From acdaebe2ebe40fd0df16c3862e116a0f87475337 Mon Sep 17 00:00:00 2001
From: Hendrik Borras <hendrikborras@web.de>
Date: Thu, 7 Oct 2021 17:04:31 +0100
Subject: [PATCH] Added support for converting Gemm to MatMul.

---
 .../qonnx/convert_qonnx_to_finn.py            |   5 +
 .../transformation/qonnx/gemm_to_matmul.py    | 218 ++++++++++++++++++
 2 files changed, 223 insertions(+)
 create mode 100644 src/finn/transformation/qonnx/gemm_to_matmul.py

diff --git a/src/finn/transformation/qonnx/convert_qonnx_to_finn.py b/src/finn/transformation/qonnx/convert_qonnx_to_finn.py
index 0c606e36c..f4c69ca53 100644
--- a/src/finn/transformation/qonnx/convert_qonnx_to_finn.py
+++ b/src/finn/transformation/qonnx/convert_qonnx_to_finn.py
@@ -27,12 +27,14 @@
 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
 from onnx import TensorProto, helper
+from qonnx.transformation.quant_constant_folding import FoldTransposeIntoQuantInit
 
 import finn.core.onnx_exec as oxe
 from finn.core.datatype import DataType
 from finn.transformation.base import Transformation
 from finn.transformation.infer_datatypes import InferDataTypes
 from finn.transformation.infer_shapes import InferShapes
+from finn.transformation.qonnx.gemm_to_matmul import GemmToMatMul
 from finn.transformation.qonnx.qonnx_activation_handlers import QuantActBaseHandler
 from finn.util.basic import get_by_name
 
@@ -46,6 +48,9 @@ class ConvertQONNXtoFINN(Transformation):
     """
 
     def apply(self, model):
+        # Gemm operations are not supported by FINN, so we convert them to MatMul
+        model = model.transform(GemmToMatMul())
+        model = model.transform(FoldTransposeIntoQuantInit())
         # Make sure the datatypes exist, these are required for folding the weights
         model = model.transform(InferDataTypes())
         # Fold weights
diff --git a/src/finn/transformation/qonnx/gemm_to_matmul.py b/src/finn/transformation/qonnx/gemm_to_matmul.py
new file mode 100644
index 000000000..6debb169f
--- /dev/null
+++ b/src/finn/transformation/qonnx/gemm_to_matmul.py
@@ -0,0 +1,218 @@
+# Copyright (c) 2021, Xilinx
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of FINN nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import numpy as np
+import warnings
+from onnx import TensorProto, helper
+
+from finn.transformation.base import Transformation
+from finn.util.basic import get_by_name
+
+
+class GemmToMatMul(Transformation):
+    """
+    Converts Gemm op into a MatMul and an Add op.
+    """
+
+    def apply(self, model):
+        graph = model.graph
+        node_ind = 0
+        for n in graph.node:
+            node_ind += 1
+            if n.op_type == "Gemm":
+                running_node_index = node_ind
+                predecessors = model.find_direct_predecessors(n)
+
+                # Transpose A?
+                transA = get_by_name(n.attribute, "transA")
+                if transA is not None and transA.i:
+                    # Insert transpose node
+                    inp_trans_out = helper.make_tensor_value_info(
+                        model.make_new_valueinfo_name(),
+                        TensorProto.FLOAT,
+                        None,
+                        # [1024,1000],
+                    )
+                    graph.value_info.append(inp_trans_out)
+                    inp_trans_node = helper.make_node(
+                        "Transpose", [n.input[0]], [inp_trans_out.name]
+                    )
+                    graph.node.insert(running_node_index, inp_trans_node)
+                    running_node_index += 1
+
+                    n.input[0] = inp_trans_out.name
+
+                # Transpose B?
+                transB = get_by_name(n.attribute, "transB")
+                if transB is not None and transB.i:
+                    # Insert transpose node
+                    inp_trans_out = helper.make_tensor_value_info(
+                        model.make_new_valueinfo_name(),
+                        TensorProto.FLOAT,
+                        None,
+                        # [1024,1000],
+                    )
+                    graph.value_info.append(inp_trans_out)
+                    inp_trans_node = helper.make_node(
+                        "Transpose", [n.input[1]], [inp_trans_out.name]
+                    )
+                    graph.node.insert(running_node_index, inp_trans_node)
+                    running_node_index += 1
+
+                    n.input[1] = inp_trans_out.name
+
+                # Insert MatMul: A * B
+                matMul_node = helper.make_node(
+                    "MatMul", [n.input[0], n.input[1]], [n.output[0]]
+                )
+                graph.node.insert(running_node_index, matMul_node)
+                matMul_node = graph.node[running_node_index]
+                running_node_index += 1
+
+                # Insert Mul: (A*B) * alpha
+                alpha = get_by_name(n.attribute, "alpha")
+                if alpha is None:
+                    alpha = np.array(1.0)
+                else:
+                    alpha = np.array(alpha.f)
+                mul_tensor = helper.make_tensor_value_info(
+                    model.make_new_valueinfo_name(),
+                    TensorProto.FLOAT,
+                    None,
+                )
+                graph.value_info.append(mul_tensor)
+                model.set_initializer(mul_tensor.name, alpha)
+
+                act_mul_tensor = helper.make_tensor_value_info(
+                    model.make_new_valueinfo_name(),
+                    TensorProto.FLOAT,
+                    None,
+                )
+                graph.value_info.append(act_mul_tensor)
+                mul_node = helper.make_node(
+                    "Mul",
+                    [act_mul_tensor.name, mul_tensor.name],
+                    [n.output[0]],
+                )
+                graph.node.insert(running_node_index, mul_node)
+                mul_node_main_branch = graph.node[running_node_index]
+                running_node_index += 1
+                matMul_node.output[0] = act_mul_tensor.name
+
+                # Other branch: Insert Mul: beta * C
+                beta = get_by_name(n.attribute, "beta")
+                if alpha is None:
+                    beta = np.array(1.0)
+                else:
+                    beta = np.array(beta.f)
+                mul_tensor = helper.make_tensor_value_info(
+                    model.make_new_valueinfo_name(),
+                    TensorProto.FLOAT,
+                    None,
+                )
+                graph.value_info.append(mul_tensor)
+                model.set_initializer(mul_tensor.name, beta)
+
+                act_mul_tensor = helper.make_tensor_value_info(
+                    model.make_new_valueinfo_name(),
+                    TensorProto.FLOAT,
+                    None,
+                )
+                graph.value_info.append(act_mul_tensor)
+                mul_node = helper.make_node(
+                    "Mul",
+                    [act_mul_tensor.name, mul_tensor.name],
+                    [n.input[2]],
+                )
+                graph.node.insert(running_node_index, mul_node)
+                running_node_index += 1
+                predecessors[2].output[0] = act_mul_tensor.name
+
+                # Insert Add: ((A*B) * alpha) + (beta * C)
+                act_add_tensor = helper.make_tensor_value_info(
+                    model.make_new_valueinfo_name(),
+                    TensorProto.FLOAT,
+                    None,
+                )
+                graph.value_info.append(act_add_tensor)
+                mul_node_main_branch.output[0] = act_add_tensor.name
+                add_node = helper.make_node(
+                    "Add",
+                    [act_add_tensor.name, n.input[2]],
+                    [n.output[0]],
+                )
+
+                graph.node.insert(running_node_index, add_node)
+                running_node_index += 1
+
+                # Delete Gemm node
+                graph.node.remove(n)
+
+                # Remove potential unity multiplications from alpha and beta attributes
+                model = model.transform(RemoveUnityMul())
+
+                return model, True
+
+        return model, False
+
+
+class RemoveUnityMul(Transformation):
+    """
+    Removes multiplication nodes, which have a unity initializer.
+    """
+
+    def apply(self, model):
+        graph = model.graph
+        node_ind = 0
+        for n in graph.node:
+            node_ind += 1
+            if n.op_type == "Mul":
+                init = model.get_initializer(n.input[1])
+                if init is not None:
+                    # Check if all multipliers are unity
+                    init = init.flatten()
+                    if (init == 1.0).all():
+                        predecessors = model.find_direct_predecessors(n)
+                        successors = model.find_direct_successors(n)
+                        # Check if we reached the top or bottom of the graph
+                        if predecessors is not None:
+                            for predecessor in predecessors:
+                                predecessor.output[0] = n.output[0]
+                            graph.node.remove(n)
+                            return model, True
+                        elif successors is not None:
+                            for successor in successors:
+                                successor.input[0] = n.input[0]
+                            graph.node.remove(n)
+                            return model, True
+                        else:
+                            warnings.warn(
+                                f"Can't remove empty unity multiplication node {n}, "
+                                f"due to no available successors or predecessors."
+                            )
+        return model, False
-- 
GitLab