From a3c41c5dbdaedb1620897d3cc84e97a45940c1bc Mon Sep 17 00:00:00 2001
From: Hendrik Borras <hendrikborras@web.de>
Date: Wed, 13 Oct 2021 11:17:38 +0100
Subject: [PATCH] Preserve shapes and datatypes during gemm to matmul
 conversion.

---
 .../transformation/qonnx/gemm_to_matmul.py    | 55 ++++++++++++++-----
 1 file changed, 41 insertions(+), 14 deletions(-)

diff --git a/src/finn/transformation/qonnx/gemm_to_matmul.py b/src/finn/transformation/qonnx/gemm_to_matmul.py
index 6debb169f..4e1da1af6 100644
--- a/src/finn/transformation/qonnx/gemm_to_matmul.py
+++ b/src/finn/transformation/qonnx/gemm_to_matmul.py
@@ -30,13 +30,19 @@ import numpy as np
 import warnings
 from onnx import TensorProto, helper
 
+from finn.core.datatype import DataType
 from finn.transformation.base import Transformation
 from finn.util.basic import get_by_name
 
 
 class GemmToMatMul(Transformation):
     """
-    Converts Gemm op into a MatMul and an Add op.
+    Converts Gemm nodes into a MatMul and an Add nodes.
+    This transformation is built to support version 9 of the Gemm node, as
+    documented here: https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Gemm-9
+    However, earlier and later versions of the node are likely to work as well.
+    Explicitly not supported is the optionality of input C in versions >=11 and
+    the broadcast attribute of version <=6.
     """
 
     def apply(self, model):
@@ -46,17 +52,18 @@ class GemmToMatMul(Transformation):
             node_ind += 1
             if n.op_type == "Gemm":
                 running_node_index = node_ind
-                predecessors = model.find_direct_predecessors(n)
 
                 # Transpose A?
                 transA = get_by_name(n.attribute, "transA")
                 if transA is not None and transA.i:
                     # Insert transpose node
+                    shape = model.get_tensor_shape(n.input[0])
+                    if shape is not None:
+                        shape = tuple(reversed(shape))
                     inp_trans_out = helper.make_tensor_value_info(
                         model.make_new_valueinfo_name(),
                         TensorProto.FLOAT,
-                        None,
-                        # [1024,1000],
+                        shape,
                     )
                     graph.value_info.append(inp_trans_out)
                     inp_trans_node = helper.make_node(
@@ -64,6 +71,9 @@ class GemmToMatMul(Transformation):
                     )
                     graph.node.insert(running_node_index, inp_trans_node)
                     running_node_index += 1
+                    dt = model.get_tensor_datatype(n.input[0])
+                    if dt != DataType["FLOAT32"]:
+                        model.set_tensor_datatype(inp_trans_out.name, dt)
 
                     n.input[0] = inp_trans_out.name
 
@@ -71,11 +81,13 @@ class GemmToMatMul(Transformation):
                 transB = get_by_name(n.attribute, "transB")
                 if transB is not None and transB.i:
                     # Insert transpose node
+                    shape = model.get_tensor_shape(n.input[1])
+                    if shape is not None:
+                        shape = tuple(reversed(shape))
                     inp_trans_out = helper.make_tensor_value_info(
                         model.make_new_valueinfo_name(),
                         TensorProto.FLOAT,
-                        None,
-                        # [1024,1000],
+                        shape,
                     )
                     graph.value_info.append(inp_trans_out)
                     inp_trans_node = helper.make_node(
@@ -83,6 +95,10 @@ class GemmToMatMul(Transformation):
                     )
                     graph.node.insert(running_node_index, inp_trans_node)
                     running_node_index += 1
+                    # Copy over the datatype
+                    dt = model.get_tensor_datatype(n.input[1])
+                    if dt != DataType["FLOAT32"]:
+                        model.set_tensor_datatype(inp_trans_out.name, dt)
 
                     n.input[1] = inp_trans_out.name
 
@@ -108,10 +124,16 @@ class GemmToMatMul(Transformation):
                 graph.value_info.append(mul_tensor)
                 model.set_initializer(mul_tensor.name, alpha)
 
+                A_shape = model.get_tensor_shape(n.input[0])
+                B_shape = model.get_tensor_shape(n.input[1])
+                if A_shape is not None and B_shape is not None:
+                    shape = [A_shape[0], B_shape[1]]
+                else:
+                    shape = None
                 act_mul_tensor = helper.make_tensor_value_info(
                     model.make_new_valueinfo_name(),
                     TensorProto.FLOAT,
-                    None,
+                    shape,
                 )
                 graph.value_info.append(act_mul_tensor)
                 mul_node = helper.make_node(
@@ -126,7 +148,7 @@ class GemmToMatMul(Transformation):
 
                 # Other branch: Insert Mul: beta * C
                 beta = get_by_name(n.attribute, "beta")
-                if alpha is None:
+                if beta is None:
                     beta = np.array(1.0)
                 else:
                     beta = np.array(beta.f)
@@ -138,26 +160,31 @@ class GemmToMatMul(Transformation):
                 graph.value_info.append(mul_tensor)
                 model.set_initializer(mul_tensor.name, beta)
 
+                C_shape = model.get_tensor_shape(n.input[2])
                 act_mul_tensor = helper.make_tensor_value_info(
                     model.make_new_valueinfo_name(),
                     TensorProto.FLOAT,
-                    None,
+                    C_shape,
                 )
                 graph.value_info.append(act_mul_tensor)
                 mul_node = helper.make_node(
                     "Mul",
-                    [act_mul_tensor.name, mul_tensor.name],
-                    [n.input[2]],
+                    [n.input[2], mul_tensor.name],
+                    [act_mul_tensor.name],
                 )
                 graph.node.insert(running_node_index, mul_node)
                 running_node_index += 1
-                predecessors[2].output[0] = act_mul_tensor.name
+                dt = model.get_tensor_datatype(n.input[2])
+                if dt != DataType["FLOAT32"]:
+                    model.set_tensor_datatype(act_mul_tensor.name, dt)
+                n.input[2] = act_mul_tensor.name
 
                 # Insert Add: ((A*B) * alpha) + (beta * C)
+                shape = model.get_tensor_shape(mul_node_main_branch.input[0])
                 act_add_tensor = helper.make_tensor_value_info(
                     model.make_new_valueinfo_name(),
                     TensorProto.FLOAT,
-                    None,
+                    shape,
                 )
                 graph.value_info.append(act_add_tensor)
                 mul_node_main_branch.output[0] = act_add_tensor.name
@@ -174,7 +201,7 @@ class GemmToMatMul(Transformation):
                 graph.node.remove(n)
 
                 # Remove potential unity multiplications from alpha and beta attributes
-                model = model.transform(RemoveUnityMul())
+                # model = model.transform(RemoveUnityMul())
 
                 return model, True
 
-- 
GitLab