From 72d2343d4e3ac382b0bdf7662367127b670682b2 Mon Sep 17 00:00:00 2001
From: Hendrik Borras <hendrikborras@web.de>
Date: Tue, 12 Oct 2021 18:08:52 +0100
Subject: [PATCH] Added support for filtering which Quant nodes to convert to
 MultiThreshold nodes.

---
 src/finn/builder/build_dataflow_config.py     |  8 +++
 src/finn/builder/build_dataflow_steps.py      |  6 ++-
 .../qonnx/convert_qonnx_to_finn.py            | 32 +++++++++++-
 .../qonnx/quant_act_to_multithreshold.py      | 52 ++++++++++++++++++-
 4 files changed, 95 insertions(+), 3 deletions(-)

diff --git a/src/finn/builder/build_dataflow_config.py b/src/finn/builder/build_dataflow_config.py
index 49f800f62..4d9184545 100644
--- a/src/finn/builder/build_dataflow_config.py
+++ b/src/finn/builder/build_dataflow_config.py
@@ -295,6 +295,14 @@ class DataflowBuildConfig:
     #: If given, stop at this step.
     stop_step: Optional[str] = None
 
+    #: The optional argument `max_multithreshold_bit_width` affects which Quant nodes
+    #: of the QONNX format get converted to the MultiThreshold nodes of FINN. This
+    #: only affects Quant nodes in the activation path. Quant nodes, which define a
+    #: bit width larger than `max_multithreshold_bit_width` are not converted to
+    #: MultiThreshold nodes and a warning is raised instead.
+    #: If not given `max_multithreshold_bit_width` defaults to 4.
+    max_multithreshold_bit_width: Optional[int] = 4
+
     def _resolve_hls_clk_period(self):
         if self.hls_clk_period_ns is None:
             # use same clk for synth and hls if not explicitly specified
diff --git a/src/finn/builder/build_dataflow_steps.py b/src/finn/builder/build_dataflow_steps.py
index 2216fcb35..7fce5b35f 100644
--- a/src/finn/builder/build_dataflow_steps.py
+++ b/src/finn/builder/build_dataflow_steps.py
@@ -206,7 +206,11 @@ def step_qonnx_to_finn(model: ModelWrapper, cfg: DataflowBuildConfig):
     # QONNX cleanup
     model = cleanup_model(model)
     # QONNX to FINN-ONNX
-    model = model.transform(ConvertQONNXtoFINN())
+    model = model.transform(
+        ConvertQONNXtoFINN(
+            max_multithreshold_bit_width=cfg.max_multithreshold_bit_width
+        )
+    )
 
     if VerificationStepType.QONNX_TO_FINN_PYTHON in cfg._resolve_verification_steps():
         verify_step(model, cfg, "initial_python", need_parent=False)
diff --git a/src/finn/transformation/qonnx/convert_qonnx_to_finn.py b/src/finn/transformation/qonnx/convert_qonnx_to_finn.py
index 184a97853..d4eba367c 100644
--- a/src/finn/transformation/qonnx/convert_qonnx_to_finn.py
+++ b/src/finn/transformation/qonnx/convert_qonnx_to_finn.py
@@ -43,8 +43,33 @@ class ConvertQONNXtoFINN(Transformation):
     then the ConvertQuantActToMultiThreshold transformation is used to convert
     the activations.
     If incompatibilities are found a ValueError or RuntimeError is raised.
+
+    The optional keyword arguments `max_multithreshold_bit_width` and `filter_lambda`
+    present a way to control which Quant and BinaryQuant nodes in the activation path
+    are converted to MultiThreshold nodes.
+    The filters which are represented by `max_multithreshold_bit_width` and
+    `filter_lambda` are internally connected by an `AND` operation. A warning
+    will be emitted when a Quant node is not converted to a MultiThreshold node.
+
+    :param max_multithreshold_bit_width: The value of max_multithreshold_bit_width is
+    checked against the bit width of any given Quant node and the transformation to a
+    MultiTrheshold node is rejected, when the bitwidth of the Quant node is larger
+    than value of max_multithreshold_bit_with. Defaults to: 4
+    :type max_multithreshold_bit_width: `int`, optional
+    :param filter_lambda: Each candidate Quant and BinaryQant node is first evaluated
+    by this lambda function. If the function returns False,
+    then the node is not converted to a MultiTrheshold node.
+    Defaults to: lambda q_node: True
+    :type filter_lambda: `lambda`, optional
     """
 
+    def __init__(
+        self, max_multithreshold_bit_width=4, filter_lambda=lambda q_node: True
+    ):
+        super().__init__()
+        self.max_multithreshold_bit_width = max_multithreshold_bit_width
+        self._filter_lambda = filter_lambda
+
     def apply(self, model):
         # Gemm operations are not supported by FINN, so we convert them to MatMul
         model = model.transform(GemmToMatMul())
@@ -54,7 +79,12 @@ class ConvertQONNXtoFINN(Transformation):
         # Fold weights
         model = model.transform(FoldQuantWeights())
         # Convert activations
-        model = model.transform(ConvertQuantActToMultiThreshold())
+        model = model.transform(
+            ConvertQuantActToMultiThreshold(
+                max_multithreshold_bit_width=self.max_multithreshold_bit_width,
+                filter_lambda=self._filter_lambda,
+            )
+        )
         # Recompute datatypes
         model = model.transform(InferDataTypes())
 
diff --git a/src/finn/transformation/qonnx/quant_act_to_multithreshold.py b/src/finn/transformation/qonnx/quant_act_to_multithreshold.py
index 641cef0c3..cd203e63f 100644
--- a/src/finn/transformation/qonnx/quant_act_to_multithreshold.py
+++ b/src/finn/transformation/qonnx/quant_act_to_multithreshold.py
@@ -27,12 +27,41 @@
 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
 
+import warnings
+
 from finn.transformation.base import Transformation
 from finn.transformation.qonnx.qonnx_activation_handlers import QuantActBaseHandler
 
 
 class ConvertQuantActToMultiThreshold(Transformation):
-    """Converts Quant nodes in the activation path to MultiThreshold nodes."""
+    """
+    Converts Quant nodes in the activation path to MultiThreshold nodes.
+
+    The optional keyword arguments `max_multithreshold_bit_width` and `filter_lambda`
+    present a way to control which Quant and BinaryQuant nodes in the activation path
+    are converted to MultiThreshold nodes.
+    The filters which are represented by `max_multithreshold_bit_width` and
+    `filter_lambda` are internally connected by an `AND` operation. A warning
+    will be emitted when a Quant node is not converted to a MultiThreshold node.
+
+    :param max_multithreshold_bit_width: The value of max_multithreshold_bit_width is
+    checked against the bit width of any given Quant node and the transformation to a
+    MultiTrheshold node is rejected, when the bitwidth of the Quant node is larger
+    than value of max_multithreshold_bit_with. Defaults to: 4
+    :type max_multithreshold_bit_width: `int`, optional
+    :param filter_lambda: Each candidate Quant and BinaryQant node is first evaluated
+    by this lambda function. If the function returns False,
+    then the node is not converted to a MultiTrheshold node.
+    Defaults to: lambda q_node: True
+    :type filter_lambda: `lambda`, optional
+    """
+
+    def __init__(
+        self, max_multithreshold_bit_width=4, filter_lambda=lambda q_node: True
+    ):
+        super().__init__()
+        self.max_multithreshold_bit_width = max_multithreshold_bit_width
+        self._filter_lambda = filter_lambda
 
     def apply(self, model):
         graph = model.graph
@@ -61,6 +90,27 @@ class ConvertQuantActToMultiThreshold(Transformation):
                         "Only Quant nodes with zero-point == 0 are currently supported."
                     )
 
+                # Check if the bit width is low enough
+                bit_width = model.get_initializer(n.input[3])
+                if bit_width is None:
+                    raise ValueError("Quant nodes must have a static bit width.")
+                if bit_width > self.max_multithreshold_bit_width:
+                    warnings.warn(
+                        f'The Quant node with name: "{n.name}" was not converted to a '
+                        f"MultiThreshold node, because its bit width of {bit_width} is "
+                        f"higher than the configured maximum bit width of "
+                        f"{self.max_multithreshold_bit_width}."
+                    )
+                    continue
+                # Check that this node passes the user filter
+                if not self._filter_lambda(n):
+                    warnings.warn(
+                        f'The Quant node with name: "{n.name}" was not converted to a '
+                        f"MultiThreshold node, because the filtering lambda function "
+                        f"returned False for this node."
+                    )
+                    continue
+
                 # Check for possible ambiguity in handler selection
                 valid_predecessors = []
                 for cls in QuantActBaseHandler.__subclasses__():
-- 
GitLab