From 28f93d5ac66c0fc09c7af5aa24afb58d3909d7fe Mon Sep 17 00:00:00 2001 From: Hendrik Borras <hendrikborras@web.de> Date: Tue, 19 Oct 2021 13:33:30 +0100 Subject: [PATCH] Refactored filter lambda for ConvertQuantActToMultiThreshold transformation to only use the function for filtering. --- src/finn/builder/build_dataflow_steps.py | 7 +- .../qonnx/convert_qonnx_to_finn.py | 37 ++++---- .../qonnx/quant_act_to_multithreshold.py | 88 +++++++++++-------- 3 files changed, 73 insertions(+), 59 deletions(-) diff --git a/src/finn/builder/build_dataflow_steps.py b/src/finn/builder/build_dataflow_steps.py index 7fce5b35f..96cda8e0e 100644 --- a/src/finn/builder/build_dataflow_steps.py +++ b/src/finn/builder/build_dataflow_steps.py @@ -96,6 +96,9 @@ from finn.transformation.infer_shapes import InferShapes from finn.transformation.lower_convs_to_matmul import LowerConvsToMatMul from finn.transformation.move_reshape import RemoveCNVtoFCFlatten from finn.transformation.qonnx.convert_qonnx_to_finn import ConvertQONNXtoFINN +from finn.transformation.qonnx.quant_act_to_multithreshold import ( + default_filter_function_generator, +) from finn.transformation.streamline import Streamline from finn.transformation.streamline.reorder import MakeMaxPoolNHWC from finn.util.config import extract_model_config_to_json @@ -208,7 +211,9 @@ def step_qonnx_to_finn(model: ModelWrapper, cfg: DataflowBuildConfig): # QONNX to FINN-ONNX model = model.transform( ConvertQONNXtoFINN( - max_multithreshold_bit_width=cfg.max_multithreshold_bit_width + filter_function=default_filter_function_generator( + max_multithreshold_bit_width=cfg.max_multithreshold_bit_width + ) ) ) diff --git a/src/finn/transformation/qonnx/convert_qonnx_to_finn.py b/src/finn/transformation/qonnx/convert_qonnx_to_finn.py index f2b4cc3ec..b1b1dd0bb 100644 --- a/src/finn/transformation/qonnx/convert_qonnx_to_finn.py +++ b/src/finn/transformation/qonnx/convert_qonnx_to_finn.py @@ -36,6 +36,7 @@ from finn.transformation.qonnx.fold_quant_weights import FoldQuantWeights from finn.transformation.qonnx.infer_QuantAvgPool2d import AvgPoolAndTruncToQuantAvgPool from finn.transformation.qonnx.quant_act_to_multithreshold import ( ConvertQuantActToMultiThreshold, + default_filter_function_generator, ) from finn.transformation.remove import RemoveEmptyPadding @@ -47,31 +48,28 @@ class ConvertQONNXtoFINN(Transformation): the activations. If incompatibilities are found a ValueError or RuntimeError is raised. - The optional keyword arguments `max_multithreshold_bit_width` and `filter_lambda` - present a way to control which Quant and BipolarQuant nodes in the activation path - are converted to MultiThreshold nodes. - The filters which are represented by `max_multithreshold_bit_width` and - `filter_lambda` are internally connected by an `AND` operation. A warning - will be emitted when a Quant node is not converted to a MultiThreshold node. + The optional keyword argument `filter_function` + presents a way to control which Quant and BipolarQuant nodes in the activation path + are converted to MultiThreshold nodes. A warning will be emitted when a Quant node + is not converted to a MultiThreshold node. - :param max_multithreshold_bit_width: The value of max_multithreshold_bit_width is - checked against the bit width of any given Quant node and the transformation to a - MultiTrheshold node is rejected, when the bitwidth of the Quant node is larger - than value of max_multithreshold_bit_with. Defaults to: 8 - :type max_multithreshold_bit_width: `int`, optional - :param filter_lambda: Each candidate Quant and BinaryQant node is first evaluated - by this lambda function. If the function returns False, + :param filter_function: Each candidate Quant and BinaryQant node is first evaluated + by this function. If the function returns False, then the node is not converted to a MultiTrheshold node. - Defaults to: lambda q_node: True - :type filter_lambda: `lambda`, optional + The function is given the model and candidate node as parameters. + Per default a filter function is inserted, which disables the conversion of + Quant nodes, which have a bit width of larger than 8. + Defaults to: default_filter_function_generator(max_multithreshold_bit_width=8) """ def __init__( - self, max_multithreshold_bit_width=8, filter_lambda=lambda q_node: True + self, + filter_function=default_filter_function_generator( + max_multithreshold_bit_width=8 + ), ): super().__init__() - self.max_multithreshold_bit_width = max_multithreshold_bit_width - self._filter_lambda = filter_lambda + self._filter_function = filter_function def apply(self, model): # Extract the bias from Conv node @@ -86,8 +84,7 @@ class ConvertQONNXtoFINN(Transformation): # Convert activations model = model.transform( ConvertQuantActToMultiThreshold( - max_multithreshold_bit_width=self.max_multithreshold_bit_width, - filter_lambda=self._filter_lambda, + filter_function=self._filter_function, ) ) # Recompute datatypes diff --git a/src/finn/transformation/qonnx/quant_act_to_multithreshold.py b/src/finn/transformation/qonnx/quant_act_to_multithreshold.py index 520ed0b82..29ba93dfc 100644 --- a/src/finn/transformation/qonnx/quant_act_to_multithreshold.py +++ b/src/finn/transformation/qonnx/quant_act_to_multithreshold.py @@ -33,35 +33,64 @@ from finn.transformation.base import Transformation from finn.transformation.qonnx.qonnx_activation_handlers import QuantActBaseHandler +def default_filter_function_generator(max_multithreshold_bit_width=8): + """ + This function generates the default filter function for the + ConvertQuantActToMultiThreshold transformation. Per default the returned + function disables the conversion of Quant nodes which have a bit width above 8 bit. + + This function generator can be used as a template to write custom + filter functions. + """ + + def filter_function(model, q_node): + if q_node.op_type == "Quant": + bit_width = model.get_initializer(q_node.input[3]) + elif q_node.op_type == "BipolarQuant": + bit_width = 1.0 + else: + raise RuntimeError("Got an unexpected quantizer node type") + if bit_width is None: + raise ValueError("Quant nodes must have a static bit width.") + if bit_width > max_multithreshold_bit_width: + warnings.warn( + f'The Quant node with name: "{q_node.name}" was not converted to a ' + f"MultiThreshold node, because its bit width of {bit_width} is " + f"higher than the configured maximum bit width of " + f"{max_multithreshold_bit_width}." + ) + return False + return True + + return filter_function + + class ConvertQuantActToMultiThreshold(Transformation): """ Converts Quant nodes in the activation path to MultiThreshold nodes. - The optional keyword arguments `max_multithreshold_bit_width` and `filter_lambda` - present a way to control which Quant and BipolarQuant nodes in the activation path - are converted to MultiThreshold nodes. - The filters which are represented by `max_multithreshold_bit_width` and - `filter_lambda` are internally connected by an `AND` operation. A warning - will be emitted when a Quant node is not converted to a MultiThreshold node. - - :param max_multithreshold_bit_width: The value of max_multithreshold_bit_width is - checked against the bit width of any given Quant node and the transformation to a - MultiTrheshold node is rejected, when the bitwidth of the Quant node is larger - than value of max_multithreshold_bit_with. Defaults to: 8 - :type max_multithreshold_bit_width: `int`, optional - :param filter_lambda: Each candidate Quant and BinaryQant node is first evaluated - by this lambda function. If the function returns False, + The optional keyword argument `filter_function` + presents a way to control which Quant and BipolarQuant nodes in the activation path + are converted to MultiThreshold nodes. A warning will be emitted when a Quant node + is not converted to a MultiThreshold node. + + :param filter_function: Each candidate Quant and BinaryQant node is first evaluated + by this function. If the function returns False, then the node is not converted to a MultiTrheshold node. - Defaults to: lambda q_node: True - :type filter_lambda: `lambda`, optional + The function is given the model and candidate node as parameters. + Per default a filter function is inserted, which disables the conversion of + Quant nodes, which have a bit width of larger than 8. + Defaults to: default_filter_function_generator(max_multithreshold_bit_width=8) """ def __init__( - self, max_multithreshold_bit_width=8, filter_lambda=lambda q_node: True + self, + filter_function=default_filter_function_generator( + max_multithreshold_bit_width=8 + ), ): super().__init__() - self.max_multithreshold_bit_width = max_multithreshold_bit_width - self._filter_lambda = filter_lambda + self._filter_function = filter_function def apply(self, model): graph = model.graph @@ -91,28 +120,11 @@ class ConvertQuantActToMultiThreshold(Transformation): "Only Quant nodes with zero-point == 0 are currently supported." ) - # Check if the bit width is low enough - if n.op_type == "Quant": - bit_width = model.get_initializer(n.input[3]) - elif n.op_type == "BipolarQuant": - bit_width = 1.0 - else: - raise RuntimeError("Got an unexpected quantizer node type") - if bit_width is None: - raise ValueError("Quant nodes must have a static bit width.") - if bit_width > self.max_multithreshold_bit_width: - warnings.warn( - f'The Quant node with name: "{n.name}" was not converted to a ' - f"MultiThreshold node, because its bit width of {bit_width} is " - f"higher than the configured maximum bit width of " - f"{self.max_multithreshold_bit_width}." - ) - continue # Check that this node passes the user filter - if not self._filter_lambda(n): + if not self._filter_function(model, n): warnings.warn( f'The Quant node with name: "{n.name}" was not converted to a ' - f"MultiThreshold node, because the filtering lambda function " + f"MultiThreshold node, because the filtering function " f"returned False for this node." ) continue -- GitLab