Skip to content
Snippets Groups Projects
Commit bc78cba1 authored by Hendrik Borras's avatar Hendrik Borras
Browse files

Fixed test_brevitas_avg_pool_export to properly integrate the sign.

parent b60468a6
No related branches found
No related tags found
No related merge requests found
......@@ -32,7 +32,7 @@ import os
import torch
from brevitas.export import FINNManager
from brevitas.export.onnx.generic.manager import BrevitasONNXManager
from brevitas.nn import QuantAvgPool2d, QuantIdentity
from brevitas.nn import QuantAvgPool2d, QuantIdentity, QuantReLU
from qonnx.util.cleanup import cleanup as qonnx_cleanup
from torch import nn
......@@ -44,17 +44,17 @@ from finn.transformation.infer_shapes import InferShapes
from finn.transformation.qonnx.convert_qonnx_to_finn import ConvertQONNXtoFINN
from finn.util.basic import gen_finn_dt_tensor
export_onnx_path = "test_brevitas_avg_pool_export.onnx"
base_export_onnx_path = "test_brevitas_avg_pool_export.onnx"
@pytest.mark.parametrize("QONNX_export", [False, True])
@pytest.mark.parametrize("signed", [True, False])
@pytest.mark.parametrize("kernel_size", [2, 3])
@pytest.mark.parametrize("stride", [1, 2])
@pytest.mark.parametrize("signed", [True, False])
@pytest.mark.parametrize("bit_width", [2, 4])
@pytest.mark.parametrize("input_bit_width", [4, 8, 16])
@pytest.mark.parametrize("channels", [2, 4])
@pytest.mark.parametrize("idim", [7, 8])
@pytest.mark.parametrize("QONNX_export", [False, True])
def test_brevitas_avg_pool_export(
kernel_size,
stride,
......@@ -65,12 +65,24 @@ def test_brevitas_avg_pool_export(
idim,
QONNX_export,
):
export_onnx_path = base_export_onnx_path.replace(
".onnx", f"test_QONNX-{QONNX_export}.onnx"
)
# To do a proper static export Brevitas requires a quantized input tensor.
# For the BrevitasONNXManager these requirements are even more stringent,
# such that in-model quantization and de-quantization at the end are required.
if signed:
# Signed is only supported by the QuantIdentity in FINN
act_layer = QuantIdentity(
signed=signed, bit_width=input_bit_width, return_quant_tensor=True
)
else:
# Unsigned is only supported by the QuantReLU in FINN
act_layer = QuantReLU(
signed=signed, bit_width=input_bit_width, return_quant_tensor=True
)
quant_avgpool = nn.Sequential(
QuantIdentity(bit_width=input_bit_width, return_quant_tensor=True),
act_layer,
QuantAvgPool2d(
kernel_size=kernel_size,
stride=stride,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment