Skip to content
Snippets Groups Projects
Commit 31daf6d2 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Test] use new avgpool iface to fix avgpool export test

parent 9b521da3
No related branches found
No related tags found
No related merge requests found
......@@ -23,7 +23,7 @@ export_onnx_path = "test_brevitas_avg_pool_export.onnx"
@pytest.mark.parametrize("stride", [1, 2])
@pytest.mark.parametrize("signed", [False, True])
@pytest.mark.parametrize("bit_width", [2, 4])
@pytest.mark.parametrize("input_bit_width", [4, 8, 32])
@pytest.mark.parametrize("input_bit_width", [4, 8, 16])
@pytest.mark.parametrize("channels", [2, 4])
@pytest.mark.parametrize("idim", [7, 8])
def test_brevitas_avg_pool_export(
......@@ -35,9 +35,7 @@ def test_brevitas_avg_pool_export(
b_avgpool = QuantAvgPool2d(
kernel_size=kernel_size,
stride=stride,
signed=signed,
min_overall_bit_width=bit_width,
max_overall_bit_width=bit_width,
bit_width=bit_width,
quant_type=QuantType.INT,
)
# call forward pass manually once to cache scale factor and bitwidth
......@@ -55,7 +53,7 @@ def test_brevitas_avg_pool_export(
prefix = "INT"
else:
prefix = "UINT"
dt_name = prefix + str(input_bit_width // 2)
dt_name = prefix + str(input_bit_width)
dtype = DataType[dt_name]
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment