From bc78cba1f4bb379e60bfb2668aa3e9d9b3739977 Mon Sep 17 00:00:00 2001
From: Hendrik Borras <hendrikborras@web.de>
Date: Fri, 15 Oct 2021 16:18:12 +0100
Subject: [PATCH] Fixed test_brevitas_avg_pool_export to properly integrate the
 sign.

---
 .../brevitas/test_brevitas_avg_pool_export.py | 24 ++++++++++++++-----
 1 file changed, 18 insertions(+), 6 deletions(-)

diff --git a/tests/brevitas/test_brevitas_avg_pool_export.py b/tests/brevitas/test_brevitas_avg_pool_export.py
index fc6d50f8c..e202b6328 100644
--- a/tests/brevitas/test_brevitas_avg_pool_export.py
+++ b/tests/brevitas/test_brevitas_avg_pool_export.py
@@ -32,7 +32,7 @@ import os
 import torch
 from brevitas.export import FINNManager
 from brevitas.export.onnx.generic.manager import BrevitasONNXManager
-from brevitas.nn import QuantAvgPool2d, QuantIdentity
+from brevitas.nn import QuantAvgPool2d, QuantIdentity, QuantReLU
 from qonnx.util.cleanup import cleanup as qonnx_cleanup
 from torch import nn
 
@@ -44,17 +44,17 @@ from finn.transformation.infer_shapes import InferShapes
 from finn.transformation.qonnx.convert_qonnx_to_finn import ConvertQONNXtoFINN
 from finn.util.basic import gen_finn_dt_tensor
 
-export_onnx_path = "test_brevitas_avg_pool_export.onnx"
+base_export_onnx_path = "test_brevitas_avg_pool_export.onnx"
 
 
+@pytest.mark.parametrize("QONNX_export", [False, True])
+@pytest.mark.parametrize("signed", [True, False])
 @pytest.mark.parametrize("kernel_size", [2, 3])
 @pytest.mark.parametrize("stride", [1, 2])
-@pytest.mark.parametrize("signed", [True, False])
 @pytest.mark.parametrize("bit_width", [2, 4])
 @pytest.mark.parametrize("input_bit_width", [4, 8, 16])
 @pytest.mark.parametrize("channels", [2, 4])
 @pytest.mark.parametrize("idim", [7, 8])
-@pytest.mark.parametrize("QONNX_export", [False, True])
 def test_brevitas_avg_pool_export(
     kernel_size,
     stride,
@@ -65,12 +65,24 @@ def test_brevitas_avg_pool_export(
     idim,
     QONNX_export,
 ):
-
+    export_onnx_path = base_export_onnx_path.replace(
+        ".onnx", f"test_QONNX-{QONNX_export}.onnx"
+    )
     # To do a proper static export Brevitas requires a quantized input tensor.
     # For the BrevitasONNXManager these requirements are even more stringent,
     # such that in-model quantization and de-quantization at the end are required.
+    if signed:
+        # Signed is only supported by the QuantIdentity in FINN
+        act_layer = QuantIdentity(
+            signed=signed, bit_width=input_bit_width, return_quant_tensor=True
+        )
+    else:
+        # Unsigned is only supported by the QuantReLU in FINN
+        act_layer = QuantReLU(
+            signed=signed, bit_width=input_bit_width, return_quant_tensor=True
+        )
     quant_avgpool = nn.Sequential(
-        QuantIdentity(bit_width=input_bit_width, return_quant_tensor=True),
+        act_layer,
         QuantAvgPool2d(
             kernel_size=kernel_size,
             stride=stride,
-- 
GitLab