From 9be84cec793b9e63c9ba40fea863dd08943d55e6 Mon Sep 17 00:00:00 2001
From: Hendrik Borras <hendrikborras@web.de>
Date: Fri, 15 Oct 2021 10:12:00 +0100
Subject: [PATCH] Added transformation to extract the bias from a Conv node.

---
 .../qonnx/convert_qonnx_to_finn.py            |  5 +-
 .../transformation/qonnx/extract_conv_bias.py | 94 +++++++++++++++++++
 2 files changed, 98 insertions(+), 1 deletion(-)
 create mode 100644 src/finn/transformation/qonnx/extract_conv_bias.py

diff --git a/src/finn/transformation/qonnx/convert_qonnx_to_finn.py b/src/finn/transformation/qonnx/convert_qonnx_to_finn.py
index 5db1f302a..e6bc7b093 100644
--- a/src/finn/transformation/qonnx/convert_qonnx_to_finn.py
+++ b/src/finn/transformation/qonnx/convert_qonnx_to_finn.py
@@ -31,6 +31,7 @@ from qonnx.transformation.quant_constant_folding import FoldTransposeIntoQuantIn
 from finn.transformation.base import Transformation
 from finn.transformation.gemm_to_matmul import GemmToMatMul
 from finn.transformation.infer_datatypes import InferDataTypes
+from finn.transformation.qonnx.extract_conv_bias import ExtractBiasFromConv
 from finn.transformation.qonnx.fold_quant_weights import FoldQuantWeights
 from finn.transformation.qonnx.infer_QuantAvgPool2d import AvgPoolAndTruncToQuantAvgPool
 from finn.transformation.qonnx.quant_act_to_multithreshold import (
@@ -73,6 +74,8 @@ class ConvertQONNXtoFINN(Transformation):
         self._filter_lambda = filter_lambda
 
     def apply(self, model):
+        # Extract the bias from Conv node
+        model = model.transform(ExtractBiasFromConv())
         # Gemm operations are not supported by FINN, so we convert them to MatMul
         model = model.transform(GemmToMatMul())
         model = model.transform(FoldTransposeIntoQuantInit())
@@ -94,4 +97,4 @@ class ConvertQONNXtoFINN(Transformation):
         # Remove empty padding if it exists
         model = model.transform(RemoveEmptyPadding())
 
-        return (model, False)
+        return model, False
diff --git a/src/finn/transformation/qonnx/extract_conv_bias.py b/src/finn/transformation/qonnx/extract_conv_bias.py
new file mode 100644
index 000000000..c15ef7e81
--- /dev/null
+++ b/src/finn/transformation/qonnx/extract_conv_bias.py
@@ -0,0 +1,94 @@
+# Copyright (c) 2021, Xilinx
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of FINN nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import warnings
+from onnx import TensorProto, helper
+
+from finn.transformation.base import Transformation
+
+# ToDo: Move this transformation into finn-base?
+
+
+class ExtractBiasFromConv(Transformation):
+    """
+    Extracts the (optional) Bias from a Conv node and inserts it behind the
+    Conv node as an Add node.
+    """
+
+    def apply(self, model):
+        graph = model.graph
+        node_ind = 0
+        for n in graph.node:
+            node_ind += 1
+            if n.op_type == "Conv":
+                # Check if the node has a bias input
+                if len(n.input) > 2:
+                    # Extract bias
+                    bias = model.get_initializer(n.input[2])
+                    if bias is None:
+                        warnings.warn(
+                            f"Could not extract bias from Conv node {n}, "
+                            f"due to missing static initialization."
+                        )
+                        continue
+
+                    # Insert bias as Add node behind the Conv node
+                    out_shape = model.get_tensor_shape(n.output[0])
+                    add_shape = [1] * len(out_shape)
+                    # ToDo: this must change to "add_shape[-1] = bias.shape[0]" when
+                    #  channels last comes around
+                    add_shape[1] = bias.shape[0]
+                    add_tensor = helper.make_tensor_value_info(
+                        model.make_new_valueinfo_name(),
+                        TensorProto.FLOAT,
+                        add_shape,
+                    )
+                    graph.value_info.append(add_tensor)
+                    model.set_initializer(add_tensor.name, bias.reshape(add_shape))
+
+                    act_add_tensor = helper.make_tensor_value_info(
+                        model.make_new_valueinfo_name(),
+                        TensorProto.FLOAT,
+                        out_shape,
+                    )
+                    graph.value_info.append(act_add_tensor)
+
+                    add_node = helper.make_node(
+                        "Add",
+                        [act_add_tensor.name, add_tensor.name],
+                        [n.output[0]],
+                    )
+                    graph.node.insert(node_ind, add_node)
+
+                    # Repoint Conv output and remove bias tensor
+                    n.output[0] = act_add_tensor.name
+                    n.input.remove(n.input[2])
+
+                    return model, True
+
+        return model, False
-- 
GitLab