import os import onnx # noqa import torch import numpy as np import brevitas.onnx as bo from brevitas.nn import QuantAvgPool2d from brevitas.quant_tensor import pack_quant_tensor from brevitas.core.quant import QuantType from finn.core.modelwrapper import ModelWrapper from finn.core.datatype import DataType from finn.transformation.infer_shapes import InferShapes from finn.transformation.infer_datatypes import InferDataTypes from finn.util.basic import gen_finn_dt_tensor import finn.core.onnx_exec as oxe import pytest export_onnx_path = "test_brevitas_avg_pool_export.onnx" @pytest.mark.parametrize("kernel_size", [2, 3]) @pytest.mark.parametrize("stride", [1, 2]) @pytest.mark.parametrize("signed", [False, True]) @pytest.mark.parametrize("bit_width", [2, 4]) @pytest.mark.parametrize("input_bit_width", [4, 8, 16]) @pytest.mark.parametrize("channels", [2, 4]) @pytest.mark.parametrize("idim", [7, 8]) def test_brevitas_avg_pool_export( kernel_size, stride, signed, bit_width, input_bit_width, channels, idim ): ishape = (1, channels, idim, idim) ibw_tensor = torch.Tensor([input_bit_width]) b_avgpool = QuantAvgPool2d( kernel_size=kernel_size, stride=stride, bit_width=bit_width, quant_type=QuantType.INT, ) # call forward pass manually once to cache scale factor and bitwidth input_tensor = torch.from_numpy(np.zeros(ishape)).float() scale = np.ones((1, channels, 1, 1)) output_scale = torch.from_numpy(scale).float() input_quant_tensor = pack_quant_tensor( tensor=input_tensor, scale=output_scale, bit_width=ibw_tensor, signed=signed ) bo.export_finn_onnx(b_avgpool, ishape, export_onnx_path, input_t=input_quant_tensor) model = ModelWrapper(export_onnx_path) # determine input FINN datatype if signed is True: prefix = "INT" else: prefix = "UINT" dt_name = prefix + str(input_bit_width) dtype = DataType[dt_name] model = model.transform(InferShapes()) model = model.transform(InferDataTypes()) # execution with input tensor using integers and scale = 1 # calculate golden output inp = gen_finn_dt_tensor(dtype, ishape) input_tensor = torch.from_numpy(inp).float() input_quant_tensor = pack_quant_tensor( tensor=input_tensor, scale=output_scale, bit_width=ibw_tensor, signed=signed ) b_avgpool.eval() expected = b_avgpool.forward(input_quant_tensor).tensor.detach().numpy() # finn execution idict = {model.graph.input[0].name: inp} odict = oxe.execute_onnx(model, idict, True) produced = odict[model.graph.output[0].name] assert (expected == produced).all() # execution with input tensor using float and scale != 1 scale = np.random.uniform(low=0, high=1, size=(1, channels, 1, 1)).astype( np.float32 ) inp_tensor = inp * scale input_tensor = torch.from_numpy(inp_tensor).float() input_scale = torch.from_numpy(scale).float() input_quant_tensor = pack_quant_tensor( tensor=input_tensor, scale=input_scale, bit_width=ibw_tensor, signed=signed ) # export again to set the scale values correctly bo.export_finn_onnx(b_avgpool, ishape, export_onnx_path, input_t=input_quant_tensor) model = ModelWrapper(export_onnx_path) model = model.transform(InferShapes()) model = model.transform(InferDataTypes()) b_avgpool.eval() expected = b_avgpool.forward(input_quant_tensor).tensor.detach().numpy() # finn execution idict = {model.graph.input[0].name: inp_tensor} odict = oxe.execute_onnx(model, idict, True) produced = odict[model.graph.output[0].name] assert np.isclose(expected, produced).all() os.remove(export_onnx_path)