From 5f3cf6520121894a00c32baf5f6d4aef42d1bd14 Mon Sep 17 00:00:00 2001 From: Hendrik Borras <hendrikborras@web.de> Date: Thu, 14 Oct 2021 16:01:56 +0100 Subject: [PATCH] Added support for BinaryQuant conversion. --- docker/Dockerfile.finn | 6 +- .../qonnx/fold_quant_weights.py | 47 +++-- .../qonnx/qonnx_activation_handlers.py | 167 +++++++++++------- .../qonnx/quant_act_to_multithreshold.py | 14 +- tests/transformation/test_qonnx_to_finn.py | 7 +- 5 files changed, 156 insertions(+), 85 deletions(-) diff --git a/docker/Dockerfile.finn b/docker/Dockerfile.finn index dfb6b8658..3856065f2 100644 --- a/docker/Dockerfile.finn +++ b/docker/Dockerfile.finn @@ -86,10 +86,10 @@ RUN pip install -e git+https://github.com/fbcotter/dataset_loading.git@0.0.4#egg # git-based Python repo dependencies # these are installed in editable mode for easier co-development -ARG FINN_BASE_COMMIT="138f177a2729334444d5f58d96c32a1bf4d4b6c2" -ARG QONNX_COMMIT="02b15f56d199576cefe9aff672d5e009349e402a" +ARG FINN_BASE_COMMIT="ec3997c3f4276f7746bfd08a1a9508bd02a132fa" +ARG QONNX_COMMIT="834610ba3f668971fe2800fde7f8d0c10d825d5b" ARG FINN_EXP_COMMIT="f82c0d9868bb88ea045dfadb28508d327d287221" -ARG BREVITAS_COMMIT="462f86cdc60f9915baf13afd1676fb21da44c2ee" +ARG BREVITAS_COMMIT="0eaff006407955153594254728baeb988edcd042" ARG PYVERILATOR_COMMIT="0c3eb9343500fc1352a02c020a736c8c2db47e8e" ARG CNPY_COMMIT="4e8810b1a8637695171ed346ce68f6984e585ef4" ARG HLSLIB_COMMIT="fbb07135b3d991602e8abe3f2c51212c11fd392b" diff --git a/src/finn/transformation/qonnx/fold_quant_weights.py b/src/finn/transformation/qonnx/fold_quant_weights.py index f7c1725e6..7dd942307 100644 --- a/src/finn/transformation/qonnx/fold_quant_weights.py +++ b/src/finn/transformation/qonnx/fold_quant_weights.py @@ -31,10 +31,24 @@ from onnx import TensorProto, helper import finn.core.onnx_exec as oxe from finn.core.datatype import DataType +from finn.custom_op.registry import getCustomOp from finn.transformation.base import Transformation from finn.transformation.infer_shapes import InferShapes +def get_dtype(bit_width: int, signed: bool) -> DataType: + bit_width = int(bit_width) + signed = bool(signed) + if bit_width == 1.0: + finn_dt = DataType["BIPOLAR"] + else: + if signed: + finn_dt = DataType["INT" + str(bit_width)] + else: + finn_dt = DataType["UINT" + str(bit_width)] + return finn_dt + + class FoldQuantWeights(Transformation): """Merges Quant nodes, which are used as weights into the initializer of the weight tensor. @@ -47,7 +61,7 @@ class FoldQuantWeights(Transformation): execution_context = model.make_empty_exec_context() for n in graph.node: node_ind += 1 - if n.op_type == "Quant": + if n.op_type == "Quant" or n.op_type == "BinaryQuant": node_inp_inits = list(map(lambda x: model.get_initializer(x), n.input)) node_inp_dyn = list(filter(lambda x: x is None, node_inp_inits)) node_out = n.output[0] @@ -55,18 +69,25 @@ class FoldQuantWeights(Transformation): ishape = model.get_tensor_shape(n.input[0]) is_const_shape = (n.op_type == "Shape") and (ishape is not None) if is_all_constant_inputs or is_const_shape: - if not model.get_initializer(n.input[2]) == 0: + if ( + n.op_type == "Quant" + and not model.get_initializer(n.input[2]) == 0 + ): raise ValueError( "Only Quant nodes with zero-point == 0 " "are currently supported." ) + scale = model.get_initializer(n.input[1]) + unity_scale = (scale.flatten() == 1.0).all() # this node has no dynamic inputs, only constant ones -- so we can # do constant folding. oxe.execute_node(n, execution_context, graph) q_node_output = execution_context[node_out] - # Check if the datatype can be directly constant folded - dtype = model.get_tensor_datatype(n.output[0]) - if "SCALED" in dtype.name: + # Check we can directly constant fold + if unity_scale: + # use the execution result as an initializer + model.set_initializer(node_out, q_node_output) + else: # Reshape scale for Conv if required if model.is_fork_node(n): raise RuntimeError( @@ -95,14 +116,23 @@ class FoldQuantWeights(Transformation): f"at node: {target_node}." ) - # For buth mul and Add: + # For both mul and Add: # Move the scale factor behind the next operator scale = model.get_initializer(n.input[1]) new_initializer = q_node_output / scale # Round, to correct for floating point errors new_initializer = np.round(new_initializer) model.set_initializer(node_out, new_initializer) - new_dtype = DataType[dtype.name.replace("SCALED", "")] + if n.op_type == "Quant": + bit_width = model.get_initializer(n.input[3]) + q_inst = getCustomOp(n) + signed = q_inst.get_nodeattr("signed") + elif n.op_type == "BinaryQuant": + bit_width = 1.0 + signed = True + else: + raise RuntimeError("Got an unexpected quantizer node type") + new_dtype = get_dtype(bit_width, signed) model.set_tensor_datatype(node_out, new_dtype) if target_node.op_type == "Conv" and len(scale.shape) > 0: @@ -175,9 +205,6 @@ class FoldQuantWeights(Transformation): ) graph.node.insert(node_ind, div_node) - else: - # use the execution result as an initializer - model.set_initializer(node_out, q_node_output) # remove old node graph.node.remove(n) graph_modified = True diff --git a/src/finn/transformation/qonnx/qonnx_activation_handlers.py b/src/finn/transformation/qonnx/qonnx_activation_handlers.py index 3b1d84611..0814aecab 100644 --- a/src/finn/transformation/qonnx/qonnx_activation_handlers.py +++ b/src/finn/transformation/qonnx/qonnx_activation_handlers.py @@ -160,7 +160,7 @@ class QuantActBaseHandler(ABC): # behind the MultiTrheshold nodes. bias_scalar = adder_bias.shape == (1,) or len(adder_bias.shape) == 0 scale_scalar = mul_scale.shape == (1,) or len(mul_scale.shape) == 0 - if scale_scalar and bias_scalar and False: + if scale_scalar and bias_scalar and self._q_node.op_type == "BinaryQuant": # Get Quant parameters mul_scale = np.atleast_1d(mul_scale) # ONNX only accepts 64bit floats as attributes @@ -173,10 +173,11 @@ class QuantActBaseHandler(ABC): # FINN applies scale first then bias, # which is the other way around in Brevitas, # we thus need to adjust the bias in the MultiThreshold node - mt_inst.set_nodeattr("out_bias", adder_bias[0] * mul_scale[0]) + finn_bias = adder_bias[0] * mul_scale[0] + mt_inst.set_nodeattr("out_bias", finn_bias) # If the bias and scale are integers, then the output will be as well. - if adder_bias % 1 == 0 and mul_scale % 1 == 0: + if finn_bias % 1 == 0 and mul_scale % 1 == 0: mt_inst.set_nodeattr("out_dtype", out_dtype) else: # Set datatype @@ -289,19 +290,24 @@ class QuantReluHandler(QuantActBaseHandler): ] def _check_compatibility(self): - q_inst = getCustomOp(self._q_node) - narrow = q_inst.get_nodeattr("narrow") - signed = q_inst.get_nodeattr("signed") - if signed or narrow: - raise ValueError( - "FINN only supports unsigned and non-narrow Quant nodes " - "for Relu activations." - ) - if not self._model.get_initializer(self._q_node.input[2]) == 0: - raise ValueError( - "Only Quant nodes with zero-point == 0 " - "are currently supported for ReLu activations." - ) + if self._q_node.op_type == "Quant": + q_inst = getCustomOp(self._q_node) + narrow = q_inst.get_nodeattr("narrow") + signed = q_inst.get_nodeattr("signed") + if signed or narrow: + raise ValueError( + "FINN only supports unsigned and non-narrow Quant nodes " + "for Relu activations." + ) + if not self._model.get_initializer(self._q_node.input[2]) == 0: + raise ValueError( + "Only Quant nodes with zero-point == 0 " + "are currently supported for ReLu activations." + ) + elif self._q_node.op_type == "BinaryQuant": + return + else: + raise RuntimeError("Got an unexpected quantizer node type") def _calculate_act_bias(self): # No bias allowed for Relu activations, see: https://github.com/Xilinx/ @@ -312,7 +318,12 @@ class QuantReluHandler(QuantActBaseHandler): def _calculate_thresholds(self): # Gather parameters - bit_width = self._model.get_initializer(self._q_node.input[3]) + if self._q_node.op_type == "Quant": + bit_width = self._model.get_initializer(self._q_node.input[3]) + elif self._q_node.op_type == "BinaryQuant": + bit_width = 1.0 + else: + raise RuntimeError("Got an unexpected quantizer node type") quant_scale = self._model.get_initializer(self._q_node.input[1]) # q_inst = getCustomOp(self._q_node) # narrow = q_inst.get_nodeattr("narrow") @@ -388,27 +399,42 @@ class QuantIdentityHandler(QuantActBaseHandler): def _check_compatibility(self): # Gather parameters to check - q_inst = getCustomOp(self._q_node) - signed = q_inst.get_nodeattr("signed") - if not signed: - raise ValueError( - "FINN only supports signed Quant nodes for identity activations." - ) - if not self._model.get_initializer(self._q_node.input[2]) == 0: - raise ValueError( - "Only Quant nodes with zero-point == 0 " - "are currently supported for identity activations." - ) + if self._q_node.op_type == "Quant": + q_inst = getCustomOp(self._q_node) + signed = q_inst.get_nodeattr("signed") + if not signed: + raise ValueError( + "FINN only supports signed Quant nodes for identity activations." + ) + if not self._model.get_initializer(self._q_node.input[2]) == 0: + raise ValueError( + "Only Quant nodes with zero-point == 0 " + "are currently supported for identity activations." + ) + elif self._q_node.op_type == "BinaryQuant": + quant_scale = self._model.get_initializer(self._q_node.input[1]) + if (quant_scale.flatten().shape[0] != 1) or quant_scale.flatten()[0] != 1.0: + raise ValueError( + "FINN only supports Bipolar identity activations " + "with out per channel scaling and the scaling must be 1. " + ) + else: + raise RuntimeError("Got an unexpected quantizer node type") def _calculate_act_bias(self): # Gather parameters - bit_width = self._model.get_initializer(self._q_node.input[3]) q_inst = getCustomOp(self._q_node) - narrow = q_inst.get_nodeattr("narrow") + if self._q_node.op_type == "Quant": + bit_width = self._model.get_initializer(self._q_node.input[3]) + narrow = q_inst.get_nodeattr("narrow") + elif self._q_node.op_type == "BinaryQuant": + bit_width = 1.0 + else: + raise RuntimeError("Got an unexpected quantizer node type") # Calculate bias, see: https://github.com/Xilinx/brevitas/blob/ # a5bfd6dc5e030f0047ac1ee47932b60e8e873e17/src/brevitas/export/ # onnx/finn/handler/act.py#L64 - if bit_width == 1: + if bit_width == 1.0: bias = np.array([-0.5]) else: if narrow: @@ -420,45 +446,62 @@ class QuantIdentityHandler(QuantActBaseHandler): def _calculate_thresholds(self): # Gather parameters - bit_width = self._model.get_initializer(self._q_node.input[3]) quant_scale = self._model.get_initializer(self._q_node.input[1]) q_inst = getCustomOp(self._q_node) - narrow = q_inst.get_nodeattr("narrow") + if self._q_node.op_type == "Quant": + bit_width = self._model.get_initializer(self._q_node.input[3]) + narrow = q_inst.get_nodeattr("narrow") + elif self._q_node.op_type == "BinaryQuant": + bit_width = 1.0 + else: + raise RuntimeError("Got an unexpected quantizer node type") # Calculate thersholds, see: https://github.com/Xilinx/brevitas/ # blob/a5bfd6dc5e030f0047ac1ee47932b60e8e873e17/src/brevitas/ # export/onnx/finn/handler/act.py#L76 - if narrow: - num_distinct_values = 2 ** bit_width - 1 + if bit_width == 1.0: + thresholds = np.empty([1, 1]) + thresholds[0] = 0 + return thresholds else: - num_distinct_values = 2 ** bit_width - - num_thresholds = int(num_distinct_values - 1) - flat_scale = quant_scale.flatten() - num_scale_channels = flat_scale.shape[0] - step = np.abs(flat_scale) - half_step = step / 2.0 - thresholds = np.empty((num_scale_channels, num_thresholds)) - # compute the value of the smallest threshold, we'll neg-bias all - # generated thresholds by this much - min_threshold = -half_step - step * ((num_thresholds // 2) - 1) - if not narrow: - min_threshold -= step - for c in range(num_scale_channels): - for t in range(num_thresholds): - thresholds[c][t] = min_threshold[c] + step[c] * t - - # ToDo: The index 1 needs to be changed to -1 for the channels last format - num_output_channels = self._model.get_tensor_shape(self._q_node.output[0])[1] - final_shape = (num_output_channels, num_thresholds) - if thresholds.shape != final_shape: - thresholds = np.broadcast_to(thresholds, final_shape) - - return thresholds + if narrow: + num_distinct_values = 2 ** bit_width - 1 + else: + num_distinct_values = 2 ** bit_width + + num_thresholds = int(num_distinct_values - 1) + flat_scale = quant_scale.flatten() + num_scale_channels = flat_scale.shape[0] + step = np.abs(flat_scale) + half_step = step / 2.0 + thresholds = np.empty((num_scale_channels, num_thresholds)) + # compute the value of the smallest threshold, we'll neg-bias all + # generated thresholds by this much + min_threshold = -half_step - step * ((num_thresholds // 2) - 1) + if not narrow: + min_threshold -= step + for c in range(num_scale_channels): + for t in range(num_thresholds): + thresholds[c][t] = min_threshold[c] + step[c] * t + + # ToDo: The index 1 needs to be changed to -1 for the channels last format + num_output_channels = self._model.get_tensor_shape(self._q_node.output[0])[ + 1 + ] + final_shape = (num_output_channels, num_thresholds) + if thresholds.shape != final_shape: + thresholds = np.broadcast_to(thresholds, final_shape) + + return thresholds def _calculate_act_scale(self): # Gather parameters - bit_width = self._model.get_initializer(self._q_node.input[3]) + if self._q_node.op_type == "Quant": + bit_width = self._model.get_initializer(self._q_node.input[3]) + elif self._q_node.op_type == "BinaryQuant": + bit_width = 1.0 + else: + raise RuntimeError("Got an unexpected quantizer node type") quant_scale = self._model.get_initializer(self._q_node.input[1]) # Calculate scale, see: https://github.com/Xilinx/brevitas/ # blob/a5bfd6dc5e030f0047ac1ee47932b60e8e873e17/src/brevitas/ @@ -466,12 +509,10 @@ class QuantIdentityHandler(QuantActBaseHandler): if bit_width != 1: scale = quant_scale else: - # ToDo: This needs testing and/or rewriting when the BinarayQuant op - # comes around assert ( quant_scale.flatten().shape[0] == 1 ), "Unsupported BIPOLAR per channel scale" - assert quant_scale.flatten().item() == 1.0, "Unsupported BIPOLAR scale != 1" + assert quant_scale.flatten()[0] == 1.0, "Unsupported BIPOLAR scale != 1" scale = quant_scale * 2 return scale diff --git a/src/finn/transformation/qonnx/quant_act_to_multithreshold.py b/src/finn/transformation/qonnx/quant_act_to_multithreshold.py index 24b185796..90ccebe8b 100644 --- a/src/finn/transformation/qonnx/quant_act_to_multithreshold.py +++ b/src/finn/transformation/qonnx/quant_act_to_multithreshold.py @@ -70,7 +70,7 @@ class ConvertQuantActToMultiThreshold(Transformation): for n in graph.node: node_ind += 1 - if n.op_type == "Quant": + if n.op_type == "Quant" or n.op_type == "BinaryQuant": # Check that the node is in the activation path inp = model.get_initializer(n.input[0]) out = model.get_initializer(n.output[0]) @@ -83,15 +83,21 @@ class ConvertQuantActToMultiThreshold(Transformation): predecessor_op_type = predecessor if model.is_fork_node(n): raise ValueError( - "Forking Quant nodes are not currently supported by FINN." + "Forking Quant/BinaryQuant nodes are currently " + "not supported by FINN." ) - if not model.get_initializer(n.input[2]) == 0: + if n.op_type == "Quant" and not model.get_initializer(n.input[2]) == 0: raise ValueError( "Only Quant nodes with zero-point == 0 are currently supported." ) # Check if the bit width is low enough - bit_width = model.get_initializer(n.input[3]) + if n.op_type == "Quant": + bit_width = model.get_initializer(n.input[3]) + elif n.op_type == "BinaryQuant": + bit_width = 1.0 + else: + raise RuntimeError("Got an unexpected quantizer node type") if bit_width is None: raise ValueError("Quant nodes must have a static bit width.") if bit_width > self.max_multithreshold_bit_width: diff --git a/tests/transformation/test_qonnx_to_finn.py b/tests/transformation/test_qonnx_to_finn.py index db2292319..5041b409f 100644 --- a/tests/transformation/test_qonnx_to_finn.py +++ b/tests/transformation/test_qonnx_to_finn.py @@ -89,6 +89,7 @@ def analysis_testing_for_no_quant_nodes(model): # ToDo: Add KWS networks, when they are ready to be added to finn-examples. # ToDo: Add RadioML_VGG10, if possible +# This test currently takes about 4 min and 42 seconds @pytest.mark.parametrize("abits", [1, 2]) @pytest.mark.parametrize("wbits", [1, 2]) @pytest.mark.parametrize("model_name", ["TFC", "SFC", "LFC", "CNV", "mobilenet"]) @@ -97,13 +98,9 @@ def test_QONNX_to_FINN(model_name, wbits, abits): pytest.skip("No wbits > abits cases at the moment") if model_name == "LFC" and wbits == 2 and abits == 2: pytest.skip("No LFC-w2a2 present at the moment") - if model_name == "mobilenet" and wbits < 2 and abits < 2: + if model_name == "mobilenet" and (wbits != 2 or abits != 2): pytest.skip("Mobilenet only runs at W2A2, though it's technically W4A4.") - # ToDo: Remove the following restriction when QONNX supports binary operations. - if wbits == 1 or abits == 1: - pytest.skip("wbits == 1 or abits == 1 is currently not supported by QONNX.") - brev_model, in_shape, input_tensor = get_brev_model_and_sample_inputs( model_name, wbits, abits ) -- GitLab