From 28f93d5ac66c0fc09c7af5aa24afb58d3909d7fe Mon Sep 17 00:00:00 2001
From: Hendrik Borras <hendrikborras@web.de>
Date: Tue, 19 Oct 2021 13:33:30 +0100
Subject: [PATCH] Refactored filter lambda for ConvertQuantActToMultiThreshold
 transformation to only use the function for filtering.

---
 src/finn/builder/build_dataflow_steps.py      |  7 +-
 .../qonnx/convert_qonnx_to_finn.py            | 37 ++++----
 .../qonnx/quant_act_to_multithreshold.py      | 88 +++++++++++--------
 3 files changed, 73 insertions(+), 59 deletions(-)

diff --git a/src/finn/builder/build_dataflow_steps.py b/src/finn/builder/build_dataflow_steps.py
index 7fce5b35f..96cda8e0e 100644
--- a/src/finn/builder/build_dataflow_steps.py
+++ b/src/finn/builder/build_dataflow_steps.py
@@ -96,6 +96,9 @@ from finn.transformation.infer_shapes import InferShapes
 from finn.transformation.lower_convs_to_matmul import LowerConvsToMatMul
 from finn.transformation.move_reshape import RemoveCNVtoFCFlatten
 from finn.transformation.qonnx.convert_qonnx_to_finn import ConvertQONNXtoFINN
+from finn.transformation.qonnx.quant_act_to_multithreshold import (
+    default_filter_function_generator,
+)
 from finn.transformation.streamline import Streamline
 from finn.transformation.streamline.reorder import MakeMaxPoolNHWC
 from finn.util.config import extract_model_config_to_json
@@ -208,7 +211,9 @@ def step_qonnx_to_finn(model: ModelWrapper, cfg: DataflowBuildConfig):
     # QONNX to FINN-ONNX
     model = model.transform(
         ConvertQONNXtoFINN(
-            max_multithreshold_bit_width=cfg.max_multithreshold_bit_width
+            filter_function=default_filter_function_generator(
+                max_multithreshold_bit_width=cfg.max_multithreshold_bit_width
+            )
         )
     )
 
diff --git a/src/finn/transformation/qonnx/convert_qonnx_to_finn.py b/src/finn/transformation/qonnx/convert_qonnx_to_finn.py
index f2b4cc3ec..b1b1dd0bb 100644
--- a/src/finn/transformation/qonnx/convert_qonnx_to_finn.py
+++ b/src/finn/transformation/qonnx/convert_qonnx_to_finn.py
@@ -36,6 +36,7 @@ from finn.transformation.qonnx.fold_quant_weights import FoldQuantWeights
 from finn.transformation.qonnx.infer_QuantAvgPool2d import AvgPoolAndTruncToQuantAvgPool
 from finn.transformation.qonnx.quant_act_to_multithreshold import (
     ConvertQuantActToMultiThreshold,
+    default_filter_function_generator,
 )
 from finn.transformation.remove import RemoveEmptyPadding
 
@@ -47,31 +48,28 @@ class ConvertQONNXtoFINN(Transformation):
     the activations.
     If incompatibilities are found a ValueError or RuntimeError is raised.
 
-    The optional keyword arguments `max_multithreshold_bit_width` and `filter_lambda`
-    present a way to control which Quant and BipolarQuant nodes in the activation path
-    are converted to MultiThreshold nodes.
-    The filters which are represented by `max_multithreshold_bit_width` and
-    `filter_lambda` are internally connected by an `AND` operation. A warning
-    will be emitted when a Quant node is not converted to a MultiThreshold node.
+    The optional keyword argument `filter_function`
+    presents a way to control which Quant and BipolarQuant nodes in the activation path
+    are converted to MultiThreshold nodes. A warning will be emitted when a Quant node
+    is not converted to a MultiThreshold node.
 
-    :param max_multithreshold_bit_width: The value of max_multithreshold_bit_width is
-    checked against the bit width of any given Quant node and the transformation to a
-    MultiTrheshold node is rejected, when the bitwidth of the Quant node is larger
-    than value of max_multithreshold_bit_with. Defaults to: 8
-    :type max_multithreshold_bit_width: `int`, optional
-    :param filter_lambda: Each candidate Quant and BinaryQant node is first evaluated
-    by this lambda function. If the function returns False,
+    :param filter_function: Each candidate Quant and BinaryQant node is first evaluated
+    by this function. If the function returns False,
     then the node is not converted to a MultiTrheshold node.
-    Defaults to: lambda q_node: True
-    :type filter_lambda: `lambda`, optional
+    The function is given the model and candidate node as parameters.
+    Per default a filter function is inserted, which disables the conversion of
+    Quant nodes, which have a bit width of larger than 8.
+    Defaults to: default_filter_function_generator(max_multithreshold_bit_width=8)
     """
 
     def __init__(
-        self, max_multithreshold_bit_width=8, filter_lambda=lambda q_node: True
+        self,
+        filter_function=default_filter_function_generator(
+            max_multithreshold_bit_width=8
+        ),
     ):
         super().__init__()
-        self.max_multithreshold_bit_width = max_multithreshold_bit_width
-        self._filter_lambda = filter_lambda
+        self._filter_function = filter_function
 
     def apply(self, model):
         # Extract the bias from Conv node
@@ -86,8 +84,7 @@ class ConvertQONNXtoFINN(Transformation):
         # Convert activations
         model = model.transform(
             ConvertQuantActToMultiThreshold(
-                max_multithreshold_bit_width=self.max_multithreshold_bit_width,
-                filter_lambda=self._filter_lambda,
+                filter_function=self._filter_function,
             )
         )
         # Recompute datatypes
diff --git a/src/finn/transformation/qonnx/quant_act_to_multithreshold.py b/src/finn/transformation/qonnx/quant_act_to_multithreshold.py
index 520ed0b82..29ba93dfc 100644
--- a/src/finn/transformation/qonnx/quant_act_to_multithreshold.py
+++ b/src/finn/transformation/qonnx/quant_act_to_multithreshold.py
@@ -33,35 +33,64 @@ from finn.transformation.base import Transformation
 from finn.transformation.qonnx.qonnx_activation_handlers import QuantActBaseHandler
 
 
+def default_filter_function_generator(max_multithreshold_bit_width=8):
+    """
+    This function generates the default filter function for the
+    ConvertQuantActToMultiThreshold transformation. Per default the returned
+    function disables the conversion of Quant nodes which have a bit width above 8 bit.
+
+    This function generator can be used as a template to write custom
+    filter functions.
+    """
+
+    def filter_function(model, q_node):
+        if q_node.op_type == "Quant":
+            bit_width = model.get_initializer(q_node.input[3])
+        elif q_node.op_type == "BipolarQuant":
+            bit_width = 1.0
+        else:
+            raise RuntimeError("Got an unexpected quantizer node type")
+        if bit_width is None:
+            raise ValueError("Quant nodes must have a static bit width.")
+        if bit_width > max_multithreshold_bit_width:
+            warnings.warn(
+                f'The Quant node with name: "{q_node.name}" was not converted to a '
+                f"MultiThreshold node, because its bit width of {bit_width} is "
+                f"higher than the configured maximum bit width of "
+                f"{max_multithreshold_bit_width}."
+            )
+            return False
+        return True
+
+    return filter_function
+
+
 class ConvertQuantActToMultiThreshold(Transformation):
     """
     Converts Quant nodes in the activation path to MultiThreshold nodes.
 
-    The optional keyword arguments `max_multithreshold_bit_width` and `filter_lambda`
-    present a way to control which Quant and BipolarQuant nodes in the activation path
-    are converted to MultiThreshold nodes.
-    The filters which are represented by `max_multithreshold_bit_width` and
-    `filter_lambda` are internally connected by an `AND` operation. A warning
-    will be emitted when a Quant node is not converted to a MultiThreshold node.
-
-    :param max_multithreshold_bit_width: The value of max_multithreshold_bit_width is
-    checked against the bit width of any given Quant node and the transformation to a
-    MultiTrheshold node is rejected, when the bitwidth of the Quant node is larger
-    than value of max_multithreshold_bit_with. Defaults to: 8
-    :type max_multithreshold_bit_width: `int`, optional
-    :param filter_lambda: Each candidate Quant and BinaryQant node is first evaluated
-    by this lambda function. If the function returns False,
+    The optional keyword argument `filter_function`
+    presents a way to control which Quant and BipolarQuant nodes in the activation path
+    are converted to MultiThreshold nodes. A warning will be emitted when a Quant node
+    is not converted to a MultiThreshold node.
+
+    :param filter_function: Each candidate Quant and BinaryQant node is first evaluated
+    by this function. If the function returns False,
     then the node is not converted to a MultiTrheshold node.
-    Defaults to: lambda q_node: True
-    :type filter_lambda: `lambda`, optional
+    The function is given the model and candidate node as parameters.
+    Per default a filter function is inserted, which disables the conversion of
+    Quant nodes, which have a bit width of larger than 8.
+    Defaults to: default_filter_function_generator(max_multithreshold_bit_width=8)
     """
 
     def __init__(
-        self, max_multithreshold_bit_width=8, filter_lambda=lambda q_node: True
+        self,
+        filter_function=default_filter_function_generator(
+            max_multithreshold_bit_width=8
+        ),
     ):
         super().__init__()
-        self.max_multithreshold_bit_width = max_multithreshold_bit_width
-        self._filter_lambda = filter_lambda
+        self._filter_function = filter_function
 
     def apply(self, model):
         graph = model.graph
@@ -91,28 +120,11 @@ class ConvertQuantActToMultiThreshold(Transformation):
                         "Only Quant nodes with zero-point == 0 are currently supported."
                     )
 
-                # Check if the bit width is low enough
-                if n.op_type == "Quant":
-                    bit_width = model.get_initializer(n.input[3])
-                elif n.op_type == "BipolarQuant":
-                    bit_width = 1.0
-                else:
-                    raise RuntimeError("Got an unexpected quantizer node type")
-                if bit_width is None:
-                    raise ValueError("Quant nodes must have a static bit width.")
-                if bit_width > self.max_multithreshold_bit_width:
-                    warnings.warn(
-                        f'The Quant node with name: "{n.name}" was not converted to a '
-                        f"MultiThreshold node, because its bit width of {bit_width} is "
-                        f"higher than the configured maximum bit width of "
-                        f"{self.max_multithreshold_bit_width}."
-                    )
-                    continue
                 # Check that this node passes the user filter
-                if not self._filter_lambda(n):
+                if not self._filter_function(model, n):
                     warnings.warn(
                         f'The Quant node with name: "{n.name}" was not converted to a '
-                        f"MultiThreshold node, because the filtering lambda function "
+                        f"MultiThreshold node, because the filtering function "
                         f"returned False for this node."
                     )
                     continue
-- 
GitLab