From 5f3cf6520121894a00c32baf5f6d4aef42d1bd14 Mon Sep 17 00:00:00 2001
From: Hendrik Borras <hendrikborras@web.de>
Date: Thu, 14 Oct 2021 16:01:56 +0100
Subject: [PATCH] Added support for BinaryQuant conversion.

---
 docker/Dockerfile.finn                        |   6 +-
 .../qonnx/fold_quant_weights.py               |  47 +++--
 .../qonnx/qonnx_activation_handlers.py        | 167 +++++++++++-------
 .../qonnx/quant_act_to_multithreshold.py      |  14 +-
 tests/transformation/test_qonnx_to_finn.py    |   7 +-
 5 files changed, 156 insertions(+), 85 deletions(-)

diff --git a/docker/Dockerfile.finn b/docker/Dockerfile.finn
index dfb6b8658..3856065f2 100644
--- a/docker/Dockerfile.finn
+++ b/docker/Dockerfile.finn
@@ -86,10 +86,10 @@ RUN pip install -e git+https://github.com/fbcotter/dataset_loading.git@0.0.4#egg
 
 # git-based Python repo dependencies
 # these are installed in editable mode for easier co-development
-ARG FINN_BASE_COMMIT="138f177a2729334444d5f58d96c32a1bf4d4b6c2"
-ARG QONNX_COMMIT="02b15f56d199576cefe9aff672d5e009349e402a"
+ARG FINN_BASE_COMMIT="ec3997c3f4276f7746bfd08a1a9508bd02a132fa"
+ARG QONNX_COMMIT="834610ba3f668971fe2800fde7f8d0c10d825d5b"
 ARG FINN_EXP_COMMIT="f82c0d9868bb88ea045dfadb28508d327d287221"
-ARG BREVITAS_COMMIT="462f86cdc60f9915baf13afd1676fb21da44c2ee"
+ARG BREVITAS_COMMIT="0eaff006407955153594254728baeb988edcd042"
 ARG PYVERILATOR_COMMIT="0c3eb9343500fc1352a02c020a736c8c2db47e8e"
 ARG CNPY_COMMIT="4e8810b1a8637695171ed346ce68f6984e585ef4"
 ARG HLSLIB_COMMIT="fbb07135b3d991602e8abe3f2c51212c11fd392b"
diff --git a/src/finn/transformation/qonnx/fold_quant_weights.py b/src/finn/transformation/qonnx/fold_quant_weights.py
index f7c1725e6..7dd942307 100644
--- a/src/finn/transformation/qonnx/fold_quant_weights.py
+++ b/src/finn/transformation/qonnx/fold_quant_weights.py
@@ -31,10 +31,24 @@ from onnx import TensorProto, helper
 
 import finn.core.onnx_exec as oxe
 from finn.core.datatype import DataType
+from finn.custom_op.registry import getCustomOp
 from finn.transformation.base import Transformation
 from finn.transformation.infer_shapes import InferShapes
 
 
+def get_dtype(bit_width: int, signed: bool) -> DataType:
+    bit_width = int(bit_width)
+    signed = bool(signed)
+    if bit_width == 1.0:
+        finn_dt = DataType["BIPOLAR"]
+    else:
+        if signed:
+            finn_dt = DataType["INT" + str(bit_width)]
+        else:
+            finn_dt = DataType["UINT" + str(bit_width)]
+    return finn_dt
+
+
 class FoldQuantWeights(Transformation):
     """Merges Quant nodes, which are used as weights into the initializer
     of the weight tensor.
@@ -47,7 +61,7 @@ class FoldQuantWeights(Transformation):
         execution_context = model.make_empty_exec_context()
         for n in graph.node:
             node_ind += 1
-            if n.op_type == "Quant":
+            if n.op_type == "Quant" or n.op_type == "BinaryQuant":
                 node_inp_inits = list(map(lambda x: model.get_initializer(x), n.input))
                 node_inp_dyn = list(filter(lambda x: x is None, node_inp_inits))
                 node_out = n.output[0]
@@ -55,18 +69,25 @@ class FoldQuantWeights(Transformation):
                 ishape = model.get_tensor_shape(n.input[0])
                 is_const_shape = (n.op_type == "Shape") and (ishape is not None)
                 if is_all_constant_inputs or is_const_shape:
-                    if not model.get_initializer(n.input[2]) == 0:
+                    if (
+                        n.op_type == "Quant"
+                        and not model.get_initializer(n.input[2]) == 0
+                    ):
                         raise ValueError(
                             "Only Quant nodes with zero-point == 0 "
                             "are currently supported."
                         )
+                    scale = model.get_initializer(n.input[1])
+                    unity_scale = (scale.flatten() == 1.0).all()
                     # this node has no dynamic inputs, only constant ones -- so we can
                     # do constant folding.
                     oxe.execute_node(n, execution_context, graph)
                     q_node_output = execution_context[node_out]
-                    # Check if the datatype can be directly constant folded
-                    dtype = model.get_tensor_datatype(n.output[0])
-                    if "SCALED" in dtype.name:
+                    # Check we can directly constant fold
+                    if unity_scale:
+                        # use the execution result as an initializer
+                        model.set_initializer(node_out, q_node_output)
+                    else:
                         # Reshape scale for Conv if required
                         if model.is_fork_node(n):
                             raise RuntimeError(
@@ -95,14 +116,23 @@ class FoldQuantWeights(Transformation):
                                 f"at node: {target_node}."
                             )
 
-                        # For buth mul and Add:
+                        # For both mul and Add:
                         # Move the scale factor behind the next operator
                         scale = model.get_initializer(n.input[1])
                         new_initializer = q_node_output / scale
                         # Round, to correct for floating point errors
                         new_initializer = np.round(new_initializer)
                         model.set_initializer(node_out, new_initializer)
-                        new_dtype = DataType[dtype.name.replace("SCALED", "")]
+                        if n.op_type == "Quant":
+                            bit_width = model.get_initializer(n.input[3])
+                            q_inst = getCustomOp(n)
+                            signed = q_inst.get_nodeattr("signed")
+                        elif n.op_type == "BinaryQuant":
+                            bit_width = 1.0
+                            signed = True
+                        else:
+                            raise RuntimeError("Got an unexpected quantizer node type")
+                        new_dtype = get_dtype(bit_width, signed)
                         model.set_tensor_datatype(node_out, new_dtype)
 
                         if target_node.op_type == "Conv" and len(scale.shape) > 0:
@@ -175,9 +205,6 @@ class FoldQuantWeights(Transformation):
                             )
                             graph.node.insert(node_ind, div_node)
 
-                    else:
-                        # use the execution result as an initializer
-                        model.set_initializer(node_out, q_node_output)
                     # remove old node
                     graph.node.remove(n)
                     graph_modified = True
diff --git a/src/finn/transformation/qonnx/qonnx_activation_handlers.py b/src/finn/transformation/qonnx/qonnx_activation_handlers.py
index 3b1d84611..0814aecab 100644
--- a/src/finn/transformation/qonnx/qonnx_activation_handlers.py
+++ b/src/finn/transformation/qonnx/qonnx_activation_handlers.py
@@ -160,7 +160,7 @@ class QuantActBaseHandler(ABC):
         # behind the MultiTrheshold nodes.
         bias_scalar = adder_bias.shape == (1,) or len(adder_bias.shape) == 0
         scale_scalar = mul_scale.shape == (1,) or len(mul_scale.shape) == 0
-        if scale_scalar and bias_scalar and False:
+        if scale_scalar and bias_scalar and self._q_node.op_type == "BinaryQuant":
             # Get Quant parameters
             mul_scale = np.atleast_1d(mul_scale)
             # ONNX only accepts 64bit floats as attributes
@@ -173,10 +173,11 @@ class QuantActBaseHandler(ABC):
             # FINN applies scale first then bias,
             # which is the other way around in Brevitas,
             # we thus need to adjust the bias in the MultiThreshold node
-            mt_inst.set_nodeattr("out_bias", adder_bias[0] * mul_scale[0])
+            finn_bias = adder_bias[0] * mul_scale[0]
+            mt_inst.set_nodeattr("out_bias", finn_bias)
 
             # If the bias and scale are integers, then the output will be as well.
-            if adder_bias % 1 == 0 and mul_scale % 1 == 0:
+            if finn_bias % 1 == 0 and mul_scale % 1 == 0:
                 mt_inst.set_nodeattr("out_dtype", out_dtype)
         else:
             # Set datatype
@@ -289,19 +290,24 @@ class QuantReluHandler(QuantActBaseHandler):
     ]
 
     def _check_compatibility(self):
-        q_inst = getCustomOp(self._q_node)
-        narrow = q_inst.get_nodeattr("narrow")
-        signed = q_inst.get_nodeattr("signed")
-        if signed or narrow:
-            raise ValueError(
-                "FINN only supports unsigned and non-narrow Quant nodes "
-                "for Relu activations."
-            )
-        if not self._model.get_initializer(self._q_node.input[2]) == 0:
-            raise ValueError(
-                "Only Quant nodes with zero-point == 0 "
-                "are currently supported for ReLu activations."
-            )
+        if self._q_node.op_type == "Quant":
+            q_inst = getCustomOp(self._q_node)
+            narrow = q_inst.get_nodeattr("narrow")
+            signed = q_inst.get_nodeattr("signed")
+            if signed or narrow:
+                raise ValueError(
+                    "FINN only supports unsigned and non-narrow Quant nodes "
+                    "for Relu activations."
+                )
+            if not self._model.get_initializer(self._q_node.input[2]) == 0:
+                raise ValueError(
+                    "Only Quant nodes with zero-point == 0 "
+                    "are currently supported for ReLu activations."
+                )
+        elif self._q_node.op_type == "BinaryQuant":
+            return
+        else:
+            raise RuntimeError("Got an unexpected quantizer node type")
 
     def _calculate_act_bias(self):
         # No bias allowed for Relu activations, see: https://github.com/Xilinx/
@@ -312,7 +318,12 @@ class QuantReluHandler(QuantActBaseHandler):
 
     def _calculate_thresholds(self):
         # Gather parameters
-        bit_width = self._model.get_initializer(self._q_node.input[3])
+        if self._q_node.op_type == "Quant":
+            bit_width = self._model.get_initializer(self._q_node.input[3])
+        elif self._q_node.op_type == "BinaryQuant":
+            bit_width = 1.0
+        else:
+            raise RuntimeError("Got an unexpected quantizer node type")
         quant_scale = self._model.get_initializer(self._q_node.input[1])
         # q_inst = getCustomOp(self._q_node)
         # narrow = q_inst.get_nodeattr("narrow")
@@ -388,27 +399,42 @@ class QuantIdentityHandler(QuantActBaseHandler):
 
     def _check_compatibility(self):
         # Gather parameters to check
-        q_inst = getCustomOp(self._q_node)
-        signed = q_inst.get_nodeattr("signed")
-        if not signed:
-            raise ValueError(
-                "FINN only supports signed Quant nodes for identity activations."
-            )
-        if not self._model.get_initializer(self._q_node.input[2]) == 0:
-            raise ValueError(
-                "Only Quant nodes with zero-point == 0 "
-                "are currently supported for identity activations."
-            )
+        if self._q_node.op_type == "Quant":
+            q_inst = getCustomOp(self._q_node)
+            signed = q_inst.get_nodeattr("signed")
+            if not signed:
+                raise ValueError(
+                    "FINN only supports signed Quant nodes for identity activations."
+                )
+            if not self._model.get_initializer(self._q_node.input[2]) == 0:
+                raise ValueError(
+                    "Only Quant nodes with zero-point == 0 "
+                    "are currently supported for identity activations."
+                )
+        elif self._q_node.op_type == "BinaryQuant":
+            quant_scale = self._model.get_initializer(self._q_node.input[1])
+            if (quant_scale.flatten().shape[0] != 1) or quant_scale.flatten()[0] != 1.0:
+                raise ValueError(
+                    "FINN only supports Bipolar identity activations "
+                    "with out per channel scaling and the scaling must be 1. "
+                )
+        else:
+            raise RuntimeError("Got an unexpected quantizer node type")
 
     def _calculate_act_bias(self):
         # Gather parameters
-        bit_width = self._model.get_initializer(self._q_node.input[3])
         q_inst = getCustomOp(self._q_node)
-        narrow = q_inst.get_nodeattr("narrow")
+        if self._q_node.op_type == "Quant":
+            bit_width = self._model.get_initializer(self._q_node.input[3])
+            narrow = q_inst.get_nodeattr("narrow")
+        elif self._q_node.op_type == "BinaryQuant":
+            bit_width = 1.0
+        else:
+            raise RuntimeError("Got an unexpected quantizer node type")
         # Calculate bias, see: https://github.com/Xilinx/brevitas/blob/
         # a5bfd6dc5e030f0047ac1ee47932b60e8e873e17/src/brevitas/export/
         # onnx/finn/handler/act.py#L64
-        if bit_width == 1:
+        if bit_width == 1.0:
             bias = np.array([-0.5])
         else:
             if narrow:
@@ -420,45 +446,62 @@ class QuantIdentityHandler(QuantActBaseHandler):
 
     def _calculate_thresholds(self):
         # Gather parameters
-        bit_width = self._model.get_initializer(self._q_node.input[3])
         quant_scale = self._model.get_initializer(self._q_node.input[1])
         q_inst = getCustomOp(self._q_node)
-        narrow = q_inst.get_nodeattr("narrow")
+        if self._q_node.op_type == "Quant":
+            bit_width = self._model.get_initializer(self._q_node.input[3])
+            narrow = q_inst.get_nodeattr("narrow")
+        elif self._q_node.op_type == "BinaryQuant":
+            bit_width = 1.0
+        else:
+            raise RuntimeError("Got an unexpected quantizer node type")
 
         # Calculate thersholds, see: https://github.com/Xilinx/brevitas/
         # blob/a5bfd6dc5e030f0047ac1ee47932b60e8e873e17/src/brevitas/
         # export/onnx/finn/handler/act.py#L76
-        if narrow:
-            num_distinct_values = 2 ** bit_width - 1
+        if bit_width == 1.0:
+            thresholds = np.empty([1, 1])
+            thresholds[0] = 0
+            return thresholds
         else:
-            num_distinct_values = 2 ** bit_width
-
-        num_thresholds = int(num_distinct_values - 1)
-        flat_scale = quant_scale.flatten()
-        num_scale_channels = flat_scale.shape[0]
-        step = np.abs(flat_scale)
-        half_step = step / 2.0
-        thresholds = np.empty((num_scale_channels, num_thresholds))
-        # compute the value of the smallest threshold, we'll neg-bias all
-        # generated thresholds by this much
-        min_threshold = -half_step - step * ((num_thresholds // 2) - 1)
-        if not narrow:
-            min_threshold -= step
-        for c in range(num_scale_channels):
-            for t in range(num_thresholds):
-                thresholds[c][t] = min_threshold[c] + step[c] * t
-
-        # ToDo: The index 1 needs to be changed to -1 for the channels last format
-        num_output_channels = self._model.get_tensor_shape(self._q_node.output[0])[1]
-        final_shape = (num_output_channels, num_thresholds)
-        if thresholds.shape != final_shape:
-            thresholds = np.broadcast_to(thresholds, final_shape)
-
-        return thresholds
+            if narrow:
+                num_distinct_values = 2 ** bit_width - 1
+            else:
+                num_distinct_values = 2 ** bit_width
+
+            num_thresholds = int(num_distinct_values - 1)
+            flat_scale = quant_scale.flatten()
+            num_scale_channels = flat_scale.shape[0]
+            step = np.abs(flat_scale)
+            half_step = step / 2.0
+            thresholds = np.empty((num_scale_channels, num_thresholds))
+            # compute the value of the smallest threshold, we'll neg-bias all
+            # generated thresholds by this much
+            min_threshold = -half_step - step * ((num_thresholds // 2) - 1)
+            if not narrow:
+                min_threshold -= step
+            for c in range(num_scale_channels):
+                for t in range(num_thresholds):
+                    thresholds[c][t] = min_threshold[c] + step[c] * t
+
+            # ToDo: The index 1 needs to be changed to -1 for the channels last format
+            num_output_channels = self._model.get_tensor_shape(self._q_node.output[0])[
+                1
+            ]
+            final_shape = (num_output_channels, num_thresholds)
+            if thresholds.shape != final_shape:
+                thresholds = np.broadcast_to(thresholds, final_shape)
+
+            return thresholds
 
     def _calculate_act_scale(self):
         # Gather parameters
-        bit_width = self._model.get_initializer(self._q_node.input[3])
+        if self._q_node.op_type == "Quant":
+            bit_width = self._model.get_initializer(self._q_node.input[3])
+        elif self._q_node.op_type == "BinaryQuant":
+            bit_width = 1.0
+        else:
+            raise RuntimeError("Got an unexpected quantizer node type")
         quant_scale = self._model.get_initializer(self._q_node.input[1])
         # Calculate scale, see: https://github.com/Xilinx/brevitas/
         # blob/a5bfd6dc5e030f0047ac1ee47932b60e8e873e17/src/brevitas/
@@ -466,12 +509,10 @@ class QuantIdentityHandler(QuantActBaseHandler):
         if bit_width != 1:
             scale = quant_scale
         else:
-            # ToDo: This needs testing and/or rewriting when the BinarayQuant op
-            #  comes around
             assert (
                 quant_scale.flatten().shape[0] == 1
             ), "Unsupported BIPOLAR per channel scale"
-            assert quant_scale.flatten().item() == 1.0, "Unsupported BIPOLAR scale != 1"
+            assert quant_scale.flatten()[0] == 1.0, "Unsupported BIPOLAR scale != 1"
             scale = quant_scale * 2
         return scale
 
diff --git a/src/finn/transformation/qonnx/quant_act_to_multithreshold.py b/src/finn/transformation/qonnx/quant_act_to_multithreshold.py
index 24b185796..90ccebe8b 100644
--- a/src/finn/transformation/qonnx/quant_act_to_multithreshold.py
+++ b/src/finn/transformation/qonnx/quant_act_to_multithreshold.py
@@ -70,7 +70,7 @@ class ConvertQuantActToMultiThreshold(Transformation):
 
         for n in graph.node:
             node_ind += 1
-            if n.op_type == "Quant":
+            if n.op_type == "Quant" or n.op_type == "BinaryQuant":
                 # Check that the node is in the activation path
                 inp = model.get_initializer(n.input[0])
                 out = model.get_initializer(n.output[0])
@@ -83,15 +83,21 @@ class ConvertQuantActToMultiThreshold(Transformation):
                     predecessor_op_type = predecessor
                 if model.is_fork_node(n):
                     raise ValueError(
-                        "Forking Quant nodes are not currently supported by FINN."
+                        "Forking Quant/BinaryQuant nodes are currently "
+                        "not supported by FINN."
                     )
-                if not model.get_initializer(n.input[2]) == 0:
+                if n.op_type == "Quant" and not model.get_initializer(n.input[2]) == 0:
                     raise ValueError(
                         "Only Quant nodes with zero-point == 0 are currently supported."
                     )
 
                 # Check if the bit width is low enough
-                bit_width = model.get_initializer(n.input[3])
+                if n.op_type == "Quant":
+                    bit_width = model.get_initializer(n.input[3])
+                elif n.op_type == "BinaryQuant":
+                    bit_width = 1.0
+                else:
+                    raise RuntimeError("Got an unexpected quantizer node type")
                 if bit_width is None:
                     raise ValueError("Quant nodes must have a static bit width.")
                 if bit_width > self.max_multithreshold_bit_width:
diff --git a/tests/transformation/test_qonnx_to_finn.py b/tests/transformation/test_qonnx_to_finn.py
index db2292319..5041b409f 100644
--- a/tests/transformation/test_qonnx_to_finn.py
+++ b/tests/transformation/test_qonnx_to_finn.py
@@ -89,6 +89,7 @@ def analysis_testing_for_no_quant_nodes(model):
 
 # ToDo: Add KWS networks, when they are ready to be added to finn-examples.
 # ToDo: Add RadioML_VGG10, if possible
+# This test currently takes about 4 min and 42 seconds
 @pytest.mark.parametrize("abits", [1, 2])
 @pytest.mark.parametrize("wbits", [1, 2])
 @pytest.mark.parametrize("model_name", ["TFC", "SFC", "LFC", "CNV", "mobilenet"])
@@ -97,13 +98,9 @@ def test_QONNX_to_FINN(model_name, wbits, abits):
         pytest.skip("No wbits > abits cases at the moment")
     if model_name == "LFC" and wbits == 2 and abits == 2:
         pytest.skip("No LFC-w2a2 present at the moment")
-    if model_name == "mobilenet" and wbits < 2 and abits < 2:
+    if model_name == "mobilenet" and (wbits != 2 or abits != 2):
         pytest.skip("Mobilenet only runs at W2A2, though it's technically W4A4.")
 
-    # ToDo: Remove the following restriction when QONNX supports binary operations.
-    if wbits == 1 or abits == 1:
-        pytest.skip("wbits == 1 or abits == 1 is currently not supported by QONNX.")
-
     brev_model, in_shape, input_tensor = get_brev_model_and_sample_inputs(
         model_name, wbits, abits
     )
-- 
GitLab