Skip to content
Snippets Groups Projects
Commit 28cef50c authored by Hendrik Borras's avatar Hendrik Borras
Browse files

Added support for QONNX to test_brevitas_avg_pool_export test.

parent 8a6fb6ff
No related branches found
No related tags found
No related merge requests found
...@@ -86,7 +86,7 @@ RUN pip install -e git+https://github.com/fbcotter/dataset_loading.git@0.0.4#egg ...@@ -86,7 +86,7 @@ RUN pip install -e git+https://github.com/fbcotter/dataset_loading.git@0.0.4#egg
# git-based Python repo dependencies # git-based Python repo dependencies
# these are installed in editable mode for easier co-development # these are installed in editable mode for easier co-development
ARG FINN_BASE_COMMIT="52f0947b597687e1d7d336e1e175ccfc389648be" ARG FINN_BASE_COMMIT="78e4098ad3fc78f72db40b6a3cf29c82c2a567b1"
ARG QONNX_COMMIT="6d55dce220c7f744ef23585686460b9370b672a0" ARG QONNX_COMMIT="6d55dce220c7f744ef23585686460b9370b672a0"
ARG FINN_EXP_COMMIT="f82c0d9868bb88ea045dfadb28508d327d287221" ARG FINN_EXP_COMMIT="f82c0d9868bb88ea045dfadb28508d327d287221"
ARG BREVITAS_COMMIT="efc1217b94a71d616e3b4a37e56bd28a07c55be0" ARG BREVITAS_COMMIT="efc1217b94a71d616e3b4a37e56bd28a07c55be0"
......
...@@ -31,19 +31,23 @@ import numpy as np ...@@ -31,19 +31,23 @@ import numpy as np
import os import os
import torch import torch
from brevitas.export import FINNManager from brevitas.export import FINNManager
from brevitas.export.onnx.generic.manager import BrevitasONNXManager
from brevitas.nn import QuantAvgPool2d from brevitas.nn import QuantAvgPool2d
from brevitas.quant_tensor import QuantTensor from brevitas.quant_tensor import QuantTensor
from qonnx.util.cleanup import cleanup as qonnx_cleanup
import finn.core.onnx_exec as oxe import finn.core.onnx_exec as oxe
from finn.core.datatype import DataType from finn.core.datatype import DataType
from finn.core.modelwrapper import ModelWrapper from finn.core.modelwrapper import ModelWrapper
from finn.transformation.infer_datatypes import InferDataTypes from finn.transformation.infer_datatypes import InferDataTypes
from finn.transformation.infer_shapes import InferShapes from finn.transformation.infer_shapes import InferShapes
from finn.transformation.qonnx.convert_qonnx_to_finn import ConvertQONNXtoFINN
from finn.util.basic import gen_finn_dt_tensor from finn.util.basic import gen_finn_dt_tensor
export_onnx_path = "test_brevitas_avg_pool_export.onnx" base_export_onnx_path = "test_brevitas_avg_pool_export.onnx"
@pytest.mark.parametrize("QONNX_export", [False, True])
@pytest.mark.parametrize("kernel_size", [2, 3]) @pytest.mark.parametrize("kernel_size", [2, 3])
@pytest.mark.parametrize("stride", [1, 2]) @pytest.mark.parametrize("stride", [1, 2])
@pytest.mark.parametrize("signed", [True, False]) @pytest.mark.parametrize("signed", [True, False])
...@@ -52,11 +56,23 @@ export_onnx_path = "test_brevitas_avg_pool_export.onnx" ...@@ -52,11 +56,23 @@ export_onnx_path = "test_brevitas_avg_pool_export.onnx"
@pytest.mark.parametrize("channels", [2, 4]) @pytest.mark.parametrize("channels", [2, 4])
@pytest.mark.parametrize("idim", [7, 8]) @pytest.mark.parametrize("idim", [7, 8])
def test_brevitas_avg_pool_export( def test_brevitas_avg_pool_export(
kernel_size, stride, signed, bit_width, input_bit_width, channels, idim kernel_size,
stride,
signed,
bit_width,
input_bit_width,
channels,
idim,
QONNX_export,
): ):
export_onnx_path = base_export_onnx_path.replace(
".onnx", f"test_QONNX-{QONNX_export}.onnx"
)
quant_avgpool = QuantAvgPool2d( quant_avgpool = QuantAvgPool2d(
kernel_size=kernel_size, stride=stride, bit_width=bit_width kernel_size=kernel_size,
stride=stride,
bit_width=bit_width,
return_quant_tensor=False,
) )
quant_avgpool.eval() quant_avgpool.eval()
...@@ -69,31 +85,57 @@ def test_brevitas_avg_pool_export( ...@@ -69,31 +85,57 @@ def test_brevitas_avg_pool_export(
# Brevitas QuantAvgPool layers need QuantTensors to export correctly # Brevitas QuantAvgPool layers need QuantTensors to export correctly
# which requires setting up a QuantTensor instance with the scale # which requires setting up a QuantTensor instance with the scale
# factor, zero point, bitwidth and signedness # factor, zero point, bitwidth and signedness
scale_array = np.random.uniform(low=0, high=1, size=(1, channels, 1, 1)).astype( scale_array = np.ones((1, channels, 1, 1)).astype(np.float32)
np.float32 scale_array *= 0.5
)
input_tensor = torch.from_numpy(input_array * scale_array).float() input_tensor = torch.from_numpy(input_array * scale_array).float()
scale_tensor = torch.from_numpy(scale_array).float() scale_tensor = torch.from_numpy(scale_array).float()
zp = torch.tensor(0.0) zp = torch.tensor(0.0)
input_quant_tensor = QuantTensor( input_quant_tensor = QuantTensor(
input_tensor, scale_tensor, zp, input_bit_width, signed input_tensor, scale_tensor, zp, input_bit_width, signed, training=False
) )
# export # export
FINNManager.export( if QONNX_export:
quant_avgpool, export_path=export_onnx_path, input_t=input_quant_tensor BrevitasONNXManager.export(
) quant_avgpool,
export_path=export_onnx_path,
input_t=input_quant_tensor,
)
model = ModelWrapper(export_onnx_path)
# Statically set the additional inputs generated by the BrevitasONNXManager
model.graph.input.remove(model.graph.input[3])
model.graph.input.remove(model.graph.input[2])
model.graph.input.remove(model.graph.input[1])
model.set_initializer("1", scale_array)
model.set_initializer("2", np.array(0.0).astype(np.float32))
model.set_initializer("3", np.array(input_bit_width).astype(np.float32))
model.save(export_onnx_path)
qonnx_cleanup(export_onnx_path, out_file=export_onnx_path)
model = ModelWrapper(export_onnx_path)
model = model.transform(ConvertQONNXtoFINN())
model.save(export_onnx_path)
else:
FINNManager.export(
quant_avgpool, export_path=export_onnx_path, input_t=input_quant_tensor
)
model = ModelWrapper(export_onnx_path) model = ModelWrapper(export_onnx_path)
model = model.transform(InferShapes()) model = model.transform(InferShapes())
model = model.transform(InferDataTypes()) model = model.transform(InferDataTypes())
# reference brevitas output # reference brevitas output
ref_output_array = quant_avgpool(input_quant_tensor).tensor.detach().numpy() ref_output_array = quant_avgpool(input_quant_tensor).detach().numpy()
# finn output # finn output
idict = {model.graph.input[0].name: input_array} if QONNX_export:
# Manually apply the Quant tensor scaling for QONNX
idict = {model.graph.input[0].name: input_array * scale_array}
else:
idict = {model.graph.input[0].name: input_array}
odict = oxe.execute_onnx(model, idict, True) odict = oxe.execute_onnx(model, idict, True)
finn_output = odict[model.graph.output[0].name] finn_output = odict[model.graph.output[0].name]
# compare outputs # compare outputs
assert np.isclose(ref_output_array, finn_output).all() assert np.isclose(ref_output_array, finn_output).all()
# cleanup # cleanup
# assert False
os.remove(export_onnx_path) os.remove(export_onnx_path)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment