From b60468a641f2eb47fb236209d13eecb75a87e8dc Mon Sep 17 00:00:00 2001
From: Hendrik Borras <hendrikborras@web.de>
Date: Fri, 15 Oct 2021 15:49:59 +0100
Subject: [PATCH] Added QONNX_export test to test_brevitas_avg_pool_export test
 (required test modifications).

---
 docker/Dockerfile.finn                        |  2 +-
 .../brevitas/test_brevitas_avg_pool_export.py | 62 ++++++++++++-------
 2 files changed, 41 insertions(+), 23 deletions(-)

diff --git a/docker/Dockerfile.finn b/docker/Dockerfile.finn
index d809d99b9..2464d505c 100644
--- a/docker/Dockerfile.finn
+++ b/docker/Dockerfile.finn
@@ -86,7 +86,7 @@ RUN pip install -e git+https://github.com/fbcotter/dataset_loading.git@0.0.4#egg
 
 # git-based Python repo dependencies
 # these are installed in editable mode for easier co-development
-ARG FINN_BASE_COMMIT="352cb9c41676fa509f57f20a32c7362c6c09039a"
+ARG FINN_BASE_COMMIT="535b27013de83ff36925f2996745b12c9ba64d23"
 ARG QONNX_COMMIT="834610ba3f668971fe2800fde7f8d0c10d825d5b"
 ARG FINN_EXP_COMMIT="f82c0d9868bb88ea045dfadb28508d327d287221"
 ARG BREVITAS_COMMIT="0eaff006407955153594254728baeb988edcd042"
diff --git a/tests/brevitas/test_brevitas_avg_pool_export.py b/tests/brevitas/test_brevitas_avg_pool_export.py
index 68e563da6..fc6d50f8c 100644
--- a/tests/brevitas/test_brevitas_avg_pool_export.py
+++ b/tests/brevitas/test_brevitas_avg_pool_export.py
@@ -31,14 +31,17 @@ import numpy as np
 import os
 import torch
 from brevitas.export import FINNManager
-from brevitas.nn import QuantAvgPool2d
-from brevitas.quant_tensor import QuantTensor
+from brevitas.export.onnx.generic.manager import BrevitasONNXManager
+from brevitas.nn import QuantAvgPool2d, QuantIdentity
+from qonnx.util.cleanup import cleanup as qonnx_cleanup
+from torch import nn
 
 import finn.core.onnx_exec as oxe
 from finn.core.datatype import DataType
 from finn.core.modelwrapper import ModelWrapper
 from finn.transformation.infer_datatypes import InferDataTypes
 from finn.transformation.infer_shapes import InferShapes
+from finn.transformation.qonnx.convert_qonnx_to_finn import ConvertQONNXtoFINN
 from finn.util.basic import gen_finn_dt_tensor
 
 export_onnx_path = "test_brevitas_avg_pool_export.onnx"
@@ -51,12 +54,29 @@ export_onnx_path = "test_brevitas_avg_pool_export.onnx"
 @pytest.mark.parametrize("input_bit_width", [4, 8, 16])
 @pytest.mark.parametrize("channels", [2, 4])
 @pytest.mark.parametrize("idim", [7, 8])
+@pytest.mark.parametrize("QONNX_export", [False, True])
 def test_brevitas_avg_pool_export(
-    kernel_size, stride, signed, bit_width, input_bit_width, channels, idim
+    kernel_size,
+    stride,
+    signed,
+    bit_width,
+    input_bit_width,
+    channels,
+    idim,
+    QONNX_export,
 ):
 
-    quant_avgpool = QuantAvgPool2d(
-        kernel_size=kernel_size, stride=stride, bit_width=bit_width
+    # To do a proper static export Brevitas requires a quantized input tensor.
+    # For the BrevitasONNXManager these requirements are even more stringent,
+    # such that in-model quantization and de-quantization at the end are required.
+    quant_avgpool = nn.Sequential(
+        QuantIdentity(bit_width=input_bit_width, return_quant_tensor=True),
+        QuantAvgPool2d(
+            kernel_size=kernel_size,
+            stride=stride,
+            bit_width=bit_width,
+            return_quant_tensor=False,
+        ),
     )
     quant_avgpool.eval()
 
@@ -66,31 +86,29 @@ def test_brevitas_avg_pool_export(
     dtype = DataType[dt_name]
     input_shape = (1, channels, idim, idim)
     input_array = gen_finn_dt_tensor(dtype, input_shape)
-    # Brevitas QuantAvgPool layers need QuantTensors to export correctly
-    # which requires setting up a QuantTensor instance with the scale
-    # factor, zero point, bitwidth and signedness
-    scale_array = np.random.uniform(low=0, high=1, size=(1, channels, 1, 1)).astype(
-        np.float32
-    )
-    input_tensor = torch.from_numpy(input_array * scale_array).float()
-    scale_tensor = torch.from_numpy(scale_array).float()
-    zp = torch.tensor(0.0)
-    input_quant_tensor = QuantTensor(
-        input_tensor, scale_tensor, zp, input_bit_width, signed
-    )
+    input_tensor = torch.from_numpy(input_array).float()
 
     # export
-    FINNManager.export(
-        quant_avgpool, export_path=export_onnx_path, input_t=input_quant_tensor
-    )
+    if QONNX_export:
+        BrevitasONNXManager.export(
+            quant_avgpool,
+            input_shape,
+            export_path=export_onnx_path,
+        )
+        qonnx_cleanup(export_onnx_path, out_file=export_onnx_path)
+        model = ModelWrapper(export_onnx_path)
+        model = model.transform(ConvertQONNXtoFINN())
+        model.save(export_onnx_path)
+    else:
+        FINNManager.export(quant_avgpool, input_shape, export_path=export_onnx_path)
     model = ModelWrapper(export_onnx_path)
     model = model.transform(InferShapes())
     model = model.transform(InferDataTypes())
 
     # reference brevitas output
-    ref_output_array = quant_avgpool(input_quant_tensor).tensor.detach().numpy()
+    ref_output_array = quant_avgpool(input_tensor).detach().numpy()
     # finn output
-    idict = {model.graph.input[0].name: input_array}
+    idict = {model.graph.input[0].name: input_tensor.detach().numpy()}
     odict = oxe.execute_onnx(model, idict, True)
     finn_output = odict[model.graph.output[0].name]
     # compare outputs
-- 
GitLab