From e014553c9654515346b24c886fd6ff36015aefcf Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <yamanu@xilinx.com> Date: Thu, 14 Oct 2021 16:56:45 +0200 Subject: [PATCH] [Test] cppsim in Lookup test --- tests/fpgadataflow/test_lookup.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/tests/fpgadataflow/test_lookup.py b/tests/fpgadataflow/test_lookup.py index 14e046633..4b9d4e941 100644 --- a/tests/fpgadataflow/test_lookup.py +++ b/tests/fpgadataflow/test_lookup.py @@ -36,6 +36,10 @@ from torch import nn from finn.core.datatype import DataType from finn.core.modelwrapper import ModelWrapper from finn.core.onnx_exec import execute_onnx +from finn.transformation.fpgadataflow.compile_cppsim import CompileCppSim +from finn.transformation.fpgadataflow.convert_to_hls_layers import InferLookupLayer +from finn.transformation.fpgadataflow.prepare_cppsim import PrepareCppSim +from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode from finn.transformation.infer_datatypes import InferDataTypes from finn.transformation.infer_shapes import InferShapes from finn.util.basic import gen_finn_dt_tensor @@ -73,7 +77,7 @@ def make_lookup_model(embeddings, ishape, idt, edt): return model -def test_lookup_export(): +def test_lookup_export_convert(): export_path = tmpdir + "/test_lookup_export.onnx" ishape = (1, 10) idt = DataType["UINT8"] @@ -100,3 +104,15 @@ def test_lookup_export(): ret = execute_onnx(model, {iname: itensor}) exp_out = np.take(embeddings, itensor, axis=0) assert (exp_out == ret[oname]).all() + # call transformation to convert to HLS and verify conversion + model = model.transform(InferLookupLayer()) + assert model.graph.node[0].op_type == "Lookup" + assert model.graph.node[0].input[0] == iname + assert model.graph.node[0].input[1] == ename + assert model.graph.node[0].output[0] == oname + # prepare and execute cppsim + model = model.transform(PrepareCppSim()) + model = model.transform(CompileCppSim()) + model = model.transform(SetExecMode("cppsim")) + ret = execute_onnx(model, {iname: itensor}) + assert (exp_out == ret[oname]).all() -- GitLab