From 692179c7a6a91981aa27fcd318558cdb84e1be62 Mon Sep 17 00:00:00 2001 From: auphelia <jakobapk@web.de> Date: Tue, 9 Jun 2020 13:55:53 +0100 Subject: [PATCH] [Test] Update avg pool test --- .../brevitas/test_brevitas_avg_pool_export.py | 42 +++++++++++++++---- 1 file changed, 33 insertions(+), 9 deletions(-) diff --git a/tests/brevitas/test_brevitas_avg_pool_export.py b/tests/brevitas/test_brevitas_avg_pool_export.py index 0aff5fbf8..a423b89ff 100644 --- a/tests/brevitas/test_brevitas_avg_pool_export.py +++ b/tests/brevitas/test_brevitas_avg_pool_export.py @@ -19,14 +19,17 @@ import pytest export_onnx_path = "test_avg_pool.onnx" -@pytest.mark.parametrize("kernel_size", [7]) -@pytest.mark.parametrize("stride", [1]) -@pytest.mark.parametrize("signed", [False]) -@pytest.mark.parametrize("bit_width", [4]) -def test_brevitas_avg_pool_export(kernel_size, stride, signed, bit_width): - ch = 4 - ishape = (1, ch, 7, 7) - input_bit_width = 32 +@pytest.mark.parametrize("kernel_size", [2, 3]) +@pytest.mark.parametrize("stride", [1, 2]) +@pytest.mark.parametrize("signed", [False, True]) +@pytest.mark.parametrize("bit_width", [2, 4]) +@pytest.mark.parametrize("input_bit_width", [4, 8, 32]) +@pytest.mark.parametrize("channels", [2, 4]) +@pytest.mark.parametrize("idim", [7, 8]) +def test_brevitas_avg_pool_export( + kernel_size, stride, signed, bit_width, input_bit_width, channels, idim +): + ishape = (1, channels, idim, idim) ibw_tensor = torch.Tensor([input_bit_width]) b_avgpool = QuantAvgPool2d( @@ -39,7 +42,8 @@ def test_brevitas_avg_pool_export(kernel_size, stride, signed, bit_width): ) # call forward pass manually once to cache scale factor and bitwidth input_tensor = torch.from_numpy(np.zeros(ishape)).float() - output_scale = torch.from_numpy(np.ones((1, ch, 1, 1))).float() + scale = np.ones((1, channels, 1, 1)) + output_scale = torch.from_numpy(scale).float() input_quant_tensor = pack_quant_tensor( tensor=input_tensor, scale=output_scale, bit_width=ibw_tensor ) @@ -56,6 +60,7 @@ def test_brevitas_avg_pool_export(kernel_size, stride, signed, bit_width): model = model.transform(InferShapes()) model = model.transform(InferDataTypes()) + # execution with input tensor using integers and scale = 1 # calculate golden output inp = gen_finn_dt_tensor(dtype, ishape) input_tensor = torch.from_numpy(inp).float() @@ -71,4 +76,23 @@ def test_brevitas_avg_pool_export(kernel_size, stride, signed, bit_width): produced = odict[model.graph.output[0].name] assert (expected == produced).all() + # execution with input tensor using float and scale != 1 + scale = np.random.uniform(low=0, high=1, size=(1, channels, 1, 1)).astype( + np.float32 + ) + inp_tensor = inp * scale + input_tensor = torch.from_numpy(inp_tensor).float() + input_scale = torch.from_numpy(scale).float() + input_quant_tensor = pack_quant_tensor( + tensor=input_tensor, scale=input_scale, bit_width=ibw_tensor + ) + expected = b_avgpool.forward(input_quant_tensor).tensor.detach().numpy() + # finn execution + idict = {model.graph.input[0].name: inp_tensor} + model.set_initializer(model.graph.input[1].name, scale) + odict = oxe.execute_onnx(model, idict, True) + produced = odict[model.graph.output[0].name] + + assert np.isclose(expected, produced, rtol=1e-3).all() + os.remove(export_onnx_path) -- GitLab