From bc78cba1f4bb379e60bfb2668aa3e9d9b3739977 Mon Sep 17 00:00:00 2001 From: Hendrik Borras <hendrikborras@web.de> Date: Fri, 15 Oct 2021 16:18:12 +0100 Subject: [PATCH] Fixed test_brevitas_avg_pool_export to properly integrate the sign. --- .../brevitas/test_brevitas_avg_pool_export.py | 24 ++++++++++++++----- 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/tests/brevitas/test_brevitas_avg_pool_export.py b/tests/brevitas/test_brevitas_avg_pool_export.py index fc6d50f8c..e202b6328 100644 --- a/tests/brevitas/test_brevitas_avg_pool_export.py +++ b/tests/brevitas/test_brevitas_avg_pool_export.py @@ -32,7 +32,7 @@ import os import torch from brevitas.export import FINNManager from brevitas.export.onnx.generic.manager import BrevitasONNXManager -from brevitas.nn import QuantAvgPool2d, QuantIdentity +from brevitas.nn import QuantAvgPool2d, QuantIdentity, QuantReLU from qonnx.util.cleanup import cleanup as qonnx_cleanup from torch import nn @@ -44,17 +44,17 @@ from finn.transformation.infer_shapes import InferShapes from finn.transformation.qonnx.convert_qonnx_to_finn import ConvertQONNXtoFINN from finn.util.basic import gen_finn_dt_tensor -export_onnx_path = "test_brevitas_avg_pool_export.onnx" +base_export_onnx_path = "test_brevitas_avg_pool_export.onnx" +@pytest.mark.parametrize("QONNX_export", [False, True]) +@pytest.mark.parametrize("signed", [True, False]) @pytest.mark.parametrize("kernel_size", [2, 3]) @pytest.mark.parametrize("stride", [1, 2]) -@pytest.mark.parametrize("signed", [True, False]) @pytest.mark.parametrize("bit_width", [2, 4]) @pytest.mark.parametrize("input_bit_width", [4, 8, 16]) @pytest.mark.parametrize("channels", [2, 4]) @pytest.mark.parametrize("idim", [7, 8]) -@pytest.mark.parametrize("QONNX_export", [False, True]) def test_brevitas_avg_pool_export( kernel_size, stride, @@ -65,12 +65,24 @@ def test_brevitas_avg_pool_export( idim, QONNX_export, ): - + export_onnx_path = base_export_onnx_path.replace( + ".onnx", f"test_QONNX-{QONNX_export}.onnx" + ) # To do a proper static export Brevitas requires a quantized input tensor. # For the BrevitasONNXManager these requirements are even more stringent, # such that in-model quantization and de-quantization at the end are required. + if signed: + # Signed is only supported by the QuantIdentity in FINN + act_layer = QuantIdentity( + signed=signed, bit_width=input_bit_width, return_quant_tensor=True + ) + else: + # Unsigned is only supported by the QuantReLU in FINN + act_layer = QuantReLU( + signed=signed, bit_width=input_bit_width, return_quant_tensor=True + ) quant_avgpool = nn.Sequential( - QuantIdentity(bit_width=input_bit_width, return_quant_tensor=True), + act_layer, QuantAvgPool2d( kernel_size=kernel_size, stride=stride, -- GitLab