Skip to content
Snippets Groups Projects
Commit efe1a7a3 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Transform] add move_scalar_add_past_matmul

parent d9cf7042
No related branches found
No related tags found
No related merge requests found
import numpy as np
from onnx import helper as oh
import finn.transformation.infer_shapes as si
......@@ -114,10 +115,6 @@ def move_scalar_mul_past_matmul(model):
if n.op_type == "Mul":
consumer = model.find_consumer(n.output[0])
if consumer is not None and consumer.op_type == "MatMul":
# have: (x) -> add(,B) -> (x+B) -> mul(,A) -> (xA+BA)
# want: (x) -> mul(,A) -> (xA) -> add(,BA) -> (xA+BA)
# assume input 0 is from the previous layer, input 1 is the
# trained (constant) parameter
mul_weight_name = n.input[1]
matmul_weight_name = consumer.input[1]
A = model.get_initializer(mul_weight_name)
......@@ -147,3 +144,45 @@ def move_scalar_mul_past_matmul(model):
graph_modified = True
model = model.transform_single(si.infer_shapes)
return (model, graph_modified)
def move_scalar_add_past_matmul(model):
"""Move scalar add operations past matmul operations. We want to have adds
next to each other such that they can be collapsed into a single add."""
graph = model.graph
node_ind = 0
graph_modified = False
for n in graph.node:
node_ind += 1
if n.op_type == "Add":
consumer = model.find_consumer(n.output[0])
if consumer is not None and consumer.op_type == "MatMul":
add_weight_name = n.input[1]
matmul_weight_name = consumer.input[1]
A = model.get_initializer(add_weight_name)
W = model.get_initializer(matmul_weight_name)
assert A is not None
assert W is not None
start_name = n.input[0]
middle_name = n.output[0]
end_name = consumer.output[0]
if all(x == 1 for x in A.shape):
# if the add is scalar, we can move it past the matmul
# by taking it past the matmul with a dot product
Anew = np.dot(A * np.ones(W.shape[0], dtype=np.float32), W)
# update the add weight
model.set_initializer(add_weight_name, Anew)
new_matmul = oh.make_node(
"MatMul", [start_name, matmul_weight_name], [middle_name]
)
new_add = oh.make_node(
"Add", [middle_name, add_weight_name], [end_name]
)
graph.node.insert(node_ind, new_matmul)
graph.node.insert(node_ind + 1, new_add)
# remove old nodes
graph.node.remove(n)
graph.node.remove(consumer)
graph_modified = True
model = model.transform_single(si.infer_shapes)
return (model, graph_modified)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment