Skip to content
Snippets Groups Projects
Commit 2efefaa1 authored by auphelia's avatar auphelia
Browse files

[Streamline] Add trafo to move transpose node past scalar mul

parent d000c97d
No related branches found
No related tags found
No related merge requests found
......@@ -597,3 +597,50 @@ class MoveMaxPoolPastMultiThreshold(Transformation):
model = model.transform(InferShapes())
return (model, graph_modified)
class MoveTransposePastScalarMul(Transformation):
"""Moves a Transpose node past a scalar Mul node"""
def apply(self, model):
graph = model.graph
node_ind = 0
graph_modified = False
for n in graph.node:
node_ind += 1
if (
n.op_type == "Transpose"
and not model.is_fork_node(n)
and not model.is_join_node(n)
):
consumer = model.find_consumer(n.output[0])
if (
consumer is not None
and consumer.op_type == "Mul"
and not model.is_join_node(consumer)
):
mul_weight_name = consumer.input[1]
A = model.get_initializer(mul_weight_name)
assert A is not None, "Initializer for mul weights is not set."
transp_node = n
mul_node = consumer
start_name = transp_node.input[0]
middle_name = transp_node.output[0]
end_name = mul_node.output[0]
transp_in_shape = model.get_tensor_shape(start_name)
transp_out_shape = model.get_tensor_shape(middle_name)
if all(x == 1 for x in A.shape):
# if the mul is scalar, we can simply swap the order of ops
# rewire transpose input to be mul input
mul_node.input[0] = start_name
model.set_tensor_shape(start_name, transp_in_shape)
mul_node.output[0] = middle_name
model.set_tensor_shape(middle_name, transp_in_shape)
transp_node.input[0] = middle_name
transp_node.output[0] = end_name
model.set_tensor_shape(end_name, transp_out_shape)
graph.node.remove(transp_node)
graph.node.insert(node_ind, transp_node)
graph_modified = True
model = model.transform(InferShapes())
return (model, graph_modified)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment