From 33aecc82f87fbe6d43412191e7bf7810c572e8e8 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <yamanu@xilinx.com>
Date: Thu, 14 Oct 2021 23:31:11 +0200
Subject: [PATCH] [Test] restructure lookup test

---
 ..._lookup.py => test_fpgadataflow_lookup.py} | 52 ++++++++++---------
 1 file changed, 27 insertions(+), 25 deletions(-)
 rename tests/fpgadataflow/{test_lookup.py => test_fpgadataflow_lookup.py} (81%)

diff --git a/tests/fpgadataflow/test_lookup.py b/tests/fpgadataflow/test_fpgadataflow_lookup.py
similarity index 81%
rename from tests/fpgadataflow/test_lookup.py
rename to tests/fpgadataflow/test_fpgadataflow_lookup.py
index 11f5a27c9..45678bbdf 100644
--- a/tests/fpgadataflow/test_lookup.py
+++ b/tests/fpgadataflow/test_fpgadataflow_lookup.py
@@ -26,9 +26,9 @@
 # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
+import pytest
+
 import numpy as np
-import onnx  # noqa
-import os
 import torch
 from brevitas.export import FINNManager
 from torch import nn
@@ -48,8 +48,6 @@ from finn.transformation.infer_datatypes import InferDataTypes
 from finn.transformation.infer_shapes import InferShapes
 from finn.util.basic import gen_finn_dt_tensor
 
-tmpdir = os.environ["FINN_BUILD_DIR"]
-
 
 def make_lookup_model(embeddings, ishape, idt, edt):
     num_embeddings, embedding_dim = embeddings.shape
@@ -81,18 +79,23 @@ def make_lookup_model(embeddings, ishape, idt, edt):
     return model
 
 
-def test_lookup_export_convert():
-    export_path = tmpdir + "/test_lookup_export.onnx"
+# embedding DataType
+@pytest.mark.parametrize("edt", [DataType["FIXED<8,2>"]])
+# other embedding config
+@pytest.mark.parametrize(
+    "embedding_cfg", [(130, DataType["UINT8"], 25), (5145, DataType["UINT16"], 20)]
+)
+# execution mode
+@pytest.mark.parametrize("exec_mode", ["cppsim", "rtlsim"])
+@pytest.mark.vivado
+@pytest.mark.slow
+def test_fpgadataflow_lookup(edt, embedding_cfg, exec_mode):
     ishape = (1, 10)
-    idt = DataType["UINT8"]
-    edt = DataType["FIXED<8,2>"]
-    num_embeddings = 2 ** idt.bitwidth()
-    embedding_dim = 2
+    num_embeddings, idt, embedding_dim = embedding_cfg
     eshape = (num_embeddings, embedding_dim)
     exp_oshape = tuple(list(ishape) + [embedding_dim])
     embeddings = gen_finn_dt_tensor(edt, eshape)
     model = make_lookup_model(embeddings, ishape, idt, edt)
-    model.save(export_path)
     assert len(model.graph.node) == 1
     assert model.graph.node[0].op_type == "Gather"
     iname = model.graph.input[0].name
@@ -105,6 +108,7 @@ def test_lookup_export_convert():
     assert tuple(model.get_tensor_shape(oname)) == exp_oshape
     assert (model.get_initializer(ename) == embeddings).all()
     itensor = gen_finn_dt_tensor(idt, ishape).astype(np.int64)
+    itensor = np.clip(itensor, 0, num_embeddings - 1)
     ret = execute_onnx(model, {iname: itensor})
     exp_out = np.take(embeddings, itensor, axis=0)
     assert (exp_out == ret[oname]).all()
@@ -114,17 +118,15 @@ def test_lookup_export_convert():
     assert model.graph.node[0].input[0] == iname
     assert model.graph.node[0].input[1] == ename
     assert model.graph.node[0].output[0] == oname
-    # prepare and execute cppsim
-    model = model.transform(PrepareCppSim())
-    model = model.transform(CompileCppSim())
-    model = model.transform(SetExecMode("cppsim"))
-    ret_cppsim = execute_onnx(model, {iname: itensor})
-    assert (exp_out == ret_cppsim[oname]).all()
-    # prepare and execute rtlsim
-    model = model.transform(GiveUniqueNodeNames())
-    model = model.transform(PrepareIP("xc7z020clg400-1", 10))
-    model = model.transform(HLSSynthIP())
-    model = model.transform(SetExecMode("rtlsim"))
-    model = model.transform(PrepareRTLSim())
-    ret_rtlsim = execute_onnx(model, {iname: itensor})
-    assert (exp_out == ret_rtlsim[oname]).all()
+    if exec_mode == "cppsim":
+        model = model.transform(PrepareCppSim())
+        model = model.transform(CompileCppSim())
+        model = model.transform(SetExecMode("cppsim"))
+    elif exec_mode == "rtlsim":
+        model = model.transform(GiveUniqueNodeNames())
+        model = model.transform(PrepareIP("xc7z020clg400-1", 10))
+        model = model.transform(HLSSynthIP())
+        model = model.transform(SetExecMode("rtlsim"))
+        model = model.transform(PrepareRTLSim())
+    ret_sim = execute_onnx(model, {iname: itensor})
+    assert (exp_out == ret_sim[oname]).all()
-- 
GitLab