From f39486184f73ff88601ec6f47cfbad4ce309f278 Mon Sep 17 00:00:00 2001
From: Hendrik Borras <hendrikborras@web.de>
Date: Fri, 15 Oct 2021 14:20:38 +0100
Subject: [PATCH] Added QONNX_export test to test_brevitas_cnv_export_exec
 test.

---
 docker/Dockerfile.finn              |  2 +-
 tests/brevitas/test_brevitas_cnv.py | 16 ++++++++++++++--
 2 files changed, 15 insertions(+), 3 deletions(-)

diff --git a/docker/Dockerfile.finn b/docker/Dockerfile.finn
index 3c165e7f8..d809d99b9 100644
--- a/docker/Dockerfile.finn
+++ b/docker/Dockerfile.finn
@@ -86,7 +86,7 @@ RUN pip install -e git+https://github.com/fbcotter/dataset_loading.git@0.0.4#egg
 
 # git-based Python repo dependencies
 # these are installed in editable mode for easier co-development
-ARG FINN_BASE_COMMIT="22886049c96048d7150efe261a82b9c3f469c100"
+ARG FINN_BASE_COMMIT="352cb9c41676fa509f57f20a32c7362c6c09039a"
 ARG QONNX_COMMIT="834610ba3f668971fe2800fde7f8d0c10d825d5b"
 ARG FINN_EXP_COMMIT="f82c0d9868bb88ea045dfadb28508d327d287221"
 ARG BREVITAS_COMMIT="0eaff006407955153594254728baeb988edcd042"
diff --git a/tests/brevitas/test_brevitas_cnv.py b/tests/brevitas/test_brevitas_cnv.py
index 8a1783ae9..78ca36136 100644
--- a/tests/brevitas/test_brevitas_cnv.py
+++ b/tests/brevitas/test_brevitas_cnv.py
@@ -34,12 +34,15 @@ import brevitas.onnx as bo
 import numpy as np
 import os
 import torch
+from brevitas.export.onnx.generic.manager import BrevitasONNXManager
+from qonnx.util.cleanup import cleanup as qonnx_cleanup
 
 import finn.core.onnx_exec as oxe
 from finn.core.modelwrapper import ModelWrapper
 from finn.transformation.fold_constants import FoldConstants
 from finn.transformation.general import GiveUniqueNodeNames, RemoveStaticGraphInputs
 from finn.transformation.infer_shapes import InferShapes
+from finn.transformation.qonnx.convert_qonnx_to_finn import ConvertQONNXtoFINN
 from finn.util.test import get_test_model_trained
 
 export_onnx_path = "test_brevitas_cnv.onnx"
@@ -47,11 +50,20 @@ export_onnx_path = "test_brevitas_cnv.onnx"
 
 @pytest.mark.parametrize("abits", [1, 2])
 @pytest.mark.parametrize("wbits", [1, 2])
-def test_brevitas_cnv_export_exec(wbits, abits):
+@pytest.mark.parametrize("QONNX_export", [False, True])
+def test_brevitas_cnv_export_exec(wbits, abits, QONNX_export):
     if wbits > abits:
         pytest.skip("No wbits > abits cases at the moment")
     cnv = get_test_model_trained("CNV", wbits, abits)
-    bo.export_finn_onnx(cnv, (1, 3, 32, 32), export_onnx_path)
+    ishape = (1, 3, 32, 32)
+    if QONNX_export:
+        BrevitasONNXManager.export(cnv, ishape, export_onnx_path)
+        qonnx_cleanup(export_onnx_path, out_file=export_onnx_path)
+        model = ModelWrapper(export_onnx_path)
+        model = model.transform(ConvertQONNXtoFINN())
+        model.save(export_onnx_path)
+    else:
+        bo.export_finn_onnx(cnv, ishape, export_onnx_path)
     model = ModelWrapper(export_onnx_path)
     model = model.transform(GiveUniqueNodeNames())
     model = model.transform(InferShapes())
-- 
GitLab