From e014553c9654515346b24c886fd6ff36015aefcf Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <yamanu@xilinx.com>
Date: Thu, 14 Oct 2021 16:56:45 +0200
Subject: [PATCH] [Test] cppsim in Lookup test

---
 tests/fpgadataflow/test_lookup.py | 18 +++++++++++++++++-
 1 file changed, 17 insertions(+), 1 deletion(-)

diff --git a/tests/fpgadataflow/test_lookup.py b/tests/fpgadataflow/test_lookup.py
index 14e046633..4b9d4e941 100644
--- a/tests/fpgadataflow/test_lookup.py
+++ b/tests/fpgadataflow/test_lookup.py
@@ -36,6 +36,10 @@ from torch import nn
 from finn.core.datatype import DataType
 from finn.core.modelwrapper import ModelWrapper
 from finn.core.onnx_exec import execute_onnx
+from finn.transformation.fpgadataflow.compile_cppsim import CompileCppSim
+from finn.transformation.fpgadataflow.convert_to_hls_layers import InferLookupLayer
+from finn.transformation.fpgadataflow.prepare_cppsim import PrepareCppSim
+from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode
 from finn.transformation.infer_datatypes import InferDataTypes
 from finn.transformation.infer_shapes import InferShapes
 from finn.util.basic import gen_finn_dt_tensor
@@ -73,7 +77,7 @@ def make_lookup_model(embeddings, ishape, idt, edt):
     return model
 
 
-def test_lookup_export():
+def test_lookup_export_convert():
     export_path = tmpdir + "/test_lookup_export.onnx"
     ishape = (1, 10)
     idt = DataType["UINT8"]
@@ -100,3 +104,15 @@ def test_lookup_export():
     ret = execute_onnx(model, {iname: itensor})
     exp_out = np.take(embeddings, itensor, axis=0)
     assert (exp_out == ret[oname]).all()
+    # call transformation to convert to HLS and verify conversion
+    model = model.transform(InferLookupLayer())
+    assert model.graph.node[0].op_type == "Lookup"
+    assert model.graph.node[0].input[0] == iname
+    assert model.graph.node[0].input[1] == ename
+    assert model.graph.node[0].output[0] == oname
+    # prepare and execute cppsim
+    model = model.transform(PrepareCppSim())
+    model = model.transform(CompileCppSim())
+    model = model.transform(SetExecMode("cppsim"))
+    ret = execute_onnx(model, {iname: itensor})
+    assert (exp_out == ret[oname]).all()
-- 
GitLab