From bd887ebae155dbcde03b852411d21ead49226439 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <maltanar@gmail.com> Date: Thu, 4 Jun 2020 23:20:38 +0100 Subject: [PATCH] [Test] rework avgpool test case to comply with QuantTensor export --- .../brevitas/test_brevitas_avg_pool_export.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/tests/brevitas/test_brevitas_avg_pool_export.py b/tests/brevitas/test_brevitas_avg_pool_export.py index 01da19b93..3bf6a8ed6 100644 --- a/tests/brevitas/test_brevitas_avg_pool_export.py +++ b/tests/brevitas/test_brevitas_avg_pool_export.py @@ -1,6 +1,11 @@ import onnx # noqa +import torch +import numpy as np import brevitas.onnx as bo from brevitas.nn import QuantAvgPool2d +from brevitas.quant_tensor import pack_quant_tensor +from brevitas.core.quant import QuantType + import pytest export_onnx_path = "test_avg_pool.onnx" @@ -11,7 +16,10 @@ export_onnx_path = "test_avg_pool.onnx" @pytest.mark.parametrize("signed", [False]) @pytest.mark.parametrize("bit_width", [4]) def test_brevitas_avg_pool_export(kernel_size, stride, signed, bit_width): - ishape = (1, 1024, 7, 7) + ch = 4 + ishape = (1, ch, 7, 7) + input_bit_width = 32 + ibw_tensor = torch.Tensor([input_bit_width]) b_avgpool = QuantAvgPool2d( kernel_size=kernel_size, @@ -19,5 +27,12 @@ def test_brevitas_avg_pool_export(kernel_size, stride, signed, bit_width): signed=signed, min_overall_bit_width=bit_width, max_overall_bit_width=bit_width, + quant_type=QuantType.INT, + ) + # call forward pass manually once to cache scale factor and bitwidth + input_tensor = torch.from_numpy(np.zeros(ishape)).float() + output_scale = torch.from_numpy(np.ones((1, ch, 1, 1))).float() + input_quant_tensor = pack_quant_tensor( + tensor=input_tensor, scale=output_scale, bit_width=ibw_tensor ) - bo.export_finn_onnx(b_avgpool, ishape, export_onnx_path) + bo.export_finn_onnx(b_avgpool, ishape, export_onnx_path, input_t=input_quant_tensor) -- GitLab