Skip to content
Snippets Groups Projects
Commit 133215aa authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Lookup] fix alignment for ext.mem. lookups

parent 993aecd4
No related branches found
No related tags found
No related merge requests found
......@@ -33,7 +33,6 @@ from math import ceil, log2
from finn.core.datatype import DataType
from finn.custom_op.fpgadataflow.hlscustomop import HLSCustomOp
from finn.util.basic import roundup_to_integer_multiple
from finn.util.data_packing import (
npy_to_rtlsim_input,
numpy_to_hls_code,
......@@ -345,9 +344,8 @@ class Lookup(HLSCustomOp):
emb_elems_per_ext_mem_width = self.get_folded_output_shape()[-1]
ext_mem_emb_size = self.get_folded_output_shape()[-2]
ext_mem_emb_align = ceil(log2(ext_mem_emb_size))
align_factor = 2**ext_mem_emb_align
aligned_emb_dim = roundup_to_integer_multiple(emb_dim, align_factor)
pad_amount = aligned_emb_dim - emb_dim
align_factor = int((ext_mem_width / 8) * 2**ext_mem_emb_align)
pad_amount = align_factor - emb_dim
embeddings_padded = np.pad(embeddings, [(0, 0), (0, pad_amount)])
# reshape for packing the innermost dim
embeddings_padded = embeddings_padded.reshape(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment