Skip to content
Snippets Groups Projects
Unverified Commit b628d1a8 authored by Yaman Umuroglu's avatar Yaman Umuroglu Committed by GitHub
Browse files

Merge pull request #144 from Xilinx/feature/onnx_export_quant_avg_pool

Feature/onnx export quant avg pool
parents dde30255 c2e51d31
No related branches found
No related tags found
No related merge requests found
......@@ -13,7 +13,7 @@ gecho () {
# checkout the correct dependency repo commits
# the repos themselves are cloned in the Dockerfile
BREVITAS_COMMIT=7696326e5f279cacffd5b6ac8d9e8d81deec3978
BREVITAS_COMMIT=026a509186b7e7b0b65d46a2f905043d41069306
CNPY_COMMIT=4e8810b1a8637695171ed346ce68f6984e585ef4
HLSLIB_COMMIT=13e9b0772a27a3a1efc40c878d8e78ed09efb716
PYVERILATOR_COMMIT=c97a5ba41bbc7c419d6f25c74cdf3bdc3393174f
......
import numpy as np
from onnx import TensorProto, helper
import onnxruntime as rt
from finn.custom_op import CustomOp
from finn.core.datatype import DataType
class QuantAvgPool2d(CustomOp):
"""Class that corresponds to the quantized average pooling
layer from brevitas"""
def get_nodeattr_types(self):
return {
"stride": ("i", True, 1),
"kernel": ("i", True, 1),
"ibits": ("i", True, 1),
"obits": ("i", True, 1),
"signed": ("i", True, 0),
}
def make_shape_compatible_op(self, model):
node = self.onnx_node
k = self.get_nodeattr("kernel")
s = self.get_nodeattr("stride")
return helper.make_node(
"AveragePool",
inputs=[node.input[0]],
outputs=[node.output[0]],
kernel_shape=[k, k],
strides=[s, s],
)
def infer_node_datatype(self, model):
node = self.onnx_node
bw = self.get_nodeattr("obits")
if bw in [2, 4, 8, 16, 32]:
if self.get_nodeattr("signed") == 0:
dtype = DataType["UINT%d" % bw]
else:
dtype = DataType["INT%d" % bw]
else:
raise Exception("Unsupported output datatype for QuantAvgPool2d")
model.set_tensor_datatype(node.output[0], dtype)
def execute_node(self, context, graph):
# create a standard average pooling node to help calculate the result
node = self.onnx_node
k = self.get_nodeattr("kernel")
s = self.get_nodeattr("stride")
ishape = context[node.input[0]].shape
oshape = context[node.output[0]].shape
inp = helper.make_tensor_value_info(node.input[0], TensorProto.FLOAT, ishape)
outp = helper.make_tensor_value_info(node.output[0], TensorProto.FLOAT, oshape)
node_avgpool = helper.make_node(
"AveragePool",
inputs=[node.input[0]],
outputs=[node.output[0]],
kernel_shape=[k, k],
strides=[s, s],
)
graph_avgpool = helper.make_graph(
nodes=[node_avgpool],
name="single-avgpool-exec",
inputs=[inp],
outputs=[outp],
)
model_avgpool = helper.make_model(graph_avgpool)
idict = {node.input[0]: context[node.input[0]]}
sess = rt.InferenceSession(model_avgpool.SerializeToString())
result_temp = sess.run(None, idict)
# remove scaling introduced by average
result_temp = result_temp[0] * (k * k)
ibits = self.get_nodeattr("ibits")
max_value = 2 ** ibits - 1
max_value = max_value * k * k
max_bit_width = int(max_value).bit_length()
shift_bits = max_bit_width - self.get_nodeattr("obits")
result = np.right_shift(result_temp.astype(int), shift_bits)
context[node.output[0]] = result.astype(np.float32)
def verify_node(self):
pass
......@@ -48,6 +48,7 @@ from finn.custom_op.fpgadataflow.fmpadding import FMPadding_Batch
from finn.custom_op.fpgadataflow.thresholding_batch import Thresholding_Batch
from finn.custom_op.fpgadataflow.addstreams_batch import AddStreams_Batch
from finn.custom_op.fpgadataflow.labelselect_batch import LabelSelect_Batch
from finn.custom_op.quantavgpool2d import QuantAvgPool2d
from finn.custom_op.fpgadataflow.duplicatestreams_batch import DuplicateStreams_Batch
# create a mapping of all known CustomOp names and classes
......@@ -69,6 +70,7 @@ custom_op["FMPadding_Batch"] = FMPadding_Batch
custom_op["Thresholding_Batch"] = Thresholding_Batch
custom_op["AddStreams_Batch"] = AddStreams_Batch
custom_op["LabelSelect_Batch"] = LabelSelect_Batch
custom_op["QuantAvgPool2d"] = QuantAvgPool2d
custom_op["DuplicateStreams_Batch"] = DuplicateStreams_Batch
......
......@@ -71,7 +71,13 @@ def _infer_node_datatype(model, node):
else:
# unknown, assume node produces float32 outputs
for o in node.output:
model.set_tensor_datatype(o, DataType.FLOAT32)
# check if output datatype is already set to a value != FLOAT32
odtype = model.get_tensor_datatype(o)
if odtype is not None and odtype != DataType.FLOAT32:
# don't change data type
model.set_tensor_datatype(o, odtype)
else:
model.set_tensor_datatype(o, DataType.FLOAT32)
# compare old and new output dtypes to see if anything changed
new_odtypes = list(map(lambda x: model.get_tensor_datatype(x), node.output))
graph_modified = new_odtypes != odtypes
......
......@@ -30,6 +30,7 @@ from onnx import helper as oh
from finn.transformation import Transformation
from finn.transformation.infer_shapes import InferShapes
from finn.core.datatype import DataType
class CollapseRepeatedOp(Transformation):
......@@ -83,6 +84,9 @@ class CollapseRepeatedOp(Transformation):
graph.node.insert(node_ind, new_node)
# replace parameter value
model.set_initializer(new_node_param_name, new_param)
# be conservative with param/output DataTypes
model.set_tensor_datatype(new_node_param_name, DataType.FLOAT32)
model.set_tensor_datatype(end_name, DataType.FLOAT32)
# remove old nodes
graph.node.remove(n)
graph.node.remove(consumer)
......
......@@ -106,6 +106,14 @@ def get_finn_root():
)
def get_execution_error_thresh():
"Return the max error that is allowed for rounding in FINN execution."
try:
return float(os.environ["ERROR_THRESH"])
except KeyError:
return 1e-2
def make_build_dir(prefix=""):
"""Creates a temporary folder with given prefix to be used as a build dir.
Use this function instead of tempfile.mkdtemp to ensure any generated files
......@@ -305,7 +313,7 @@ def sanitize_quant_values(model, node_tensors, execution_context, check_values=F
)
# check if rounded values are not too far from original values
max_error = max(np.abs(current_values - updated_values).flatten())
if max_error <= 1e-4:
if max_error <= get_execution_error_thresh():
if check_values is True:
# check again if values can now be represented with set finn datatype
# TODO: vectorize with numpy
......
import os
import onnx # noqa
import torch
import numpy as np
import brevitas.onnx as bo
from brevitas.nn import QuantAvgPool2d
from brevitas.quant_tensor import pack_quant_tensor
from brevitas.core.quant import QuantType
from finn.core.modelwrapper import ModelWrapper
from finn.core.datatype import DataType
from finn.transformation.infer_shapes import InferShapes
from finn.transformation.infer_datatypes import InferDataTypes
from finn.util.basic import gen_finn_dt_tensor
import finn.core.onnx_exec as oxe
import pytest
export_onnx_path = "test_avg_pool.onnx"
@pytest.mark.parametrize("kernel_size", [2, 3])
@pytest.mark.parametrize("stride", [1, 2])
@pytest.mark.parametrize("signed", [False, True])
@pytest.mark.parametrize("bit_width", [2, 4])
@pytest.mark.parametrize("input_bit_width", [4, 8, 32])
@pytest.mark.parametrize("channels", [2, 4])
@pytest.mark.parametrize("idim", [7, 8])
def test_brevitas_avg_pool_export(
kernel_size, stride, signed, bit_width, input_bit_width, channels, idim
):
ishape = (1, channels, idim, idim)
ibw_tensor = torch.Tensor([input_bit_width])
b_avgpool = QuantAvgPool2d(
kernel_size=kernel_size,
stride=stride,
signed=signed,
min_overall_bit_width=bit_width,
max_overall_bit_width=bit_width,
quant_type=QuantType.INT,
)
# call forward pass manually once to cache scale factor and bitwidth
input_tensor = torch.from_numpy(np.zeros(ishape)).float()
scale = np.ones((1, channels, 1, 1))
output_scale = torch.from_numpy(scale).float()
input_quant_tensor = pack_quant_tensor(
tensor=input_tensor, scale=output_scale, bit_width=ibw_tensor
)
bo.export_finn_onnx(b_avgpool, ishape, export_onnx_path, input_t=input_quant_tensor)
model = ModelWrapper(export_onnx_path)
# determine input FINN datatype
if signed is True:
prefix = "INT"
else:
prefix = "UINT"
dt_name = prefix + str(input_bit_width // 2)
dtype = DataType[dt_name]
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())
# execution with input tensor using integers and scale = 1
# calculate golden output
inp = gen_finn_dt_tensor(dtype, ishape)
input_tensor = torch.from_numpy(inp).float()
input_quant_tensor = pack_quant_tensor(
tensor=input_tensor, scale=output_scale, bit_width=ibw_tensor
)
b_avgpool.eval()
expected = b_avgpool.forward(input_quant_tensor).tensor.detach().numpy()
# finn execution
idict = {model.graph.input[0].name: inp}
odict = oxe.execute_onnx(model, idict, True)
produced = odict[model.graph.output[0].name]
assert (expected == produced).all()
# execution with input tensor using float and scale != 1
scale = np.random.uniform(low=0, high=1, size=(1, channels, 1, 1)).astype(
np.float32
)
inp_tensor = inp * scale
input_tensor = torch.from_numpy(inp_tensor).float()
input_scale = torch.from_numpy(scale).float()
input_quant_tensor = pack_quant_tensor(
tensor=input_tensor, scale=input_scale, bit_width=ibw_tensor
)
# export again to set the scale values correctly
bo.export_finn_onnx(b_avgpool, ishape, export_onnx_path, input_t=input_quant_tensor)
model = ModelWrapper(export_onnx_path)
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())
b_avgpool.eval()
expected = b_avgpool.forward(input_quant_tensor).tensor.detach().numpy()
# finn execution
idict = {model.graph.input[0].name: inp_tensor}
odict = oxe.execute_onnx(model, idict, True)
produced = odict[model.graph.output[0].name]
assert np.isclose(expected, produced).all()
os.remove(export_onnx_path)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment