Skip to content
Snippets Groups Projects
Commit 52302ca6 authored by auphelia's avatar auphelia
Browse files

[Transform] Add trafo to move channelwise mul past dw conv

parent dde30255
No related branches found
No related tags found
No related merge requests found
......@@ -338,6 +338,60 @@ class MoveScalarMulPastConv(Transformation):
return (model, graph_modified)
class MoveMulPastDWConv(Transformation):
"""Move channelwise mul operations past depthwise conv operations. We want to have muls
next to each other such that they can be collapsed into a single mul."""
def apply(self, model):
graph = model.graph
node_ind = 0
graph_modified = False
for n in graph.node:
node_ind += 1
if (
n.op_type == "Mul"
and not model.is_fork_node(n)
and not model.is_join_node(n)
):
consumer = model.find_consumer(n.output[0])
if (
consumer is not None
and consumer.op_type == "Conv"
and not model.is_join_node(consumer)
):
mul_weight_name = n.input[1]
A = model.get_initializer(mul_weight_name)
assert A is not None, "Initializer for mul weights is not set."
conv_node = consumer
mul_node = n
start_name = mul_node.input[0]
conv_in_name = conv_node.input[0]
conv_in_shape = model.get_tensor_shape(conv_in_name)
ifm_ch = conv_in_shape[1]
group_attribute = get_by_name(consumer.attribute, "group").i
conv_out_name = conv_node.output[0]
conv_out_shape = model.get_tensor_shape(conv_out_name)
if np.prod(A.shape) == ifm_ch == group_attribute:
# if the mul is channelwise and conv is depthwise,
# we can simply swap the order of ops
# rewire mul input to be conv input
conv_node.input[0] = start_name
model.set_tensor_shape(start_name, conv_in_shape)
# use old conv input tensor as conv output
conv_node.output[0] = conv_in_name
model.set_tensor_shape(conv_in_name, conv_out_shape)
# use new conv output as new mul node input
mul_node.input[0] = conv_in_name
# use old conv output as new mul node output
mul_node.output[0] = conv_out_name
# move mul node past conv node
graph.node.remove(mul_node)
graph.node.insert(node_ind, mul_node)
graph_modified = True
model = model.transform(InferShapes())
return (model, graph_modified)
class MoveLinearPastEltwiseAdd(Transformation):
"""Move linear operations (mul, add) past elementwise add operations where possible.
Specifically,matches and transforms the following patterns:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment