From 40514ced69cb3f98af1e6ecae0d67108fbb702fa Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <yamanu@xilinx.com>
Date: Thu, 10 Feb 2022 20:42:27 +0100
Subject: [PATCH] [Lookup] reverse output layout + associated fixes

---
 src/finn/custom_op/fpgadataflow/lookup.py      | 10 +++++++---
 tests/fpgadataflow/test_fpgadataflow_lookup.py |  2 +-
 2 files changed, 8 insertions(+), 4 deletions(-)

diff --git a/src/finn/custom_op/fpgadataflow/lookup.py b/src/finn/custom_op/fpgadataflow/lookup.py
index 2c78685a4..984483c26 100644
--- a/src/finn/custom_op/fpgadataflow/lookup.py
+++ b/src/finn/custom_op/fpgadataflow/lookup.py
@@ -186,7 +186,7 @@ class Lookup(HLSCustomOp):
         oshape_cpp_str = str(oshape).replace("(", "{").replace(")", "}")
 
         self.code_gen_dict["$DATAOUTSTREAM$"] = [
-            'apintstream2npy<%s, %s, %d, %s>(out, %s, "%s");'
+            'apintstream2npy<%s, %s, %d, %s>(out, %s, "%s", %s);'
             % (
                 packed_hls_type,
                 elem_hls_type,
@@ -194,6 +194,7 @@ class Lookup(HLSCustomOp):
                 npy_type,
                 oshape_cpp_str,
                 npy_out,
+                "false",
             )
         ]
 
@@ -245,8 +246,11 @@ class Lookup(HLSCustomOp):
         assert np.vectorize(edt.allowed)(
             embeddings
         ).all(), "Embeddings can't be expressed with type %s" % str(edt)
+        # reverse innertmost dim in embeddings to remain compatible with
+        # how we normally encode the data in FINN
+        embeddings_rev = np.flip(embeddings, -1)
         embeddings_hls_code = numpy_to_hls_code(
-            embeddings, edt, "embeddings", True, False
+            embeddings_rev, edt, "embeddings", True, False
         )
         f_thresh = open(weight_filename, "w")
         f_thresh.write(embeddings_hls_code)
@@ -310,7 +314,7 @@ class Lookup(HLSCustomOp):
                 out_shape,
                 packed_bits,
                 target_bits,
-                reverse_inner=False,
+                reverse_inner=True,
             )
             # load and reshape output
             output = np.load(out_npy_path)
diff --git a/tests/fpgadataflow/test_fpgadataflow_lookup.py b/tests/fpgadataflow/test_fpgadataflow_lookup.py
index 45678bbdf..8f029adbb 100644
--- a/tests/fpgadataflow/test_fpgadataflow_lookup.py
+++ b/tests/fpgadataflow/test_fpgadataflow_lookup.py
@@ -124,7 +124,7 @@ def test_fpgadataflow_lookup(edt, embedding_cfg, exec_mode):
         model = model.transform(SetExecMode("cppsim"))
     elif exec_mode == "rtlsim":
         model = model.transform(GiveUniqueNodeNames())
-        model = model.transform(PrepareIP("xc7z020clg400-1", 10))
+        model = model.transform(PrepareIP("xczu3eg-sbva484-1-e", 10))
         model = model.transform(HLSSynthIP())
         model = model.transform(SetExecMode("rtlsim"))
         model = model.transform(PrepareRTLSim())
-- 
GitLab