diff --git a/src/finn/builder/build_dataflow_config.py b/src/finn/builder/build_dataflow_config.py index 49f800f62a6f9c646c5223504d09cb90d3d144f7..4d91845455be67b6e27c48f2adc60bb2b2b4c363 100644 --- a/src/finn/builder/build_dataflow_config.py +++ b/src/finn/builder/build_dataflow_config.py @@ -295,6 +295,14 @@ class DataflowBuildConfig: #: If given, stop at this step. stop_step: Optional[str] = None + #: The optional argument `max_multithreshold_bit_width` affects which Quant nodes + #: of the QONNX format get converted to the MultiThreshold nodes of FINN. This + #: only affects Quant nodes in the activation path. Quant nodes, which define a + #: bit width larger than `max_multithreshold_bit_width` are not converted to + #: MultiThreshold nodes and a warning is raised instead. + #: If not given `max_multithreshold_bit_width` defaults to 4. + max_multithreshold_bit_width: Optional[int] = 4 + def _resolve_hls_clk_period(self): if self.hls_clk_period_ns is None: # use same clk for synth and hls if not explicitly specified diff --git a/src/finn/builder/build_dataflow_steps.py b/src/finn/builder/build_dataflow_steps.py index 2216fcb3535360b6da5a43942bc46e81780fa341..7fce5b35faedb823254cd513e1c4e844d129a277 100644 --- a/src/finn/builder/build_dataflow_steps.py +++ b/src/finn/builder/build_dataflow_steps.py @@ -206,7 +206,11 @@ def step_qonnx_to_finn(model: ModelWrapper, cfg: DataflowBuildConfig): # QONNX cleanup model = cleanup_model(model) # QONNX to FINN-ONNX - model = model.transform(ConvertQONNXtoFINN()) + model = model.transform( + ConvertQONNXtoFINN( + max_multithreshold_bit_width=cfg.max_multithreshold_bit_width + ) + ) if VerificationStepType.QONNX_TO_FINN_PYTHON in cfg._resolve_verification_steps(): verify_step(model, cfg, "initial_python", need_parent=False) diff --git a/src/finn/transformation/qonnx/convert_qonnx_to_finn.py b/src/finn/transformation/qonnx/convert_qonnx_to_finn.py index 184a9785322fa9603f7f20dcc914dc68d41462a8..d4eba367c305d4228d572e9f082b191c4081166e 100644 --- a/src/finn/transformation/qonnx/convert_qonnx_to_finn.py +++ b/src/finn/transformation/qonnx/convert_qonnx_to_finn.py @@ -43,8 +43,33 @@ class ConvertQONNXtoFINN(Transformation): then the ConvertQuantActToMultiThreshold transformation is used to convert the activations. If incompatibilities are found a ValueError or RuntimeError is raised. + + The optional keyword arguments `max_multithreshold_bit_width` and `filter_lambda` + present a way to control which Quant and BinaryQuant nodes in the activation path + are converted to MultiThreshold nodes. + The filters which are represented by `max_multithreshold_bit_width` and + `filter_lambda` are internally connected by an `AND` operation. A warning + will be emitted when a Quant node is not converted to a MultiThreshold node. + + :param max_multithreshold_bit_width: The value of max_multithreshold_bit_width is + checked against the bit width of any given Quant node and the transformation to a + MultiTrheshold node is rejected, when the bitwidth of the Quant node is larger + than value of max_multithreshold_bit_with. Defaults to: 4 + :type max_multithreshold_bit_width: `int`, optional + :param filter_lambda: Each candidate Quant and BinaryQant node is first evaluated + by this lambda function. If the function returns False, + then the node is not converted to a MultiTrheshold node. + Defaults to: lambda q_node: True + :type filter_lambda: `lambda`, optional """ + def __init__( + self, max_multithreshold_bit_width=4, filter_lambda=lambda q_node: True + ): + super().__init__() + self.max_multithreshold_bit_width = max_multithreshold_bit_width + self._filter_lambda = filter_lambda + def apply(self, model): # Gemm operations are not supported by FINN, so we convert them to MatMul model = model.transform(GemmToMatMul()) @@ -54,7 +79,12 @@ class ConvertQONNXtoFINN(Transformation): # Fold weights model = model.transform(FoldQuantWeights()) # Convert activations - model = model.transform(ConvertQuantActToMultiThreshold()) + model = model.transform( + ConvertQuantActToMultiThreshold( + max_multithreshold_bit_width=self.max_multithreshold_bit_width, + filter_lambda=self._filter_lambda, + ) + ) # Recompute datatypes model = model.transform(InferDataTypes()) diff --git a/src/finn/transformation/qonnx/quant_act_to_multithreshold.py b/src/finn/transformation/qonnx/quant_act_to_multithreshold.py index 641cef0c3aaaf5b452dac7c8831a35242bab7313..cd203e63f88323c867e5235a2a2f774b505a205a 100644 --- a/src/finn/transformation/qonnx/quant_act_to_multithreshold.py +++ b/src/finn/transformation/qonnx/quant_act_to_multithreshold.py @@ -27,12 +27,41 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import warnings + from finn.transformation.base import Transformation from finn.transformation.qonnx.qonnx_activation_handlers import QuantActBaseHandler class ConvertQuantActToMultiThreshold(Transformation): - """Converts Quant nodes in the activation path to MultiThreshold nodes.""" + """ + Converts Quant nodes in the activation path to MultiThreshold nodes. + + The optional keyword arguments `max_multithreshold_bit_width` and `filter_lambda` + present a way to control which Quant and BinaryQuant nodes in the activation path + are converted to MultiThreshold nodes. + The filters which are represented by `max_multithreshold_bit_width` and + `filter_lambda` are internally connected by an `AND` operation. A warning + will be emitted when a Quant node is not converted to a MultiThreshold node. + + :param max_multithreshold_bit_width: The value of max_multithreshold_bit_width is + checked against the bit width of any given Quant node and the transformation to a + MultiTrheshold node is rejected, when the bitwidth of the Quant node is larger + than value of max_multithreshold_bit_with. Defaults to: 4 + :type max_multithreshold_bit_width: `int`, optional + :param filter_lambda: Each candidate Quant and BinaryQant node is first evaluated + by this lambda function. If the function returns False, + then the node is not converted to a MultiTrheshold node. + Defaults to: lambda q_node: True + :type filter_lambda: `lambda`, optional + """ + + def __init__( + self, max_multithreshold_bit_width=4, filter_lambda=lambda q_node: True + ): + super().__init__() + self.max_multithreshold_bit_width = max_multithreshold_bit_width + self._filter_lambda = filter_lambda def apply(self, model): graph = model.graph @@ -61,6 +90,27 @@ class ConvertQuantActToMultiThreshold(Transformation): "Only Quant nodes with zero-point == 0 are currently supported." ) + # Check if the bit width is low enough + bit_width = model.get_initializer(n.input[3]) + if bit_width is None: + raise ValueError("Quant nodes must have a static bit width.") + if bit_width > self.max_multithreshold_bit_width: + warnings.warn( + f'The Quant node with name: "{n.name}" was not converted to a ' + f"MultiThreshold node, because its bit width of {bit_width} is " + f"higher than the configured maximum bit width of " + f"{self.max_multithreshold_bit_width}." + ) + continue + # Check that this node passes the user filter + if not self._filter_lambda(n): + warnings.warn( + f'The Quant node with name: "{n.name}" was not converted to a ' + f"MultiThreshold node, because the filtering lambda function " + f"returned False for this node." + ) + continue + # Check for possible ambiguity in handler selection valid_predecessors = [] for cls in QuantActBaseHandler.__subclasses__():