Skip to content
Snippets Groups Projects
Commit b60468a6 authored by Hendrik Borras's avatar Hendrik Borras
Browse files

Added QONNX_export test to test_brevitas_avg_pool_export test (required test modifications).

parent 0c63afca
No related branches found
No related tags found
No related merge requests found
......@@ -86,7 +86,7 @@ RUN pip install -e git+https://github.com/fbcotter/dataset_loading.git@0.0.4#egg
# git-based Python repo dependencies
# these are installed in editable mode for easier co-development
ARG FINN_BASE_COMMIT="352cb9c41676fa509f57f20a32c7362c6c09039a"
ARG FINN_BASE_COMMIT="535b27013de83ff36925f2996745b12c9ba64d23"
ARG QONNX_COMMIT="834610ba3f668971fe2800fde7f8d0c10d825d5b"
ARG FINN_EXP_COMMIT="f82c0d9868bb88ea045dfadb28508d327d287221"
ARG BREVITAS_COMMIT="0eaff006407955153594254728baeb988edcd042"
......
......@@ -31,14 +31,17 @@ import numpy as np
import os
import torch
from brevitas.export import FINNManager
from brevitas.nn import QuantAvgPool2d
from brevitas.quant_tensor import QuantTensor
from brevitas.export.onnx.generic.manager import BrevitasONNXManager
from brevitas.nn import QuantAvgPool2d, QuantIdentity
from qonnx.util.cleanup import cleanup as qonnx_cleanup
from torch import nn
import finn.core.onnx_exec as oxe
from finn.core.datatype import DataType
from finn.core.modelwrapper import ModelWrapper
from finn.transformation.infer_datatypes import InferDataTypes
from finn.transformation.infer_shapes import InferShapes
from finn.transformation.qonnx.convert_qonnx_to_finn import ConvertQONNXtoFINN
from finn.util.basic import gen_finn_dt_tensor
export_onnx_path = "test_brevitas_avg_pool_export.onnx"
......@@ -51,12 +54,29 @@ export_onnx_path = "test_brevitas_avg_pool_export.onnx"
@pytest.mark.parametrize("input_bit_width", [4, 8, 16])
@pytest.mark.parametrize("channels", [2, 4])
@pytest.mark.parametrize("idim", [7, 8])
@pytest.mark.parametrize("QONNX_export", [False, True])
def test_brevitas_avg_pool_export(
kernel_size, stride, signed, bit_width, input_bit_width, channels, idim
kernel_size,
stride,
signed,
bit_width,
input_bit_width,
channels,
idim,
QONNX_export,
):
quant_avgpool = QuantAvgPool2d(
kernel_size=kernel_size, stride=stride, bit_width=bit_width
# To do a proper static export Brevitas requires a quantized input tensor.
# For the BrevitasONNXManager these requirements are even more stringent,
# such that in-model quantization and de-quantization at the end are required.
quant_avgpool = nn.Sequential(
QuantIdentity(bit_width=input_bit_width, return_quant_tensor=True),
QuantAvgPool2d(
kernel_size=kernel_size,
stride=stride,
bit_width=bit_width,
return_quant_tensor=False,
),
)
quant_avgpool.eval()
......@@ -66,31 +86,29 @@ def test_brevitas_avg_pool_export(
dtype = DataType[dt_name]
input_shape = (1, channels, idim, idim)
input_array = gen_finn_dt_tensor(dtype, input_shape)
# Brevitas QuantAvgPool layers need QuantTensors to export correctly
# which requires setting up a QuantTensor instance with the scale
# factor, zero point, bitwidth and signedness
scale_array = np.random.uniform(low=0, high=1, size=(1, channels, 1, 1)).astype(
np.float32
)
input_tensor = torch.from_numpy(input_array * scale_array).float()
scale_tensor = torch.from_numpy(scale_array).float()
zp = torch.tensor(0.0)
input_quant_tensor = QuantTensor(
input_tensor, scale_tensor, zp, input_bit_width, signed
)
input_tensor = torch.from_numpy(input_array).float()
# export
FINNManager.export(
quant_avgpool, export_path=export_onnx_path, input_t=input_quant_tensor
)
if QONNX_export:
BrevitasONNXManager.export(
quant_avgpool,
input_shape,
export_path=export_onnx_path,
)
qonnx_cleanup(export_onnx_path, out_file=export_onnx_path)
model = ModelWrapper(export_onnx_path)
model = model.transform(ConvertQONNXtoFINN())
model.save(export_onnx_path)
else:
FINNManager.export(quant_avgpool, input_shape, export_path=export_onnx_path)
model = ModelWrapper(export_onnx_path)
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())
# reference brevitas output
ref_output_array = quant_avgpool(input_quant_tensor).tensor.detach().numpy()
ref_output_array = quant_avgpool(input_tensor).detach().numpy()
# finn output
idict = {model.graph.input[0].name: input_array}
idict = {model.graph.input[0].name: input_tensor.detach().numpy()}
odict = oxe.execute_onnx(model, idict, True)
finn_output = odict[model.graph.output[0].name]
# compare outputs
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment