Skip to content
Snippets Groups Projects
Commit a3c41c5d authored by Hendrik Borras's avatar Hendrik Borras
Browse files

Preserve shapes and datatypes during gemm to matmul conversion.

parent 1b5d1fb4
No related branches found
No related tags found
No related merge requests found
......@@ -30,13 +30,19 @@ import numpy as np
import warnings
from onnx import TensorProto, helper
from finn.core.datatype import DataType
from finn.transformation.base import Transformation
from finn.util.basic import get_by_name
class GemmToMatMul(Transformation):
"""
Converts Gemm op into a MatMul and an Add op.
Converts Gemm nodes into a MatMul and an Add nodes.
This transformation is built to support version 9 of the Gemm node, as
documented here: https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Gemm-9
However, earlier and later versions of the node are likely to work as well.
Explicitly not supported is the optionality of input C in versions >=11 and
the broadcast attribute of version <=6.
"""
def apply(self, model):
......@@ -46,17 +52,18 @@ class GemmToMatMul(Transformation):
node_ind += 1
if n.op_type == "Gemm":
running_node_index = node_ind
predecessors = model.find_direct_predecessors(n)
# Transpose A?
transA = get_by_name(n.attribute, "transA")
if transA is not None and transA.i:
# Insert transpose node
shape = model.get_tensor_shape(n.input[0])
if shape is not None:
shape = tuple(reversed(shape))
inp_trans_out = helper.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
None,
# [1024,1000],
shape,
)
graph.value_info.append(inp_trans_out)
inp_trans_node = helper.make_node(
......@@ -64,6 +71,9 @@ class GemmToMatMul(Transformation):
)
graph.node.insert(running_node_index, inp_trans_node)
running_node_index += 1
dt = model.get_tensor_datatype(n.input[0])
if dt != DataType["FLOAT32"]:
model.set_tensor_datatype(inp_trans_out.name, dt)
n.input[0] = inp_trans_out.name
......@@ -71,11 +81,13 @@ class GemmToMatMul(Transformation):
transB = get_by_name(n.attribute, "transB")
if transB is not None and transB.i:
# Insert transpose node
shape = model.get_tensor_shape(n.input[1])
if shape is not None:
shape = tuple(reversed(shape))
inp_trans_out = helper.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
None,
# [1024,1000],
shape,
)
graph.value_info.append(inp_trans_out)
inp_trans_node = helper.make_node(
......@@ -83,6 +95,10 @@ class GemmToMatMul(Transformation):
)
graph.node.insert(running_node_index, inp_trans_node)
running_node_index += 1
# Copy over the datatype
dt = model.get_tensor_datatype(n.input[1])
if dt != DataType["FLOAT32"]:
model.set_tensor_datatype(inp_trans_out.name, dt)
n.input[1] = inp_trans_out.name
......@@ -108,10 +124,16 @@ class GemmToMatMul(Transformation):
graph.value_info.append(mul_tensor)
model.set_initializer(mul_tensor.name, alpha)
A_shape = model.get_tensor_shape(n.input[0])
B_shape = model.get_tensor_shape(n.input[1])
if A_shape is not None and B_shape is not None:
shape = [A_shape[0], B_shape[1]]
else:
shape = None
act_mul_tensor = helper.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
None,
shape,
)
graph.value_info.append(act_mul_tensor)
mul_node = helper.make_node(
......@@ -126,7 +148,7 @@ class GemmToMatMul(Transformation):
# Other branch: Insert Mul: beta * C
beta = get_by_name(n.attribute, "beta")
if alpha is None:
if beta is None:
beta = np.array(1.0)
else:
beta = np.array(beta.f)
......@@ -138,26 +160,31 @@ class GemmToMatMul(Transformation):
graph.value_info.append(mul_tensor)
model.set_initializer(mul_tensor.name, beta)
C_shape = model.get_tensor_shape(n.input[2])
act_mul_tensor = helper.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
None,
C_shape,
)
graph.value_info.append(act_mul_tensor)
mul_node = helper.make_node(
"Mul",
[act_mul_tensor.name, mul_tensor.name],
[n.input[2]],
[n.input[2], mul_tensor.name],
[act_mul_tensor.name],
)
graph.node.insert(running_node_index, mul_node)
running_node_index += 1
predecessors[2].output[0] = act_mul_tensor.name
dt = model.get_tensor_datatype(n.input[2])
if dt != DataType["FLOAT32"]:
model.set_tensor_datatype(act_mul_tensor.name, dt)
n.input[2] = act_mul_tensor.name
# Insert Add: ((A*B) * alpha) + (beta * C)
shape = model.get_tensor_shape(mul_node_main_branch.input[0])
act_add_tensor = helper.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
None,
shape,
)
graph.value_info.append(act_add_tensor)
mul_node_main_branch.output[0] = act_add_tensor.name
......@@ -174,7 +201,7 @@ class GemmToMatMul(Transformation):
graph.node.remove(n)
# Remove potential unity multiplications from alpha and beta attributes
model = model.transform(RemoveUnityMul())
# model = model.transform(RemoveUnityMul())
return model, True
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment