From 6ec329d65a050396b29fe11fbc12681ecf3f74cb Mon Sep 17 00:00:00 2001
From: Hendrik Borras <hendrikborras@web.de>
Date: Fri, 15 Oct 2021 14:25:00 +0100
Subject: [PATCH] Added QONNX_export test to test_brevitas_QConv2d test.

---
 tests/brevitas/test_brevitas_QConv2d.py | 16 ++++++++++++++--
 1 file changed, 14 insertions(+), 2 deletions(-)

diff --git a/tests/brevitas/test_brevitas_QConv2d.py b/tests/brevitas/test_brevitas_QConv2d.py
index c1f790946..9d042b85d 100644
--- a/tests/brevitas/test_brevitas_QConv2d.py
+++ b/tests/brevitas/test_brevitas_QConv2d.py
@@ -36,12 +36,15 @@ from brevitas.core.quant import QuantType
 from brevitas.core.restrict_val import RestrictValueType
 from brevitas.core.scaling import ScalingImplType
 from brevitas.core.stats import StatsOp
+from brevitas.export.onnx.generic.manager import BrevitasONNXManager
 from brevitas.nn import QuantConv2d
+from qonnx.util.cleanup import cleanup as qonnx_cleanup
 
 import finn.core.onnx_exec as oxe
 from finn.core.datatype import DataType
 from finn.core.modelwrapper import ModelWrapper
 from finn.transformation.infer_shapes import InferShapes
+from finn.transformation.qonnx.convert_qonnx_to_finn import ConvertQONNXtoFINN
 from finn.util.basic import gen_finn_dt_tensor
 
 export_onnx_path = "test_brevitas_conv.onnx"
@@ -50,7 +53,8 @@ export_onnx_path = "test_brevitas_conv.onnx"
 @pytest.mark.parametrize("dw", [False, True])
 @pytest.mark.parametrize("bias", [True, False])
 @pytest.mark.parametrize("in_channels", [32])
-def test_brevitas_QConv2d(dw, bias, in_channels):
+@pytest.mark.parametrize("QONNX_export", [False, True])
+def test_brevitas_QConv2d(dw, bias, in_channels, QONNX_export):
     ishape = (1, 32, 111, 111)
     if dw is True:
         groups = in_channels
@@ -89,7 +93,15 @@ def test_brevitas_QConv2d(dw, bias, in_channels):
     weight_tensor = gen_finn_dt_tensor(DataType.INT4, w_shape)
     b_conv.weight = torch.nn.Parameter(torch.from_numpy(weight_tensor).float())
     b_conv.eval()
-    bo.export_finn_onnx(b_conv, ishape, export_onnx_path)
+    if QONNX_export:
+        m_path = export_onnx_path
+        BrevitasONNXManager.export(b_conv, ishape, m_path)
+        qonnx_cleanup(m_path, out_file=m_path)
+        model = ModelWrapper(m_path)
+        model = model.transform(ConvertQONNXtoFINN())
+        model.save(m_path)
+    else:
+        bo.export_finn_onnx(b_conv, ishape, export_onnx_path)
     model = ModelWrapper(export_onnx_path)
     model = model.transform(InferShapes())
     inp_tensor = np.random.uniform(low=-1.0, high=1.0, size=ishape).astype(np.float32)
-- 
GitLab