Skip to content
Snippets Groups Projects
Commit bc78cba1 authored by Hendrik Borras's avatar Hendrik Borras
Browse files

Fixed test_brevitas_avg_pool_export to properly integrate the sign.

parent b60468a6
No related branches found
No related tags found
No related merge requests found
...@@ -32,7 +32,7 @@ import os ...@@ -32,7 +32,7 @@ import os
import torch import torch
from brevitas.export import FINNManager from brevitas.export import FINNManager
from brevitas.export.onnx.generic.manager import BrevitasONNXManager from brevitas.export.onnx.generic.manager import BrevitasONNXManager
from brevitas.nn import QuantAvgPool2d, QuantIdentity from brevitas.nn import QuantAvgPool2d, QuantIdentity, QuantReLU
from qonnx.util.cleanup import cleanup as qonnx_cleanup from qonnx.util.cleanup import cleanup as qonnx_cleanup
from torch import nn from torch import nn
...@@ -44,17 +44,17 @@ from finn.transformation.infer_shapes import InferShapes ...@@ -44,17 +44,17 @@ from finn.transformation.infer_shapes import InferShapes
from finn.transformation.qonnx.convert_qonnx_to_finn import ConvertQONNXtoFINN from finn.transformation.qonnx.convert_qonnx_to_finn import ConvertQONNXtoFINN
from finn.util.basic import gen_finn_dt_tensor from finn.util.basic import gen_finn_dt_tensor
export_onnx_path = "test_brevitas_avg_pool_export.onnx" base_export_onnx_path = "test_brevitas_avg_pool_export.onnx"
@pytest.mark.parametrize("QONNX_export", [False, True])
@pytest.mark.parametrize("signed", [True, False])
@pytest.mark.parametrize("kernel_size", [2, 3]) @pytest.mark.parametrize("kernel_size", [2, 3])
@pytest.mark.parametrize("stride", [1, 2]) @pytest.mark.parametrize("stride", [1, 2])
@pytest.mark.parametrize("signed", [True, False])
@pytest.mark.parametrize("bit_width", [2, 4]) @pytest.mark.parametrize("bit_width", [2, 4])
@pytest.mark.parametrize("input_bit_width", [4, 8, 16]) @pytest.mark.parametrize("input_bit_width", [4, 8, 16])
@pytest.mark.parametrize("channels", [2, 4]) @pytest.mark.parametrize("channels", [2, 4])
@pytest.mark.parametrize("idim", [7, 8]) @pytest.mark.parametrize("idim", [7, 8])
@pytest.mark.parametrize("QONNX_export", [False, True])
def test_brevitas_avg_pool_export( def test_brevitas_avg_pool_export(
kernel_size, kernel_size,
stride, stride,
...@@ -65,12 +65,24 @@ def test_brevitas_avg_pool_export( ...@@ -65,12 +65,24 @@ def test_brevitas_avg_pool_export(
idim, idim,
QONNX_export, QONNX_export,
): ):
export_onnx_path = base_export_onnx_path.replace(
".onnx", f"test_QONNX-{QONNX_export}.onnx"
)
# To do a proper static export Brevitas requires a quantized input tensor. # To do a proper static export Brevitas requires a quantized input tensor.
# For the BrevitasONNXManager these requirements are even more stringent, # For the BrevitasONNXManager these requirements are even more stringent,
# such that in-model quantization and de-quantization at the end are required. # such that in-model quantization and de-quantization at the end are required.
if signed:
# Signed is only supported by the QuantIdentity in FINN
act_layer = QuantIdentity(
signed=signed, bit_width=input_bit_width, return_quant_tensor=True
)
else:
# Unsigned is only supported by the QuantReLU in FINN
act_layer = QuantReLU(
signed=signed, bit_width=input_bit_width, return_quant_tensor=True
)
quant_avgpool = nn.Sequential( quant_avgpool = nn.Sequential(
QuantIdentity(bit_width=input_bit_width, return_quant_tensor=True), act_layer,
QuantAvgPool2d( QuantAvgPool2d(
kernel_size=kernel_size, kernel_size=kernel_size,
stride=stride, stride=stride,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment