diff --git a/tests/brevitas/test_brevitas_avg_pool_export.py b/tests/brevitas/test_brevitas_avg_pool_export.py index f3d6c5dde7179bec8fe97e2a6c791afb5733514c..cf91d70ad53e5d9f67f851e5c62342e3314f88ed 100644 --- a/tests/brevitas/test_brevitas_avg_pool_export.py +++ b/tests/brevitas/test_brevitas_avg_pool_export.py @@ -5,7 +5,7 @@ import torch import numpy as np import brevitas.onnx as bo from brevitas.nn import QuantAvgPool2d -from brevitas.quant_tensor import pack_quant_tensor +from brevitas.quant_tensor import QuantTensor from brevitas.core.quant import QuantType from finn.core.modelwrapper import ModelWrapper from finn.core.datatype import DataType @@ -41,11 +41,18 @@ def test_brevitas_avg_pool_export( # call forward pass manually once to cache scale factor and bitwidth input_tensor = torch.from_numpy(np.zeros(ishape)).float() scale = np.ones((1, channels, 1, 1)) + zpt = torch.from_numpy(np.zeros((1))).float() output_scale = torch.from_numpy(scale).float() - input_quant_tensor = pack_quant_tensor( - tensor=input_tensor, scale=output_scale, bit_width=ibw_tensor, signed=signed + input_quant_tensor = QuantTensor( + value=input_tensor, + scale=output_scale, + bit_width=ibw_tensor, + signed=signed, + zero_point=zpt, + ) + bo.export_finn_onnx( + b_avgpool, export_path=export_onnx_path, input_t=input_quant_tensor ) - bo.export_finn_onnx(b_avgpool, ishape, export_onnx_path, input_t=input_quant_tensor) model = ModelWrapper(export_onnx_path) # determine input FINN datatype @@ -62,8 +69,12 @@ def test_brevitas_avg_pool_export( # calculate golden output inp = gen_finn_dt_tensor(dtype, ishape) input_tensor = torch.from_numpy(inp).float() - input_quant_tensor = pack_quant_tensor( - tensor=input_tensor, scale=output_scale, bit_width=ibw_tensor, signed=signed + input_quant_tensor = QuantTensor( + value=input_tensor, + scale=output_scale, + bit_width=ibw_tensor, + signed=signed, + zero_point=zpt, ) b_avgpool.eval() expected = b_avgpool.forward(input_quant_tensor).tensor.detach().numpy() @@ -81,11 +92,17 @@ def test_brevitas_avg_pool_export( inp_tensor = inp * scale input_tensor = torch.from_numpy(inp_tensor).float() input_scale = torch.from_numpy(scale).float() - input_quant_tensor = pack_quant_tensor( - tensor=input_tensor, scale=input_scale, bit_width=ibw_tensor, signed=signed + input_quant_tensor = QuantTensor( + value=input_tensor, + scale=input_scale, + bit_width=ibw_tensor, + signed=signed, + zero_point=zpt, ) # export again to set the scale values correctly - bo.export_finn_onnx(b_avgpool, ishape, export_onnx_path, input_t=input_quant_tensor) + bo.export_finn_onnx( + b_avgpool, export_path=export_onnx_path, input_t=input_quant_tensor + ) model = ModelWrapper(export_onnx_path) model = model.transform(InferShapes()) model = model.transform(InferDataTypes())