From 7e04af91e497d7cc5a15016b24b59872eb26a1cd Mon Sep 17 00:00:00 2001
From: Hendrik Borras <hendrikborras@web.de>
Date: Tue, 19 Oct 2021 14:22:38 +0100
Subject: [PATCH] Enable QONNX ingestion for  test_end2end_cybsec_mlp_* tests.

---
 tests/end2end/test_end2end_cybsec_mlp.py | 71 +++++++++++++++++-------
 1 file changed, 52 insertions(+), 19 deletions(-)

diff --git a/tests/end2end/test_end2end_cybsec_mlp.py b/tests/end2end/test_end2end_cybsec_mlp.py
index 7b4cebb52..d69a87ffd 100644
--- a/tests/end2end/test_end2end_cybsec_mlp.py
+++ b/tests/end2end/test_end2end_cybsec_mlp.py
@@ -40,13 +40,16 @@ import torch
 import torch.nn as nn
 import wget
 from brevitas.core.quant import QuantType
+from brevitas.export.onnx.generic.manager import BrevitasONNXManager
 from brevitas.nn import QuantIdentity, QuantLinear, QuantReLU
 from brevitas.quant_tensor import QuantTensor
+from qonnx.util.cleanup import cleanup as qonnx_cleanup
 
 import finn.builder.build_dataflow as build
 import finn.builder.build_dataflow_config as build_cfg
 from finn.core.datatype import DataType
 from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.qonnx.convert_qonnx_to_finn import ConvertQONNXtoFINN
 from finn.util.basic import make_build_dir
 from finn.util.test import get_build_env, load_test_checkpoint_or_skip
 
@@ -55,13 +58,13 @@ build_kind = "zynq"
 build_dir = os.environ["FINN_BUILD_DIR"]
 
 
-def get_checkpoint_name(step):
+def get_checkpoint_name(step, QONNX_export):
     if step == "build":
         # checkpoint for build step is an entire dir
-        return build_dir + "/end2end_cybsecmlp_build"
+        return build_dir + "/end2end_cybsecmlp_build_QONNX-%d" % (QONNX_export)
     else:
         # other checkpoints are onnx files
-        return build_dir + "/end2end_cybsecmlp_%s.onnx" % (step)
+        return build_dir + "/end2end_cybsecmlp_QONNX-%d_%s.onnx" % (QONNX_export, step)
 
 
 class CybSecMLPForExport(nn.Module):
@@ -82,7 +85,8 @@ class CybSecMLPForExport(nn.Module):
         return out_final
 
 
-def test_end2end_cybsec_mlp_export():
+@pytest.mark.parametrize("QONNX_export", [False, True])
+def test_end2end_cybsec_mlp_export(QONNX_export):
     assets_dir = pk.resource_filename("finn.qnn-data", "cybsec-mlp/")
     # load up trained net in Brevitas
     input_size = 593
@@ -116,7 +120,7 @@ def test_end2end_cybsec_mlp_export():
     W_new = np.pad(W_orig, [(0, 0), (0, 7)])
     model[0].weight.data = torch.from_numpy(W_new)
     model_for_export = CybSecMLPForExport(model)
-    export_onnx_path = get_checkpoint_name("export")
+    export_onnx_path = get_checkpoint_name("export", QONNX_export)
     input_shape = (1, 600)
     # create a QuantTensor instance to mark the input as bipolar during export
     input_a = np.random.randint(0, 1, size=input_shape).astype(np.float32)
@@ -127,32 +131,60 @@ def test_end2end_cybsec_mlp_export():
         input_t, scale=torch.tensor(scale), bit_width=torch.tensor(1.0), signed=True
     )
 
-    bo.export_finn_onnx(
-        model_for_export, export_path=export_onnx_path, input_t=input_qt
-    )
+    if QONNX_export:
+        # With the BrevitasONNXManager we need to manually set
+        # the FINN DataType at the input
+        BrevitasONNXManager.export(
+            model_for_export, input_shape, export_path=export_onnx_path
+        )
+        model = ModelWrapper(export_onnx_path)
+        model.set_tensor_datatype(model.graph.input[0].name, DataType["BIPOLAR"])
+        model.save(export_onnx_path)
+        qonnx_cleanup(export_onnx_path, out_file=export_onnx_path)
+        model = ModelWrapper(export_onnx_path)
+        model = model.transform(ConvertQONNXtoFINN())
+        model.save(export_onnx_path)
+    else:
+        bo.export_finn_onnx(
+            model_for_export, export_path=export_onnx_path, input_t=input_qt
+        )
     assert os.path.isfile(export_onnx_path)
     # fix input datatype
     finn_model = ModelWrapper(export_onnx_path)
     finnonnx_in_tensor_name = finn_model.graph.input[0].name
     assert tuple(finn_model.get_tensor_shape(finnonnx_in_tensor_name)) == (1, 600)
     # verify a few exported ops
-    assert finn_model.graph.node[1].op_type == "Add"
-    assert finn_model.graph.node[2].op_type == "Div"
-    assert finn_model.graph.node[3].op_type == "MatMul"
-    assert finn_model.graph.node[-1].op_type == "MultiThreshold"
+    if QONNX_export:
+        # The first "Mul" node dosen't exist in the QONNX export,
+        # because the QuantTensor scale is not exported.
+        # However, this node would have been unity scale anyways and
+        # the models are still equivalent.
+        assert finn_model.graph.node[0].op_type == "Add"
+        assert finn_model.graph.node[1].op_type == "Div"
+        assert finn_model.graph.node[2].op_type == "MatMul"
+        assert finn_model.graph.node[-1].op_type == "MultiThreshold"
+    else:
+        assert finn_model.graph.node[0].op_type == "Mul"
+        assert finn_model.get_initializer(finn_model.graph.node[0].input[1]) == 1.0
+        assert finn_model.graph.node[1].op_type == "Add"
+        assert finn_model.graph.node[2].op_type == "Div"
+        assert finn_model.graph.node[3].op_type == "MatMul"
+        assert finn_model.graph.node[-1].op_type == "MultiThreshold"
     # verify datatypes on some tensors
     assert finn_model.get_tensor_datatype(finnonnx_in_tensor_name) == DataType.BIPOLAR
-    first_matmul_w_name = finn_model.graph.node[3].input[1]
+    first_matmul_w_name = finn_model.get_nodes_by_op_type("MatMul")[0].input[1]
     assert finn_model.get_tensor_datatype(first_matmul_w_name) == DataType.INT2
 
 
 @pytest.mark.slow
 @pytest.mark.vivado
-def test_end2end_cybsec_mlp_build():
-    model_file = get_checkpoint_name("export")
+# @pytest.mark.parametrize("QONNX_export", [False, True])
+@pytest.mark.parametrize("QONNX_export", [True])
+def test_end2end_cybsec_mlp_build(QONNX_export):
+    model_file = get_checkpoint_name("export", QONNX_export)
     load_test_checkpoint_or_skip(model_file)
     build_env = get_build_env(build_kind, target_clk_ns)
-    output_dir = make_build_dir("test_end2end_cybsec_mlp_build")
+    output_dir = make_build_dir(f"test_end2end_cybsec_mlp_build_QONNX-{QONNX_export}")
 
     cfg = build.DataflowBuildConfig(
         output_dir=output_dir,
@@ -190,13 +222,14 @@ def test_end2end_cybsec_mlp_build():
         est_res_dict = json.load(f)
         assert est_res_dict["total"]["LUT"] == 11360.0
         assert est_res_dict["total"]["BRAM_18K"] == 36.0
-    shutil.copytree(output_dir + "/deploy", get_checkpoint_name("build"))
+    shutil.copytree(output_dir + "/deploy", get_checkpoint_name("build", QONNX_export))
 
 
-def test_end2end_cybsec_mlp_run_on_hw():
+@pytest.mark.parametrize("QONNX_export", [False, True])
+def test_end2end_cybsec_mlp_run_on_hw(QONNX_export):
     build_env = get_build_env(build_kind, target_clk_ns)
     assets_dir = pk.resource_filename("finn.qnn-data", "cybsec-mlp/")
-    deploy_dir = get_checkpoint_name("build")
+    deploy_dir = get_checkpoint_name("build", QONNX_export)
     if not os.path.isdir(deploy_dir):
         pytest.skip(deploy_dir + " not found from previous test step, skipping")
     driver_dir = deploy_dir + "/driver"
-- 
GitLab