From 692179c7a6a91981aa27fcd318558cdb84e1be62 Mon Sep 17 00:00:00 2001
From: auphelia <jakobapk@web.de>
Date: Tue, 9 Jun 2020 13:55:53 +0100
Subject: [PATCH] [Test] Update avg pool test

---
 .../brevitas/test_brevitas_avg_pool_export.py | 42 +++++++++++++++----
 1 file changed, 33 insertions(+), 9 deletions(-)

diff --git a/tests/brevitas/test_brevitas_avg_pool_export.py b/tests/brevitas/test_brevitas_avg_pool_export.py
index 0aff5fbf8..a423b89ff 100644
--- a/tests/brevitas/test_brevitas_avg_pool_export.py
+++ b/tests/brevitas/test_brevitas_avg_pool_export.py
@@ -19,14 +19,17 @@ import pytest
 export_onnx_path = "test_avg_pool.onnx"
 
 
-@pytest.mark.parametrize("kernel_size", [7])
-@pytest.mark.parametrize("stride", [1])
-@pytest.mark.parametrize("signed", [False])
-@pytest.mark.parametrize("bit_width", [4])
-def test_brevitas_avg_pool_export(kernel_size, stride, signed, bit_width):
-    ch = 4
-    ishape = (1, ch, 7, 7)
-    input_bit_width = 32
+@pytest.mark.parametrize("kernel_size", [2, 3])
+@pytest.mark.parametrize("stride", [1, 2])
+@pytest.mark.parametrize("signed", [False, True])
+@pytest.mark.parametrize("bit_width", [2, 4])
+@pytest.mark.parametrize("input_bit_width", [4, 8, 32])
+@pytest.mark.parametrize("channels", [2, 4])
+@pytest.mark.parametrize("idim", [7, 8])
+def test_brevitas_avg_pool_export(
+    kernel_size, stride, signed, bit_width, input_bit_width, channels, idim
+):
+    ishape = (1, channels, idim, idim)
     ibw_tensor = torch.Tensor([input_bit_width])
 
     b_avgpool = QuantAvgPool2d(
@@ -39,7 +42,8 @@ def test_brevitas_avg_pool_export(kernel_size, stride, signed, bit_width):
     )
     # call forward pass manually once to cache scale factor and bitwidth
     input_tensor = torch.from_numpy(np.zeros(ishape)).float()
-    output_scale = torch.from_numpy(np.ones((1, ch, 1, 1))).float()
+    scale = np.ones((1, channels, 1, 1))
+    output_scale = torch.from_numpy(scale).float()
     input_quant_tensor = pack_quant_tensor(
         tensor=input_tensor, scale=output_scale, bit_width=ibw_tensor
     )
@@ -56,6 +60,7 @@ def test_brevitas_avg_pool_export(kernel_size, stride, signed, bit_width):
     model = model.transform(InferShapes())
     model = model.transform(InferDataTypes())
 
+    # execution with input tensor using integers and scale = 1
     # calculate golden output
     inp = gen_finn_dt_tensor(dtype, ishape)
     input_tensor = torch.from_numpy(inp).float()
@@ -71,4 +76,23 @@ def test_brevitas_avg_pool_export(kernel_size, stride, signed, bit_width):
     produced = odict[model.graph.output[0].name]
     assert (expected == produced).all()
 
+    # execution with input tensor using float and scale != 1
+    scale = np.random.uniform(low=0, high=1, size=(1, channels, 1, 1)).astype(
+        np.float32
+    )
+    inp_tensor = inp * scale
+    input_tensor = torch.from_numpy(inp_tensor).float()
+    input_scale = torch.from_numpy(scale).float()
+    input_quant_tensor = pack_quant_tensor(
+        tensor=input_tensor, scale=input_scale, bit_width=ibw_tensor
+    )
+    expected = b_avgpool.forward(input_quant_tensor).tensor.detach().numpy()
+    # finn execution
+    idict = {model.graph.input[0].name: inp_tensor}
+    model.set_initializer(model.graph.input[1].name, scale)
+    odict = oxe.execute_onnx(model, idict, True)
+    produced = odict[model.graph.output[0].name]
+
+    assert np.isclose(expected, produced, rtol=1e-3).all()
+
     os.remove(export_onnx_path)
-- 
GitLab