From 133215aa7b79404ddd26007607653c2067a23ff6 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <yamanu@xilinx.com>
Date: Fri, 25 Mar 2022 13:12:28 +0100
Subject: [PATCH] [Lookup] fix alignment for ext.mem. lookups

---
 src/finn/custom_op/fpgadataflow/lookup.py | 6 ++----
 1 file changed, 2 insertions(+), 4 deletions(-)

diff --git a/src/finn/custom_op/fpgadataflow/lookup.py b/src/finn/custom_op/fpgadataflow/lookup.py
index 1a4c917fd..dcf67e4c4 100644
--- a/src/finn/custom_op/fpgadataflow/lookup.py
+++ b/src/finn/custom_op/fpgadataflow/lookup.py
@@ -33,7 +33,6 @@ from math import ceil, log2
 
 from finn.core.datatype import DataType
 from finn.custom_op.fpgadataflow.hlscustomop import HLSCustomOp
-from finn.util.basic import roundup_to_integer_multiple
 from finn.util.data_packing import (
     npy_to_rtlsim_input,
     numpy_to_hls_code,
@@ -345,9 +344,8 @@ class Lookup(HLSCustomOp):
             emb_elems_per_ext_mem_width = self.get_folded_output_shape()[-1]
             ext_mem_emb_size = self.get_folded_output_shape()[-2]
             ext_mem_emb_align = ceil(log2(ext_mem_emb_size))
-            align_factor = 2**ext_mem_emb_align
-            aligned_emb_dim = roundup_to_integer_multiple(emb_dim, align_factor)
-            pad_amount = aligned_emb_dim - emb_dim
+            align_factor = int((ext_mem_width / 8) * 2**ext_mem_emb_align)
+            pad_amount = align_factor - emb_dim
             embeddings_padded = np.pad(embeddings, [(0, 0), (0, pad_amount)])
             # reshape for packing the innermost dim
             embeddings_padded = embeddings_padded.reshape(
-- 
GitLab