From bd887ebae155dbcde03b852411d21ead49226439 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Thu, 4 Jun 2020 23:20:38 +0100
Subject: [PATCH] [Test] rework avgpool test case to comply with QuantTensor
 export

---
 .../brevitas/test_brevitas_avg_pool_export.py | 19 +++++++++++++++++--
 1 file changed, 17 insertions(+), 2 deletions(-)

diff --git a/tests/brevitas/test_brevitas_avg_pool_export.py b/tests/brevitas/test_brevitas_avg_pool_export.py
index 01da19b93..3bf6a8ed6 100644
--- a/tests/brevitas/test_brevitas_avg_pool_export.py
+++ b/tests/brevitas/test_brevitas_avg_pool_export.py
@@ -1,6 +1,11 @@
 import onnx  # noqa
+import torch
+import numpy as np
 import brevitas.onnx as bo
 from brevitas.nn import QuantAvgPool2d
+from brevitas.quant_tensor import pack_quant_tensor
+from brevitas.core.quant import QuantType
+
 import pytest
 
 export_onnx_path = "test_avg_pool.onnx"
@@ -11,7 +16,10 @@ export_onnx_path = "test_avg_pool.onnx"
 @pytest.mark.parametrize("signed", [False])
 @pytest.mark.parametrize("bit_width", [4])
 def test_brevitas_avg_pool_export(kernel_size, stride, signed, bit_width):
-    ishape = (1, 1024, 7, 7)
+    ch = 4
+    ishape = (1, ch, 7, 7)
+    input_bit_width = 32
+    ibw_tensor = torch.Tensor([input_bit_width])
 
     b_avgpool = QuantAvgPool2d(
         kernel_size=kernel_size,
@@ -19,5 +27,12 @@ def test_brevitas_avg_pool_export(kernel_size, stride, signed, bit_width):
         signed=signed,
         min_overall_bit_width=bit_width,
         max_overall_bit_width=bit_width,
+        quant_type=QuantType.INT,
+    )
+    # call forward pass manually once to cache scale factor and bitwidth
+    input_tensor = torch.from_numpy(np.zeros(ishape)).float()
+    output_scale = torch.from_numpy(np.ones((1, ch, 1, 1))).float()
+    input_quant_tensor = pack_quant_tensor(
+        tensor=input_tensor, scale=output_scale, bit_width=ibw_tensor
     )
-    bo.export_finn_onnx(b_avgpool, ishape, export_onnx_path)
+    bo.export_finn_onnx(b_avgpool, ishape, export_onnx_path, input_t=input_quant_tensor)
-- 
GitLab