diff --git a/tests/brevitas/test_brevitas_avg_pool_export.py b/tests/brevitas/test_brevitas_avg_pool_export.py new file mode 100644 index 0000000000000000000000000000000000000000..01da19b935545d23abe27ed70ca3aa85ed3200cc --- /dev/null +++ b/tests/brevitas/test_brevitas_avg_pool_export.py @@ -0,0 +1,23 @@ +import onnx # noqa +import brevitas.onnx as bo +from brevitas.nn import QuantAvgPool2d +import pytest + +export_onnx_path = "test_avg_pool.onnx" + + +@pytest.mark.parametrize("kernel_size", [7]) +@pytest.mark.parametrize("stride", [1]) +@pytest.mark.parametrize("signed", [False]) +@pytest.mark.parametrize("bit_width", [4]) +def test_brevitas_avg_pool_export(kernel_size, stride, signed, bit_width): + ishape = (1, 1024, 7, 7) + + b_avgpool = QuantAvgPool2d( + kernel_size=kernel_size, + stride=stride, + signed=signed, + min_overall_bit_width=bit_width, + max_overall_bit_width=bit_width, + ) + bo.export_finn_onnx(b_avgpool, ishape, export_onnx_path)