From 52f268354791b9247476102fd22cfc9958a3e96f Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <maltanar@gmail.com> Date: Sun, 9 May 2021 23:22:02 +0100 Subject: [PATCH] [Test] fix avgpool export test --- .../brevitas/test_brevitas_avg_pool_export.py | 116 ++++++------------ 1 file changed, 35 insertions(+), 81 deletions(-) diff --git a/tests/brevitas/test_brevitas_avg_pool_export.py b/tests/brevitas/test_brevitas_avg_pool_export.py index 4674c8327..4b88b0f78 100644 --- a/tests/brevitas/test_brevitas_avg_pool_export.py +++ b/tests/brevitas/test_brevitas_avg_pool_export.py @@ -25,31 +25,29 @@ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - import os -import onnx # noqa import torch import numpy as np -import brevitas.onnx as bo -from brevitas.nn import QuantAvgPool2d -from brevitas.quant_tensor import QuantTensor -from brevitas.core.quant import QuantType +import pytest +import finn.core.onnx_exec as oxe from finn.core.modelwrapper import ModelWrapper from finn.core.datatype import DataType from finn.transformation.infer_shapes import InferShapes from finn.transformation.infer_datatypes import InferDataTypes from finn.util.basic import gen_finn_dt_tensor -import finn.core.onnx_exec as oxe -import pytest +from brevitas.export import FINNManager +from brevitas.nn import QuantAvgPool2d +from brevitas.quant_tensor import QuantTensor + export_onnx_path = "test_brevitas_avg_pool_export.onnx" @pytest.mark.parametrize("kernel_size", [2, 3]) @pytest.mark.parametrize("stride", [1, 2]) -@pytest.mark.parametrize("signed", [False, True]) +@pytest.mark.parametrize("signed", [True, False]) @pytest.mark.parametrize("bit_width", [2, 4]) @pytest.mark.parametrize("input_bit_width", [4, 8, 16]) @pytest.mark.parametrize("channels", [2, 4]) @@ -57,90 +55,46 @@ export_onnx_path = "test_brevitas_avg_pool_export.onnx" def test_brevitas_avg_pool_export( kernel_size, stride, signed, bit_width, input_bit_width, channels, idim ): - ishape = (1, channels, idim, idim) - ibw_tensor = torch.Tensor([input_bit_width]) - b_avgpool = QuantAvgPool2d( - kernel_size=kernel_size, - stride=stride, - bit_width=bit_width, - quant_type=QuantType.INT, - ) - # call forward pass manually once to cache scale factor and bitwidth - input_tensor = torch.from_numpy(np.zeros(ishape)).float() - scale = np.ones((1, channels, 1, 1)) - zpt = torch.from_numpy(np.zeros((1))).float() - output_scale = torch.from_numpy(scale).float() - input_quant_tensor = QuantTensor( - value=input_tensor, - scale=output_scale, - bit_width=ibw_tensor, - signed=signed, - zero_point=zpt, - ) - bo.export_finn_onnx( - b_avgpool, export_path=export_onnx_path, input_t=input_quant_tensor + quant_avgpool = QuantAvgPool2d( + kernel_size=kernel_size, stride=stride, bit_width=bit_width ) - model = ModelWrapper(export_onnx_path) + quant_avgpool.eval() - # determine input FINN datatype - if signed is True: - prefix = "INT" - else: - prefix = "UINT" + # determine input + prefix = "INT" if signed else "UINT" dt_name = prefix + str(input_bit_width) dtype = DataType[dt_name] - model = model.transform(InferShapes()) - model = model.transform(InferDataTypes()) - - # execution with input tensor using integers and scale = 1 - # calculate golden output - inp = gen_finn_dt_tensor(dtype, ishape) - input_tensor = torch.from_numpy(inp).float() - input_quant_tensor = QuantTensor( - value=input_tensor, - scale=output_scale, - bit_width=ibw_tensor, - signed=signed, - zero_point=zpt, - ) - b_avgpool.eval() - expected = b_avgpool.forward(input_quant_tensor).tensor.detach().numpy() - - # finn execution - idict = {model.graph.input[0].name: inp} - odict = oxe.execute_onnx(model, idict, True) - produced = odict[model.graph.output[0].name] - assert (expected == produced).all() - - # execution with input tensor using float and scale != 1 - scale = np.random.uniform(low=0, high=1, size=(1, channels, 1, 1)).astype( + input_shape = (1, channels, idim, idim) + input_array = gen_finn_dt_tensor(dtype, input_shape) + # Brevitas QuantAvgPool layers need QuantTensors to export correctly + # which requires setting up a QuantTensor instance with the scale + # factor, zero point, bitwidth and signedness + scale_array = np.random.uniform(low=0, high=1, size=(1, channels, 1, 1)).astype( np.float32 ) - inp_tensor = inp * scale - input_tensor = torch.from_numpy(inp_tensor).float() - input_scale = torch.from_numpy(scale).float() + input_tensor = torch.from_numpy(input_array * scale_array).float() + scale_tensor = torch.from_numpy(scale_array).float() + zp = torch.tensor(0.0) input_quant_tensor = QuantTensor( - value=input_tensor, - scale=input_scale, - bit_width=ibw_tensor, - signed=signed, - zero_point=zpt, + input_tensor, scale_tensor, zp, input_bit_width, signed ) - # export again to set the scale values correctly - bo.export_finn_onnx( - b_avgpool, export_path=export_onnx_path, input_t=input_quant_tensor + + # export + FINNManager.export( + quant_avgpool, export_path=export_onnx_path, input_t=input_quant_tensor ) model = ModelWrapper(export_onnx_path) model = model.transform(InferShapes()) model = model.transform(InferDataTypes()) - b_avgpool.eval() - expected = b_avgpool.forward(input_quant_tensor).tensor.detach().numpy() - # finn execution - idict = {model.graph.input[0].name: inp_tensor} - odict = oxe.execute_onnx(model, idict, True) - produced = odict[model.graph.output[0].name] - - assert np.isclose(expected, produced).all() + # reference brevitas output + ref_output_array = quant_avgpool(input_quant_tensor).tensor.detach().numpy() + # finn output + idict = {model.graph.input[0].name: input_array} + odict = oxe.execute_onnx(model, idict, True) + finn_output = odict[model.graph.output[0].name] + # compare outputs + assert np.isclose(ref_output_array, finn_output).all() + # cleanup os.remove(export_onnx_path) -- GitLab