Skip to content
Snippets Groups Projects
Commit 692179c7 authored by auphelia's avatar auphelia
Browse files

[Test] Update avg pool test

parent a7f4b4a0
No related branches found
No related tags found
No related merge requests found
......@@ -19,14 +19,17 @@ import pytest
export_onnx_path = "test_avg_pool.onnx"
@pytest.mark.parametrize("kernel_size", [7])
@pytest.mark.parametrize("stride", [1])
@pytest.mark.parametrize("signed", [False])
@pytest.mark.parametrize("bit_width", [4])
def test_brevitas_avg_pool_export(kernel_size, stride, signed, bit_width):
ch = 4
ishape = (1, ch, 7, 7)
input_bit_width = 32
@pytest.mark.parametrize("kernel_size", [2, 3])
@pytest.mark.parametrize("stride", [1, 2])
@pytest.mark.parametrize("signed", [False, True])
@pytest.mark.parametrize("bit_width", [2, 4])
@pytest.mark.parametrize("input_bit_width", [4, 8, 32])
@pytest.mark.parametrize("channels", [2, 4])
@pytest.mark.parametrize("idim", [7, 8])
def test_brevitas_avg_pool_export(
kernel_size, stride, signed, bit_width, input_bit_width, channels, idim
):
ishape = (1, channels, idim, idim)
ibw_tensor = torch.Tensor([input_bit_width])
b_avgpool = QuantAvgPool2d(
......@@ -39,7 +42,8 @@ def test_brevitas_avg_pool_export(kernel_size, stride, signed, bit_width):
)
# call forward pass manually once to cache scale factor and bitwidth
input_tensor = torch.from_numpy(np.zeros(ishape)).float()
output_scale = torch.from_numpy(np.ones((1, ch, 1, 1))).float()
scale = np.ones((1, channels, 1, 1))
output_scale = torch.from_numpy(scale).float()
input_quant_tensor = pack_quant_tensor(
tensor=input_tensor, scale=output_scale, bit_width=ibw_tensor
)
......@@ -56,6 +60,7 @@ def test_brevitas_avg_pool_export(kernel_size, stride, signed, bit_width):
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())
# execution with input tensor using integers and scale = 1
# calculate golden output
inp = gen_finn_dt_tensor(dtype, ishape)
input_tensor = torch.from_numpy(inp).float()
......@@ -71,4 +76,23 @@ def test_brevitas_avg_pool_export(kernel_size, stride, signed, bit_width):
produced = odict[model.graph.output[0].name]
assert (expected == produced).all()
# execution with input tensor using float and scale != 1
scale = np.random.uniform(low=0, high=1, size=(1, channels, 1, 1)).astype(
np.float32
)
inp_tensor = inp * scale
input_tensor = torch.from_numpy(inp_tensor).float()
input_scale = torch.from_numpy(scale).float()
input_quant_tensor = pack_quant_tensor(
tensor=input_tensor, scale=input_scale, bit_width=ibw_tensor
)
expected = b_avgpool.forward(input_quant_tensor).tensor.detach().numpy()
# finn execution
idict = {model.graph.input[0].name: inp_tensor}
model.set_initializer(model.graph.input[1].name, scale)
odict = oxe.execute_onnx(model, idict, True)
produced = odict[model.graph.output[0].name]
assert np.isclose(expected, produced, rtol=1e-3).all()
os.remove(export_onnx_path)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment