Skip to content
Snippets Groups Projects
Commit a3c41c5d authored by Hendrik Borras's avatar Hendrik Borras
Browse files

Preserve shapes and datatypes during gemm to matmul conversion.

parent 1b5d1fb4
No related branches found
No related tags found
No related merge requests found
...@@ -30,13 +30,19 @@ import numpy as np ...@@ -30,13 +30,19 @@ import numpy as np
import warnings import warnings
from onnx import TensorProto, helper from onnx import TensorProto, helper
from finn.core.datatype import DataType
from finn.transformation.base import Transformation from finn.transformation.base import Transformation
from finn.util.basic import get_by_name from finn.util.basic import get_by_name
class GemmToMatMul(Transformation): class GemmToMatMul(Transformation):
""" """
Converts Gemm op into a MatMul and an Add op. Converts Gemm nodes into a MatMul and an Add nodes.
This transformation is built to support version 9 of the Gemm node, as
documented here: https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Gemm-9
However, earlier and later versions of the node are likely to work as well.
Explicitly not supported is the optionality of input C in versions >=11 and
the broadcast attribute of version <=6.
""" """
def apply(self, model): def apply(self, model):
...@@ -46,17 +52,18 @@ class GemmToMatMul(Transformation): ...@@ -46,17 +52,18 @@ class GemmToMatMul(Transformation):
node_ind += 1 node_ind += 1
if n.op_type == "Gemm": if n.op_type == "Gemm":
running_node_index = node_ind running_node_index = node_ind
predecessors = model.find_direct_predecessors(n)
# Transpose A? # Transpose A?
transA = get_by_name(n.attribute, "transA") transA = get_by_name(n.attribute, "transA")
if transA is not None and transA.i: if transA is not None and transA.i:
# Insert transpose node # Insert transpose node
shape = model.get_tensor_shape(n.input[0])
if shape is not None:
shape = tuple(reversed(shape))
inp_trans_out = helper.make_tensor_value_info( inp_trans_out = helper.make_tensor_value_info(
model.make_new_valueinfo_name(), model.make_new_valueinfo_name(),
TensorProto.FLOAT, TensorProto.FLOAT,
None, shape,
# [1024,1000],
) )
graph.value_info.append(inp_trans_out) graph.value_info.append(inp_trans_out)
inp_trans_node = helper.make_node( inp_trans_node = helper.make_node(
...@@ -64,6 +71,9 @@ class GemmToMatMul(Transformation): ...@@ -64,6 +71,9 @@ class GemmToMatMul(Transformation):
) )
graph.node.insert(running_node_index, inp_trans_node) graph.node.insert(running_node_index, inp_trans_node)
running_node_index += 1 running_node_index += 1
dt = model.get_tensor_datatype(n.input[0])
if dt != DataType["FLOAT32"]:
model.set_tensor_datatype(inp_trans_out.name, dt)
n.input[0] = inp_trans_out.name n.input[0] = inp_trans_out.name
...@@ -71,11 +81,13 @@ class GemmToMatMul(Transformation): ...@@ -71,11 +81,13 @@ class GemmToMatMul(Transformation):
transB = get_by_name(n.attribute, "transB") transB = get_by_name(n.attribute, "transB")
if transB is not None and transB.i: if transB is not None and transB.i:
# Insert transpose node # Insert transpose node
shape = model.get_tensor_shape(n.input[1])
if shape is not None:
shape = tuple(reversed(shape))
inp_trans_out = helper.make_tensor_value_info( inp_trans_out = helper.make_tensor_value_info(
model.make_new_valueinfo_name(), model.make_new_valueinfo_name(),
TensorProto.FLOAT, TensorProto.FLOAT,
None, shape,
# [1024,1000],
) )
graph.value_info.append(inp_trans_out) graph.value_info.append(inp_trans_out)
inp_trans_node = helper.make_node( inp_trans_node = helper.make_node(
...@@ -83,6 +95,10 @@ class GemmToMatMul(Transformation): ...@@ -83,6 +95,10 @@ class GemmToMatMul(Transformation):
) )
graph.node.insert(running_node_index, inp_trans_node) graph.node.insert(running_node_index, inp_trans_node)
running_node_index += 1 running_node_index += 1
# Copy over the datatype
dt = model.get_tensor_datatype(n.input[1])
if dt != DataType["FLOAT32"]:
model.set_tensor_datatype(inp_trans_out.name, dt)
n.input[1] = inp_trans_out.name n.input[1] = inp_trans_out.name
...@@ -108,10 +124,16 @@ class GemmToMatMul(Transformation): ...@@ -108,10 +124,16 @@ class GemmToMatMul(Transformation):
graph.value_info.append(mul_tensor) graph.value_info.append(mul_tensor)
model.set_initializer(mul_tensor.name, alpha) model.set_initializer(mul_tensor.name, alpha)
A_shape = model.get_tensor_shape(n.input[0])
B_shape = model.get_tensor_shape(n.input[1])
if A_shape is not None and B_shape is not None:
shape = [A_shape[0], B_shape[1]]
else:
shape = None
act_mul_tensor = helper.make_tensor_value_info( act_mul_tensor = helper.make_tensor_value_info(
model.make_new_valueinfo_name(), model.make_new_valueinfo_name(),
TensorProto.FLOAT, TensorProto.FLOAT,
None, shape,
) )
graph.value_info.append(act_mul_tensor) graph.value_info.append(act_mul_tensor)
mul_node = helper.make_node( mul_node = helper.make_node(
...@@ -126,7 +148,7 @@ class GemmToMatMul(Transformation): ...@@ -126,7 +148,7 @@ class GemmToMatMul(Transformation):
# Other branch: Insert Mul: beta * C # Other branch: Insert Mul: beta * C
beta = get_by_name(n.attribute, "beta") beta = get_by_name(n.attribute, "beta")
if alpha is None: if beta is None:
beta = np.array(1.0) beta = np.array(1.0)
else: else:
beta = np.array(beta.f) beta = np.array(beta.f)
...@@ -138,26 +160,31 @@ class GemmToMatMul(Transformation): ...@@ -138,26 +160,31 @@ class GemmToMatMul(Transformation):
graph.value_info.append(mul_tensor) graph.value_info.append(mul_tensor)
model.set_initializer(mul_tensor.name, beta) model.set_initializer(mul_tensor.name, beta)
C_shape = model.get_tensor_shape(n.input[2])
act_mul_tensor = helper.make_tensor_value_info( act_mul_tensor = helper.make_tensor_value_info(
model.make_new_valueinfo_name(), model.make_new_valueinfo_name(),
TensorProto.FLOAT, TensorProto.FLOAT,
None, C_shape,
) )
graph.value_info.append(act_mul_tensor) graph.value_info.append(act_mul_tensor)
mul_node = helper.make_node( mul_node = helper.make_node(
"Mul", "Mul",
[act_mul_tensor.name, mul_tensor.name], [n.input[2], mul_tensor.name],
[n.input[2]], [act_mul_tensor.name],
) )
graph.node.insert(running_node_index, mul_node) graph.node.insert(running_node_index, mul_node)
running_node_index += 1 running_node_index += 1
predecessors[2].output[0] = act_mul_tensor.name dt = model.get_tensor_datatype(n.input[2])
if dt != DataType["FLOAT32"]:
model.set_tensor_datatype(act_mul_tensor.name, dt)
n.input[2] = act_mul_tensor.name
# Insert Add: ((A*B) * alpha) + (beta * C) # Insert Add: ((A*B) * alpha) + (beta * C)
shape = model.get_tensor_shape(mul_node_main_branch.input[0])
act_add_tensor = helper.make_tensor_value_info( act_add_tensor = helper.make_tensor_value_info(
model.make_new_valueinfo_name(), model.make_new_valueinfo_name(),
TensorProto.FLOAT, TensorProto.FLOAT,
None, shape,
) )
graph.value_info.append(act_add_tensor) graph.value_info.append(act_add_tensor)
mul_node_main_branch.output[0] = act_add_tensor.name mul_node_main_branch.output[0] = act_add_tensor.name
...@@ -174,7 +201,7 @@ class GemmToMatMul(Transformation): ...@@ -174,7 +201,7 @@ class GemmToMatMul(Transformation):
graph.node.remove(n) graph.node.remove(n)
# Remove potential unity multiplications from alpha and beta attributes # Remove potential unity multiplications from alpha and beta attributes
model = model.transform(RemoveUnityMul()) # model = model.transform(RemoveUnityMul())
return model, True return model, True
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment