Skip to content
Snippets Groups Projects
Unverified Commit 444bce22 authored by auphelia's avatar auphelia Committed by GitHub
Browse files

Merge pull request #791 from Xilinx/fix/infer_quant_avg_pool

Fix for QONNX -> FINN-ONNX conversion (QuantAvgPool)
parents 3918bfb2 b056303b
No related branches found
No related tags found
No related merge requests found
......@@ -27,7 +27,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
QONNX_COMMIT="d9ac34c638ccbdcd3b3f5cd236fe76d611b08f6a"
QONNX_COMMIT="20a34289cf2297d2b2bbbe75d6ac152ece86e3b4"
FINN_EXP_COMMIT="0aa7e1c44b20cf085b6fe42cff360f0a832afd2c"
BREVITAS_COMMIT="c65f9c13dc124971f14739349531bbcda5c2a4aa"
PYVERILATOR_COMMIT="766e457465f5c0dd315490d7b9cc5d74f9a76f4f"
......
......@@ -46,7 +46,7 @@ def _get_signed_from_upstream(model, trunc_node):
# Check if the input of this node already has a FINN datatype
signed = None
inp_dt = model.get_tensor_datatype(node.input[0])
if inp_dt is not None and inp_dt is not DataType["FLOAT32"]:
if inp_dt is not None and inp_dt != DataType["FLOAT32"]:
signed = inp_dt.signed()
# Go further up the graph, since the datatype inference works top down
# these nodes should either be sign preserving ops or they already have a
......@@ -67,23 +67,23 @@ def _get_signed_from_upstream(model, trunc_node):
)
next_node = next_node[0]
out_dt = model.get_tensor_datatype(next_node.output[0])
if out_dt is not None and out_dt is not DataType["FLOAT32"]:
if out_dt is not None and out_dt != DataType["FLOAT32"]:
signed = out_dt.signed()
break
# Special cases where the node has an internal or intrinsic datatype.
if next_node.op_type == "MultiThreshold":
mt_inst = getCustomOp(next_node)
mt_inst = getCustomOp(next_node, onnx_opset_version=9)
out_dt = DataType[mt_inst.get_nodeattr("out_dtype")]
if out_dt is not None and out_dt is not DataType["FLOAT32"]:
if out_dt is not None and out_dt != DataType["FLOAT32"]:
signed = out_dt.signed()
break
if next_node.op_type == "BipolarQuant":
signed = True
break
if next_node.op_type == "Quant":
q_inst = getCustomOp(next_node)
q_inst = getCustomOp(next_node, onnx_opset_version=9)
out_dt = q_inst.get_integer_datatype(model)
if out_dt is not None and out_dt is not DataType["FLOAT32"]:
if out_dt is not None and out_dt != DataType["FLOAT32"]:
signed = out_dt.signed()
break
......
......@@ -30,9 +30,8 @@ import pytest
import numpy as np
import os
import torch
from brevitas.export import export_finn_onnx, export_qonnx
from brevitas.nn import QuantAvgPool2d
from brevitas.quant_tensor import QuantTensor
from brevitas.export import export_qonnx
from brevitas.nn import QuantAvgPool2d, QuantIdentity, QuantReLU
from qonnx.core.datatype import DataType
from qonnx.core.modelwrapper import ModelWrapper
from qonnx.transformation.infer_datatypes import InferDataTypes
......@@ -47,10 +46,9 @@ base_export_onnx_path = "test_brevitas_avg_pool_export.onnx"
@pytest.mark.brevitas_export
@pytest.mark.parametrize("QONNX_export", [False, True])
@pytest.mark.parametrize("kernel_size", [2, 3])
@pytest.mark.parametrize("stride", [1, 2])
@pytest.mark.parametrize("signed", [True, False])
@pytest.mark.parametrize("signed", [True]) # TODO: Add unsigned test case
@pytest.mark.parametrize("bit_width", [2, 4])
@pytest.mark.parametrize("input_bit_width", [4, 8, 16])
@pytest.mark.parametrize("channels", [2, 4])
......@@ -63,79 +61,56 @@ def test_brevitas_avg_pool_export(
input_bit_width,
channels,
idim,
QONNX_export,
):
export_onnx_path = base_export_onnx_path.replace(
".onnx", f"test_QONNX-{QONNX_export}.onnx"
)
export_onnx_path = base_export_onnx_path.replace(".onnx", "test_QONNX.onnx")
if signed:
quant_node = QuantIdentity(
bit_width=input_bit_width,
return_quant_tensor=True,
)
else:
quant_node = QuantReLU(
bit_width=input_bit_width,
return_quant_tensor=True,
)
quant_avgpool = QuantAvgPool2d(
kernel_size=kernel_size,
stride=stride,
bit_width=bit_width,
return_quant_tensor=False,
float_to_int_impl_type="FLOOR",
)
quant_avgpool.eval()
model_brevitas = torch.nn.Sequential(quant_node, quant_avgpool)
model_brevitas.eval()
# determine input
prefix = "INT" if signed else "UINT"
dt_name = prefix + str(input_bit_width)
dtype = DataType[dt_name]
input_shape = (1, channels, idim, idim)
input_array = gen_finn_dt_tensor(dtype, input_shape)
# Brevitas QuantAvgPool layers need QuantTensors to export correctly
# which requires setting up a QuantTensor instance with the scale
# factor, zero point, bitwidth and signedness
scale_array = np.ones((1, channels, 1, 1)).astype(np.float32)
scale_array *= 0.5
input_tensor = torch.from_numpy(input_array * scale_array).float()
scale_tensor = torch.from_numpy(scale_array).float()
zp = torch.tensor(0.0)
input_quant_tensor = QuantTensor(
input_tensor, scale_tensor, zp, input_bit_width, signed, training=False
)
input_array = gen_finn_dt_tensor(DataType["FLOAT32"], input_shape)
# export
if QONNX_export:
export_qonnx(
quant_avgpool,
export_path=export_onnx_path,
input_t=input_quant_tensor,
)
model = ModelWrapper(export_onnx_path)
input_tensor = torch.from_numpy(input_array).float()
# Statically set the additional inputs generated by the Brevitas ONNX export
model.graph.input.remove(model.graph.input[3])
model.graph.input.remove(model.graph.input[2])
model.graph.input.remove(model.graph.input[1])
model.set_initializer("1", scale_array)
model.set_initializer("2", np.array(0.0).astype(np.float32))
model.set_initializer("3", np.array(input_bit_width).astype(np.float32))
model.save(export_onnx_path)
# export
export_qonnx(
model_brevitas,
export_path=export_onnx_path,
input_t=input_tensor,
)
model = ModelWrapper(export_onnx_path)
model.save(export_onnx_path)
qonnx_cleanup(export_onnx_path, out_file=export_onnx_path)
model = ModelWrapper(export_onnx_path)
model = model.transform(ConvertQONNXtoFINN())
model.save(export_onnx_path)
else:
export_finn_onnx(
quant_avgpool, export_path=export_onnx_path, input_t=input_quant_tensor
)
qonnx_cleanup(export_onnx_path, out_file=export_onnx_path)
model = ModelWrapper(export_onnx_path)
model = model.transform(ConvertQONNXtoFINN())
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())
# reference brevitas output
ref_output_array = quant_avgpool(input_quant_tensor).detach().numpy()
ref_output_array = model_brevitas(input_tensor).detach().numpy()
# finn output
if QONNX_export:
# Manually apply the Quant tensor scaling for QONNX
idict = {model.graph.input[0].name: input_array * scale_array}
else:
idict = {model.graph.input[0].name: input_array}
idict = {model.graph.input[0].name: input_array}
odict = oxe.execute_onnx(model, idict, True)
finn_output = odict[model.graph.output[0].name]
# compare outputs
assert np.isclose(ref_output_array, finn_output).all()
# cleanup
# assert False
os.remove(export_onnx_path)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment