diff --git a/docker/finn_entrypoint.sh b/docker/finn_entrypoint.sh index 65513f3148e0fed2583d02e1eba249bc9a1f2f6e..0074cce02f7de57dc778e0b671c484233df72a8a 100644 --- a/docker/finn_entrypoint.sh +++ b/docker/finn_entrypoint.sh @@ -13,7 +13,7 @@ gecho () { # checkout the correct dependency repo commits # the repos themselves are cloned in the Dockerfile -BREVITAS_COMMIT=7696326e5f279cacffd5b6ac8d9e8d81deec3978 +BREVITAS_COMMIT=026a509186b7e7b0b65d46a2f905043d41069306 CNPY_COMMIT=4e8810b1a8637695171ed346ce68f6984e585ef4 HLSLIB_COMMIT=13e9b0772a27a3a1efc40c878d8e78ed09efb716 PYVERILATOR_COMMIT=c97a5ba41bbc7c419d6f25c74cdf3bdc3393174f diff --git a/src/finn/custom_op/quantavgpool2d.py b/src/finn/custom_op/quantavgpool2d.py new file mode 100644 index 0000000000000000000000000000000000000000..3bc328a9f4f6670041d33491d58af6c553bafac9 --- /dev/null +++ b/src/finn/custom_op/quantavgpool2d.py @@ -0,0 +1,83 @@ +import numpy as np +from onnx import TensorProto, helper +import onnxruntime as rt + +from finn.custom_op import CustomOp +from finn.core.datatype import DataType + + +class QuantAvgPool2d(CustomOp): + """Class that corresponds to the quantized average pooling + layer from brevitas""" + + def get_nodeattr_types(self): + return { + "stride": ("i", True, 1), + "kernel": ("i", True, 1), + "ibits": ("i", True, 1), + "obits": ("i", True, 1), + "signed": ("i", True, 0), + } + + def make_shape_compatible_op(self, model): + node = self.onnx_node + k = self.get_nodeattr("kernel") + s = self.get_nodeattr("stride") + return helper.make_node( + "AveragePool", + inputs=[node.input[0]], + outputs=[node.output[0]], + kernel_shape=[k, k], + strides=[s, s], + ) + + def infer_node_datatype(self, model): + node = self.onnx_node + bw = self.get_nodeattr("obits") + if bw in [2, 4, 8, 16, 32]: + if self.get_nodeattr("signed") == 0: + dtype = DataType["UINT%d" % bw] + else: + dtype = DataType["INT%d" % bw] + else: + raise Exception("Unsupported output datatype for QuantAvgPool2d") + model.set_tensor_datatype(node.output[0], dtype) + + def execute_node(self, context, graph): + # create a standard average pooling node to help calculate the result + node = self.onnx_node + k = self.get_nodeattr("kernel") + s = self.get_nodeattr("stride") + ishape = context[node.input[0]].shape + oshape = context[node.output[0]].shape + inp = helper.make_tensor_value_info(node.input[0], TensorProto.FLOAT, ishape) + outp = helper.make_tensor_value_info(node.output[0], TensorProto.FLOAT, oshape) + node_avgpool = helper.make_node( + "AveragePool", + inputs=[node.input[0]], + outputs=[node.output[0]], + kernel_shape=[k, k], + strides=[s, s], + ) + graph_avgpool = helper.make_graph( + nodes=[node_avgpool], + name="single-avgpool-exec", + inputs=[inp], + outputs=[outp], + ) + model_avgpool = helper.make_model(graph_avgpool) + idict = {node.input[0]: context[node.input[0]]} + sess = rt.InferenceSession(model_avgpool.SerializeToString()) + result_temp = sess.run(None, idict) + # remove scaling introduced by average + result_temp = result_temp[0] * (k * k) + ibits = self.get_nodeattr("ibits") + max_value = 2 ** ibits - 1 + max_value = max_value * k * k + max_bit_width = int(max_value).bit_length() + shift_bits = max_bit_width - self.get_nodeattr("obits") + result = np.right_shift(result_temp.astype(int), shift_bits) + context[node.output[0]] = result.astype(np.float32) + + def verify_node(self): + pass diff --git a/src/finn/custom_op/registry.py b/src/finn/custom_op/registry.py index 614a3d7ffd70d0b102bad2b76177a2d3b32765c7..2dae826cf9712bef17d0053a0878c41ef51fec36 100644 --- a/src/finn/custom_op/registry.py +++ b/src/finn/custom_op/registry.py @@ -48,6 +48,7 @@ from finn.custom_op.fpgadataflow.fmpadding import FMPadding_Batch from finn.custom_op.fpgadataflow.thresholding_batch import Thresholding_Batch from finn.custom_op.fpgadataflow.addstreams_batch import AddStreams_Batch from finn.custom_op.fpgadataflow.labelselect_batch import LabelSelect_Batch +from finn.custom_op.quantavgpool2d import QuantAvgPool2d from finn.custom_op.fpgadataflow.duplicatestreams_batch import DuplicateStreams_Batch # create a mapping of all known CustomOp names and classes @@ -69,6 +70,7 @@ custom_op["FMPadding_Batch"] = FMPadding_Batch custom_op["Thresholding_Batch"] = Thresholding_Batch custom_op["AddStreams_Batch"] = AddStreams_Batch custom_op["LabelSelect_Batch"] = LabelSelect_Batch +custom_op["QuantAvgPool2d"] = QuantAvgPool2d custom_op["DuplicateStreams_Batch"] = DuplicateStreams_Batch diff --git a/src/finn/transformation/infer_datatypes.py b/src/finn/transformation/infer_datatypes.py index 1acd4e3abe2d77248810cf15c15475e806a3bd32..39b7a787be8c725e7b6d474757dd96fc4848dfe0 100644 --- a/src/finn/transformation/infer_datatypes.py +++ b/src/finn/transformation/infer_datatypes.py @@ -71,7 +71,13 @@ def _infer_node_datatype(model, node): else: # unknown, assume node produces float32 outputs for o in node.output: - model.set_tensor_datatype(o, DataType.FLOAT32) + # check if output datatype is already set to a value != FLOAT32 + odtype = model.get_tensor_datatype(o) + if odtype is not None and odtype != DataType.FLOAT32: + # don't change data type + model.set_tensor_datatype(o, odtype) + else: + model.set_tensor_datatype(o, DataType.FLOAT32) # compare old and new output dtypes to see if anything changed new_odtypes = list(map(lambda x: model.get_tensor_datatype(x), node.output)) graph_modified = new_odtypes != odtypes diff --git a/src/finn/transformation/streamline/collapse_repeated.py b/src/finn/transformation/streamline/collapse_repeated.py index 67824ad4f633983b93e3178d03118927a1ddd85b..769bed841ce07c1c9c62f762de4b2c0937a6d68f 100644 --- a/src/finn/transformation/streamline/collapse_repeated.py +++ b/src/finn/transformation/streamline/collapse_repeated.py @@ -30,6 +30,7 @@ from onnx import helper as oh from finn.transformation import Transformation from finn.transformation.infer_shapes import InferShapes +from finn.core.datatype import DataType class CollapseRepeatedOp(Transformation): @@ -83,6 +84,9 @@ class CollapseRepeatedOp(Transformation): graph.node.insert(node_ind, new_node) # replace parameter value model.set_initializer(new_node_param_name, new_param) + # be conservative with param/output DataTypes + model.set_tensor_datatype(new_node_param_name, DataType.FLOAT32) + model.set_tensor_datatype(end_name, DataType.FLOAT32) # remove old nodes graph.node.remove(n) graph.node.remove(consumer) diff --git a/src/finn/util/basic.py b/src/finn/util/basic.py index 585d6b90a59ca9f3dac56a6133de705c2f56f025..809b34157ee4b7890a4155bd09f33dcc85c6ceec 100644 --- a/src/finn/util/basic.py +++ b/src/finn/util/basic.py @@ -106,6 +106,14 @@ def get_finn_root(): ) +def get_execution_error_thresh(): + "Return the max error that is allowed for rounding in FINN execution." + try: + return float(os.environ["ERROR_THRESH"]) + except KeyError: + return 1e-2 + + def make_build_dir(prefix=""): """Creates a temporary folder with given prefix to be used as a build dir. Use this function instead of tempfile.mkdtemp to ensure any generated files @@ -305,7 +313,7 @@ def sanitize_quant_values(model, node_tensors, execution_context, check_values=F ) # check if rounded values are not too far from original values max_error = max(np.abs(current_values - updated_values).flatten()) - if max_error <= 1e-4: + if max_error <= get_execution_error_thresh(): if check_values is True: # check again if values can now be represented with set finn datatype # TODO: vectorize with numpy diff --git a/tests/brevitas/test_brevitas_avg_pool_export.py b/tests/brevitas/test_brevitas_avg_pool_export.py new file mode 100644 index 0000000000000000000000000000000000000000..24854a2153df9af78feb8352ca119e831a9ac9eb --- /dev/null +++ b/tests/brevitas/test_brevitas_avg_pool_export.py @@ -0,0 +1,103 @@ +import os + +import onnx # noqa +import torch +import numpy as np +import brevitas.onnx as bo +from brevitas.nn import QuantAvgPool2d +from brevitas.quant_tensor import pack_quant_tensor +from brevitas.core.quant import QuantType +from finn.core.modelwrapper import ModelWrapper +from finn.core.datatype import DataType +from finn.transformation.infer_shapes import InferShapes +from finn.transformation.infer_datatypes import InferDataTypes +from finn.util.basic import gen_finn_dt_tensor +import finn.core.onnx_exec as oxe + +import pytest + +export_onnx_path = "test_avg_pool.onnx" + + +@pytest.mark.parametrize("kernel_size", [2, 3]) +@pytest.mark.parametrize("stride", [1, 2]) +@pytest.mark.parametrize("signed", [False, True]) +@pytest.mark.parametrize("bit_width", [2, 4]) +@pytest.mark.parametrize("input_bit_width", [4, 8, 32]) +@pytest.mark.parametrize("channels", [2, 4]) +@pytest.mark.parametrize("idim", [7, 8]) +def test_brevitas_avg_pool_export( + kernel_size, stride, signed, bit_width, input_bit_width, channels, idim +): + ishape = (1, channels, idim, idim) + ibw_tensor = torch.Tensor([input_bit_width]) + + b_avgpool = QuantAvgPool2d( + kernel_size=kernel_size, + stride=stride, + signed=signed, + min_overall_bit_width=bit_width, + max_overall_bit_width=bit_width, + quant_type=QuantType.INT, + ) + # call forward pass manually once to cache scale factor and bitwidth + input_tensor = torch.from_numpy(np.zeros(ishape)).float() + scale = np.ones((1, channels, 1, 1)) + output_scale = torch.from_numpy(scale).float() + input_quant_tensor = pack_quant_tensor( + tensor=input_tensor, scale=output_scale, bit_width=ibw_tensor + ) + bo.export_finn_onnx(b_avgpool, ishape, export_onnx_path, input_t=input_quant_tensor) + model = ModelWrapper(export_onnx_path) + + # determine input FINN datatype + if signed is True: + prefix = "INT" + else: + prefix = "UINT" + dt_name = prefix + str(input_bit_width // 2) + dtype = DataType[dt_name] + model = model.transform(InferShapes()) + model = model.transform(InferDataTypes()) + + # execution with input tensor using integers and scale = 1 + # calculate golden output + inp = gen_finn_dt_tensor(dtype, ishape) + input_tensor = torch.from_numpy(inp).float() + input_quant_tensor = pack_quant_tensor( + tensor=input_tensor, scale=output_scale, bit_width=ibw_tensor + ) + b_avgpool.eval() + expected = b_avgpool.forward(input_quant_tensor).tensor.detach().numpy() + + # finn execution + idict = {model.graph.input[0].name: inp} + odict = oxe.execute_onnx(model, idict, True) + produced = odict[model.graph.output[0].name] + assert (expected == produced).all() + + # execution with input tensor using float and scale != 1 + scale = np.random.uniform(low=0, high=1, size=(1, channels, 1, 1)).astype( + np.float32 + ) + inp_tensor = inp * scale + input_tensor = torch.from_numpy(inp_tensor).float() + input_scale = torch.from_numpy(scale).float() + input_quant_tensor = pack_quant_tensor( + tensor=input_tensor, scale=input_scale, bit_width=ibw_tensor + ) + # export again to set the scale values correctly + bo.export_finn_onnx(b_avgpool, ishape, export_onnx_path, input_t=input_quant_tensor) + model = ModelWrapper(export_onnx_path) + model = model.transform(InferShapes()) + model = model.transform(InferDataTypes()) + b_avgpool.eval() + expected = b_avgpool.forward(input_quant_tensor).tensor.detach().numpy() + # finn execution + idict = {model.graph.input[0].name: inp_tensor} + odict = oxe.execute_onnx(model, idict, True) + produced = odict[model.graph.output[0].name] + + assert np.isclose(expected, produced).all() + + os.remove(export_onnx_path)