From b60468a641f2eb47fb236209d13eecb75a87e8dc Mon Sep 17 00:00:00 2001 From: Hendrik Borras <hendrikborras@web.de> Date: Fri, 15 Oct 2021 15:49:59 +0100 Subject: [PATCH] Added QONNX_export test to test_brevitas_avg_pool_export test (required test modifications). --- docker/Dockerfile.finn | 2 +- .../brevitas/test_brevitas_avg_pool_export.py | 62 ++++++++++++------- 2 files changed, 41 insertions(+), 23 deletions(-) diff --git a/docker/Dockerfile.finn b/docker/Dockerfile.finn index d809d99b9..2464d505c 100644 --- a/docker/Dockerfile.finn +++ b/docker/Dockerfile.finn @@ -86,7 +86,7 @@ RUN pip install -e git+https://github.com/fbcotter/dataset_loading.git@0.0.4#egg # git-based Python repo dependencies # these are installed in editable mode for easier co-development -ARG FINN_BASE_COMMIT="352cb9c41676fa509f57f20a32c7362c6c09039a" +ARG FINN_BASE_COMMIT="535b27013de83ff36925f2996745b12c9ba64d23" ARG QONNX_COMMIT="834610ba3f668971fe2800fde7f8d0c10d825d5b" ARG FINN_EXP_COMMIT="f82c0d9868bb88ea045dfadb28508d327d287221" ARG BREVITAS_COMMIT="0eaff006407955153594254728baeb988edcd042" diff --git a/tests/brevitas/test_brevitas_avg_pool_export.py b/tests/brevitas/test_brevitas_avg_pool_export.py index 68e563da6..fc6d50f8c 100644 --- a/tests/brevitas/test_brevitas_avg_pool_export.py +++ b/tests/brevitas/test_brevitas_avg_pool_export.py @@ -31,14 +31,17 @@ import numpy as np import os import torch from brevitas.export import FINNManager -from brevitas.nn import QuantAvgPool2d -from brevitas.quant_tensor import QuantTensor +from brevitas.export.onnx.generic.manager import BrevitasONNXManager +from brevitas.nn import QuantAvgPool2d, QuantIdentity +from qonnx.util.cleanup import cleanup as qonnx_cleanup +from torch import nn import finn.core.onnx_exec as oxe from finn.core.datatype import DataType from finn.core.modelwrapper import ModelWrapper from finn.transformation.infer_datatypes import InferDataTypes from finn.transformation.infer_shapes import InferShapes +from finn.transformation.qonnx.convert_qonnx_to_finn import ConvertQONNXtoFINN from finn.util.basic import gen_finn_dt_tensor export_onnx_path = "test_brevitas_avg_pool_export.onnx" @@ -51,12 +54,29 @@ export_onnx_path = "test_brevitas_avg_pool_export.onnx" @pytest.mark.parametrize("input_bit_width", [4, 8, 16]) @pytest.mark.parametrize("channels", [2, 4]) @pytest.mark.parametrize("idim", [7, 8]) +@pytest.mark.parametrize("QONNX_export", [False, True]) def test_brevitas_avg_pool_export( - kernel_size, stride, signed, bit_width, input_bit_width, channels, idim + kernel_size, + stride, + signed, + bit_width, + input_bit_width, + channels, + idim, + QONNX_export, ): - quant_avgpool = QuantAvgPool2d( - kernel_size=kernel_size, stride=stride, bit_width=bit_width + # To do a proper static export Brevitas requires a quantized input tensor. + # For the BrevitasONNXManager these requirements are even more stringent, + # such that in-model quantization and de-quantization at the end are required. + quant_avgpool = nn.Sequential( + QuantIdentity(bit_width=input_bit_width, return_quant_tensor=True), + QuantAvgPool2d( + kernel_size=kernel_size, + stride=stride, + bit_width=bit_width, + return_quant_tensor=False, + ), ) quant_avgpool.eval() @@ -66,31 +86,29 @@ def test_brevitas_avg_pool_export( dtype = DataType[dt_name] input_shape = (1, channels, idim, idim) input_array = gen_finn_dt_tensor(dtype, input_shape) - # Brevitas QuantAvgPool layers need QuantTensors to export correctly - # which requires setting up a QuantTensor instance with the scale - # factor, zero point, bitwidth and signedness - scale_array = np.random.uniform(low=0, high=1, size=(1, channels, 1, 1)).astype( - np.float32 - ) - input_tensor = torch.from_numpy(input_array * scale_array).float() - scale_tensor = torch.from_numpy(scale_array).float() - zp = torch.tensor(0.0) - input_quant_tensor = QuantTensor( - input_tensor, scale_tensor, zp, input_bit_width, signed - ) + input_tensor = torch.from_numpy(input_array).float() # export - FINNManager.export( - quant_avgpool, export_path=export_onnx_path, input_t=input_quant_tensor - ) + if QONNX_export: + BrevitasONNXManager.export( + quant_avgpool, + input_shape, + export_path=export_onnx_path, + ) + qonnx_cleanup(export_onnx_path, out_file=export_onnx_path) + model = ModelWrapper(export_onnx_path) + model = model.transform(ConvertQONNXtoFINN()) + model.save(export_onnx_path) + else: + FINNManager.export(quant_avgpool, input_shape, export_path=export_onnx_path) model = ModelWrapper(export_onnx_path) model = model.transform(InferShapes()) model = model.transform(InferDataTypes()) # reference brevitas output - ref_output_array = quant_avgpool(input_quant_tensor).tensor.detach().numpy() + ref_output_array = quant_avgpool(input_tensor).detach().numpy() # finn output - idict = {model.graph.input[0].name: input_array} + idict = {model.graph.input[0].name: input_tensor.detach().numpy()} odict = oxe.execute_onnx(model, idict, True) finn_output = odict[model.graph.output[0].name] # compare outputs -- GitLab