Skip to content
Snippets Groups Projects
Commit 28cef50c authored by Hendrik Borras's avatar Hendrik Borras
Browse files

Added support for QONNX to test_brevitas_avg_pool_export test.

parent 8a6fb6ff
No related branches found
No related tags found
No related merge requests found
......@@ -86,7 +86,7 @@ RUN pip install -e git+https://github.com/fbcotter/dataset_loading.git@0.0.4#egg
# git-based Python repo dependencies
# these are installed in editable mode for easier co-development
ARG FINN_BASE_COMMIT="52f0947b597687e1d7d336e1e175ccfc389648be"
ARG FINN_BASE_COMMIT="78e4098ad3fc78f72db40b6a3cf29c82c2a567b1"
ARG QONNX_COMMIT="6d55dce220c7f744ef23585686460b9370b672a0"
ARG FINN_EXP_COMMIT="f82c0d9868bb88ea045dfadb28508d327d287221"
ARG BREVITAS_COMMIT="efc1217b94a71d616e3b4a37e56bd28a07c55be0"
......
......@@ -31,19 +31,23 @@ import numpy as np
import os
import torch
from brevitas.export import FINNManager
from brevitas.export.onnx.generic.manager import BrevitasONNXManager
from brevitas.nn import QuantAvgPool2d
from brevitas.quant_tensor import QuantTensor
from qonnx.util.cleanup import cleanup as qonnx_cleanup
import finn.core.onnx_exec as oxe
from finn.core.datatype import DataType
from finn.core.modelwrapper import ModelWrapper
from finn.transformation.infer_datatypes import InferDataTypes
from finn.transformation.infer_shapes import InferShapes
from finn.transformation.qonnx.convert_qonnx_to_finn import ConvertQONNXtoFINN
from finn.util.basic import gen_finn_dt_tensor
export_onnx_path = "test_brevitas_avg_pool_export.onnx"
base_export_onnx_path = "test_brevitas_avg_pool_export.onnx"
@pytest.mark.parametrize("QONNX_export", [False, True])
@pytest.mark.parametrize("kernel_size", [2, 3])
@pytest.mark.parametrize("stride", [1, 2])
@pytest.mark.parametrize("signed", [True, False])
......@@ -52,11 +56,23 @@ export_onnx_path = "test_brevitas_avg_pool_export.onnx"
@pytest.mark.parametrize("channels", [2, 4])
@pytest.mark.parametrize("idim", [7, 8])
def test_brevitas_avg_pool_export(
kernel_size, stride, signed, bit_width, input_bit_width, channels, idim
kernel_size,
stride,
signed,
bit_width,
input_bit_width,
channels,
idim,
QONNX_export,
):
export_onnx_path = base_export_onnx_path.replace(
".onnx", f"test_QONNX-{QONNX_export}.onnx"
)
quant_avgpool = QuantAvgPool2d(
kernel_size=kernel_size, stride=stride, bit_width=bit_width
kernel_size=kernel_size,
stride=stride,
bit_width=bit_width,
return_quant_tensor=False,
)
quant_avgpool.eval()
......@@ -69,31 +85,57 @@ def test_brevitas_avg_pool_export(
# Brevitas QuantAvgPool layers need QuantTensors to export correctly
# which requires setting up a QuantTensor instance with the scale
# factor, zero point, bitwidth and signedness
scale_array = np.random.uniform(low=0, high=1, size=(1, channels, 1, 1)).astype(
np.float32
)
scale_array = np.ones((1, channels, 1, 1)).astype(np.float32)
scale_array *= 0.5
input_tensor = torch.from_numpy(input_array * scale_array).float()
scale_tensor = torch.from_numpy(scale_array).float()
zp = torch.tensor(0.0)
input_quant_tensor = QuantTensor(
input_tensor, scale_tensor, zp, input_bit_width, signed
input_tensor, scale_tensor, zp, input_bit_width, signed, training=False
)
# export
FINNManager.export(
quant_avgpool, export_path=export_onnx_path, input_t=input_quant_tensor
)
if QONNX_export:
BrevitasONNXManager.export(
quant_avgpool,
export_path=export_onnx_path,
input_t=input_quant_tensor,
)
model = ModelWrapper(export_onnx_path)
# Statically set the additional inputs generated by the BrevitasONNXManager
model.graph.input.remove(model.graph.input[3])
model.graph.input.remove(model.graph.input[2])
model.graph.input.remove(model.graph.input[1])
model.set_initializer("1", scale_array)
model.set_initializer("2", np.array(0.0).astype(np.float32))
model.set_initializer("3", np.array(input_bit_width).astype(np.float32))
model.save(export_onnx_path)
qonnx_cleanup(export_onnx_path, out_file=export_onnx_path)
model = ModelWrapper(export_onnx_path)
model = model.transform(ConvertQONNXtoFINN())
model.save(export_onnx_path)
else:
FINNManager.export(
quant_avgpool, export_path=export_onnx_path, input_t=input_quant_tensor
)
model = ModelWrapper(export_onnx_path)
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())
# reference brevitas output
ref_output_array = quant_avgpool(input_quant_tensor).tensor.detach().numpy()
ref_output_array = quant_avgpool(input_quant_tensor).detach().numpy()
# finn output
idict = {model.graph.input[0].name: input_array}
if QONNX_export:
# Manually apply the Quant tensor scaling for QONNX
idict = {model.graph.input[0].name: input_array * scale_array}
else:
idict = {model.graph.input[0].name: input_array}
odict = oxe.execute_onnx(model, idict, True)
finn_output = odict[model.graph.output[0].name]
# compare outputs
assert np.isclose(ref_output_array, finn_output).all()
# cleanup
# assert False
os.remove(export_onnx_path)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment