Skip to content
Snippets Groups Projects
Commit c48b2641 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Transform] add MoveScalarMulPastConv and call in streamlining

parent aef84650
No related branches found
No related tags found
No related merge requests found
......@@ -53,6 +53,7 @@ from finn.transformation.streamline.reorder import (
MoveScalarMulPastMatMul,
MoveScalarAddPastMatMul,
MoveScalarAddPastConv,
MoveScalarMulPastConv,
)
from finn.transformation.streamline.round_thresholds import RoundAndClipThresholds
......@@ -73,6 +74,7 @@ class Streamline(Transformation):
MoveScalarAddPastMatMul(),
MoveScalarAddPastConv(),
MoveScalarMulPastMatMul(),
MoveScalarMulPastConv(),
MoveAddPastMul(),
CollapseRepeatedAdd(),
CollapseRepeatedMul(),
......
......@@ -186,13 +186,10 @@ class MoveScalarAddPastConv(Transformation):
conv_node = consumer
add_node = n
add_weight_name = n.input[1]
conv_weight_name = consumer.input[1]
conv_in_name = consumer.input[0]
conv_in_shape = model.get_tensor_shape(conv_in_name)
A = model.get_initializer(add_weight_name)
W = model.get_initializer(conv_weight_name)
assert A is not None, "Initializer for add weights is not set."
assert W is not None, "Initializer for conv weights is not set."
start_name = n.input[0]
end_name = consumer.output[0]
conv_out_shape = model.get_tensor_shape(end_name)
......@@ -228,3 +225,46 @@ class MoveScalarAddPastConv(Transformation):
graph_modified = True
model = model.transform(InferShapes())
return (model, graph_modified)
class MoveScalarMulPastConv(Transformation):
"""Move scalar mul operations past conv operations. We want to have muls
next to each other such that they can be collapsed into a single mul."""
def apply(self, model):
graph = model.graph
node_ind = 0
graph_modified = False
for n in graph.node:
node_ind += 1
if n.op_type == "Mul":
consumer = model.find_consumer(n.output[0])
if consumer is not None and consumer.op_type == "Conv":
mul_weight_name = n.input[1]
A = model.get_initializer(mul_weight_name)
assert A is not None, "Initializer for mul weights is not set."
conv_node = consumer
mul_node = n
start_name = mul_node.input[0]
conv_in_name = conv_node.input[0]
conv_in_shape = model.get_tensor_shape(conv_in_name)
conv_out_name = conv_node.output[0]
conv_out_shape = model.get_tensor_shape(conv_out_name)
if all(x == 1 for x in A.shape):
# if the mul is scalar, we can simply swap the order of ops
# rewire mul input to be conv input
conv_node.input[0] = start_name
model.set_tensor_shape(start_name, conv_in_shape)
# use old conv input tensor as conv output
conv_node.output[0] = conv_in_name
model.set_tensor_shape(conv_in_name, conv_out_shape)
# use new conv output as new mul node input
mul_node.input[0] = conv_in_name
# use old conv output as new mul node output
mul_node.output[0] = conv_out_name
# move add node past conv node
graph.node.remove(mul_node)
graph.node.insert(node_ind, mul_node)
graph_modified = True
model = model.transform(InferShapes())
return (model, graph_modified)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment