From 28cef50ccad919e0260125ecb9779d2439a827ff Mon Sep 17 00:00:00 2001
From: Hendrik Borras <hendrikborras@web.de>
Date: Tue, 19 Oct 2021 19:53:59 +0100
Subject: [PATCH] Added support for QONNX to test_brevitas_avg_pool_export
 test.

---
 docker/Dockerfile.finn                        |  2 +-
 .../brevitas/test_brevitas_avg_pool_export.py | 68 +++++++++++++++----
 2 files changed, 56 insertions(+), 14 deletions(-)

diff --git a/docker/Dockerfile.finn b/docker/Dockerfile.finn
index c0640ac6f..39f612812 100644
--- a/docker/Dockerfile.finn
+++ b/docker/Dockerfile.finn
@@ -86,7 +86,7 @@ RUN pip install -e git+https://github.com/fbcotter/dataset_loading.git@0.0.4#egg
 
 # git-based Python repo dependencies
 # these are installed in editable mode for easier co-development
-ARG FINN_BASE_COMMIT="52f0947b597687e1d7d336e1e175ccfc389648be"
+ARG FINN_BASE_COMMIT="78e4098ad3fc78f72db40b6a3cf29c82c2a567b1"
 ARG QONNX_COMMIT="6d55dce220c7f744ef23585686460b9370b672a0"
 ARG FINN_EXP_COMMIT="f82c0d9868bb88ea045dfadb28508d327d287221"
 ARG BREVITAS_COMMIT="efc1217b94a71d616e3b4a37e56bd28a07c55be0"
diff --git a/tests/brevitas/test_brevitas_avg_pool_export.py b/tests/brevitas/test_brevitas_avg_pool_export.py
index 68e563da6..1b38914a8 100644
--- a/tests/brevitas/test_brevitas_avg_pool_export.py
+++ b/tests/brevitas/test_brevitas_avg_pool_export.py
@@ -31,19 +31,23 @@ import numpy as np
 import os
 import torch
 from brevitas.export import FINNManager
+from brevitas.export.onnx.generic.manager import BrevitasONNXManager
 from brevitas.nn import QuantAvgPool2d
 from brevitas.quant_tensor import QuantTensor
+from qonnx.util.cleanup import cleanup as qonnx_cleanup
 
 import finn.core.onnx_exec as oxe
 from finn.core.datatype import DataType
 from finn.core.modelwrapper import ModelWrapper
 from finn.transformation.infer_datatypes import InferDataTypes
 from finn.transformation.infer_shapes import InferShapes
+from finn.transformation.qonnx.convert_qonnx_to_finn import ConvertQONNXtoFINN
 from finn.util.basic import gen_finn_dt_tensor
 
-export_onnx_path = "test_brevitas_avg_pool_export.onnx"
+base_export_onnx_path = "test_brevitas_avg_pool_export.onnx"
 
 
+@pytest.mark.parametrize("QONNX_export", [False, True])
 @pytest.mark.parametrize("kernel_size", [2, 3])
 @pytest.mark.parametrize("stride", [1, 2])
 @pytest.mark.parametrize("signed", [True, False])
@@ -52,11 +56,23 @@ export_onnx_path = "test_brevitas_avg_pool_export.onnx"
 @pytest.mark.parametrize("channels", [2, 4])
 @pytest.mark.parametrize("idim", [7, 8])
 def test_brevitas_avg_pool_export(
-    kernel_size, stride, signed, bit_width, input_bit_width, channels, idim
+    kernel_size,
+    stride,
+    signed,
+    bit_width,
+    input_bit_width,
+    channels,
+    idim,
+    QONNX_export,
 ):
-
+    export_onnx_path = base_export_onnx_path.replace(
+        ".onnx", f"test_QONNX-{QONNX_export}.onnx"
+    )
     quant_avgpool = QuantAvgPool2d(
-        kernel_size=kernel_size, stride=stride, bit_width=bit_width
+        kernel_size=kernel_size,
+        stride=stride,
+        bit_width=bit_width,
+        return_quant_tensor=False,
     )
     quant_avgpool.eval()
 
@@ -69,31 +85,57 @@ def test_brevitas_avg_pool_export(
     # Brevitas QuantAvgPool layers need QuantTensors to export correctly
     # which requires setting up a QuantTensor instance with the scale
     # factor, zero point, bitwidth and signedness
-    scale_array = np.random.uniform(low=0, high=1, size=(1, channels, 1, 1)).astype(
-        np.float32
-    )
+    scale_array = np.ones((1, channels, 1, 1)).astype(np.float32)
+    scale_array *= 0.5
     input_tensor = torch.from_numpy(input_array * scale_array).float()
     scale_tensor = torch.from_numpy(scale_array).float()
     zp = torch.tensor(0.0)
     input_quant_tensor = QuantTensor(
-        input_tensor, scale_tensor, zp, input_bit_width, signed
+        input_tensor, scale_tensor, zp, input_bit_width, signed, training=False
     )
 
     # export
-    FINNManager.export(
-        quant_avgpool, export_path=export_onnx_path, input_t=input_quant_tensor
-    )
+    if QONNX_export:
+        BrevitasONNXManager.export(
+            quant_avgpool,
+            export_path=export_onnx_path,
+            input_t=input_quant_tensor,
+        )
+        model = ModelWrapper(export_onnx_path)
+
+        # Statically set the additional inputs generated by the BrevitasONNXManager
+        model.graph.input.remove(model.graph.input[3])
+        model.graph.input.remove(model.graph.input[2])
+        model.graph.input.remove(model.graph.input[1])
+        model.set_initializer("1", scale_array)
+        model.set_initializer("2", np.array(0.0).astype(np.float32))
+        model.set_initializer("3", np.array(input_bit_width).astype(np.float32))
+        model.save(export_onnx_path)
+
+        qonnx_cleanup(export_onnx_path, out_file=export_onnx_path)
+        model = ModelWrapper(export_onnx_path)
+        model = model.transform(ConvertQONNXtoFINN())
+        model.save(export_onnx_path)
+    else:
+        FINNManager.export(
+            quant_avgpool, export_path=export_onnx_path, input_t=input_quant_tensor
+        )
     model = ModelWrapper(export_onnx_path)
     model = model.transform(InferShapes())
     model = model.transform(InferDataTypes())
 
     # reference brevitas output
-    ref_output_array = quant_avgpool(input_quant_tensor).tensor.detach().numpy()
+    ref_output_array = quant_avgpool(input_quant_tensor).detach().numpy()
     # finn output
-    idict = {model.graph.input[0].name: input_array}
+    if QONNX_export:
+        # Manually apply the Quant tensor scaling for QONNX
+        idict = {model.graph.input[0].name: input_array * scale_array}
+    else:
+        idict = {model.graph.input[0].name: input_array}
     odict = oxe.execute_onnx(model, idict, True)
     finn_output = odict[model.graph.output[0].name]
     # compare outputs
     assert np.isclose(ref_output_array, finn_output).all()
     # cleanup
+    # assert False
     os.remove(export_onnx_path)
-- 
GitLab