From a3c41c5dbdaedb1620897d3cc84e97a45940c1bc Mon Sep 17 00:00:00 2001 From: Hendrik Borras <hendrikborras@web.de> Date: Wed, 13 Oct 2021 11:17:38 +0100 Subject: [PATCH] Preserve shapes and datatypes during gemm to matmul conversion. --- .../transformation/qonnx/gemm_to_matmul.py | 55 ++++++++++++++----- 1 file changed, 41 insertions(+), 14 deletions(-) diff --git a/src/finn/transformation/qonnx/gemm_to_matmul.py b/src/finn/transformation/qonnx/gemm_to_matmul.py index 6debb169f..4e1da1af6 100644 --- a/src/finn/transformation/qonnx/gemm_to_matmul.py +++ b/src/finn/transformation/qonnx/gemm_to_matmul.py @@ -30,13 +30,19 @@ import numpy as np import warnings from onnx import TensorProto, helper +from finn.core.datatype import DataType from finn.transformation.base import Transformation from finn.util.basic import get_by_name class GemmToMatMul(Transformation): """ - Converts Gemm op into a MatMul and an Add op. + Converts Gemm nodes into a MatMul and an Add nodes. + This transformation is built to support version 9 of the Gemm node, as + documented here: https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Gemm-9 + However, earlier and later versions of the node are likely to work as well. + Explicitly not supported is the optionality of input C in versions >=11 and + the broadcast attribute of version <=6. """ def apply(self, model): @@ -46,17 +52,18 @@ class GemmToMatMul(Transformation): node_ind += 1 if n.op_type == "Gemm": running_node_index = node_ind - predecessors = model.find_direct_predecessors(n) # Transpose A? transA = get_by_name(n.attribute, "transA") if transA is not None and transA.i: # Insert transpose node + shape = model.get_tensor_shape(n.input[0]) + if shape is not None: + shape = tuple(reversed(shape)) inp_trans_out = helper.make_tensor_value_info( model.make_new_valueinfo_name(), TensorProto.FLOAT, - None, - # [1024,1000], + shape, ) graph.value_info.append(inp_trans_out) inp_trans_node = helper.make_node( @@ -64,6 +71,9 @@ class GemmToMatMul(Transformation): ) graph.node.insert(running_node_index, inp_trans_node) running_node_index += 1 + dt = model.get_tensor_datatype(n.input[0]) + if dt != DataType["FLOAT32"]: + model.set_tensor_datatype(inp_trans_out.name, dt) n.input[0] = inp_trans_out.name @@ -71,11 +81,13 @@ class GemmToMatMul(Transformation): transB = get_by_name(n.attribute, "transB") if transB is not None and transB.i: # Insert transpose node + shape = model.get_tensor_shape(n.input[1]) + if shape is not None: + shape = tuple(reversed(shape)) inp_trans_out = helper.make_tensor_value_info( model.make_new_valueinfo_name(), TensorProto.FLOAT, - None, - # [1024,1000], + shape, ) graph.value_info.append(inp_trans_out) inp_trans_node = helper.make_node( @@ -83,6 +95,10 @@ class GemmToMatMul(Transformation): ) graph.node.insert(running_node_index, inp_trans_node) running_node_index += 1 + # Copy over the datatype + dt = model.get_tensor_datatype(n.input[1]) + if dt != DataType["FLOAT32"]: + model.set_tensor_datatype(inp_trans_out.name, dt) n.input[1] = inp_trans_out.name @@ -108,10 +124,16 @@ class GemmToMatMul(Transformation): graph.value_info.append(mul_tensor) model.set_initializer(mul_tensor.name, alpha) + A_shape = model.get_tensor_shape(n.input[0]) + B_shape = model.get_tensor_shape(n.input[1]) + if A_shape is not None and B_shape is not None: + shape = [A_shape[0], B_shape[1]] + else: + shape = None act_mul_tensor = helper.make_tensor_value_info( model.make_new_valueinfo_name(), TensorProto.FLOAT, - None, + shape, ) graph.value_info.append(act_mul_tensor) mul_node = helper.make_node( @@ -126,7 +148,7 @@ class GemmToMatMul(Transformation): # Other branch: Insert Mul: beta * C beta = get_by_name(n.attribute, "beta") - if alpha is None: + if beta is None: beta = np.array(1.0) else: beta = np.array(beta.f) @@ -138,26 +160,31 @@ class GemmToMatMul(Transformation): graph.value_info.append(mul_tensor) model.set_initializer(mul_tensor.name, beta) + C_shape = model.get_tensor_shape(n.input[2]) act_mul_tensor = helper.make_tensor_value_info( model.make_new_valueinfo_name(), TensorProto.FLOAT, - None, + C_shape, ) graph.value_info.append(act_mul_tensor) mul_node = helper.make_node( "Mul", - [act_mul_tensor.name, mul_tensor.name], - [n.input[2]], + [n.input[2], mul_tensor.name], + [act_mul_tensor.name], ) graph.node.insert(running_node_index, mul_node) running_node_index += 1 - predecessors[2].output[0] = act_mul_tensor.name + dt = model.get_tensor_datatype(n.input[2]) + if dt != DataType["FLOAT32"]: + model.set_tensor_datatype(act_mul_tensor.name, dt) + n.input[2] = act_mul_tensor.name # Insert Add: ((A*B) * alpha) + (beta * C) + shape = model.get_tensor_shape(mul_node_main_branch.input[0]) act_add_tensor = helper.make_tensor_value_info( model.make_new_valueinfo_name(), TensorProto.FLOAT, - None, + shape, ) graph.value_info.append(act_add_tensor) mul_node_main_branch.output[0] = act_add_tensor.name @@ -174,7 +201,7 @@ class GemmToMatMul(Transformation): graph.node.remove(n) # Remove potential unity multiplications from alpha and beta attributes - model = model.transform(RemoveUnityMul()) + # model = model.transform(RemoveUnityMul()) return model, True -- GitLab