Skip to content
Snippets Groups Projects
Commit aef84650 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Transform] add MoveScalarAddPastConv, use in streamlining

parent 4932f858
No related branches found
No related tags found
No related merge requests found
......@@ -52,6 +52,7 @@ from finn.transformation.streamline.reorder import (
MoveAddPastMul,
MoveScalarMulPastMatMul,
MoveScalarAddPastMatMul,
MoveScalarAddPastConv,
)
from finn.transformation.streamline.round_thresholds import RoundAndClipThresholds
......@@ -70,6 +71,7 @@ class Streamline(Transformation):
ConvertSignToThres(),
MoveAddPastMul(),
MoveScalarAddPastMatMul(),
MoveScalarAddPastConv(),
MoveScalarMulPastMatMul(),
MoveAddPastMul(),
CollapseRepeatedAdd(),
......
......@@ -31,6 +31,7 @@ from onnx import helper as oh
from finn.transformation import Transformation
from finn.transformation.infer_shapes import InferShapes
from finn.core.onnx_exec import execute_node
class MoveAddPastMul(Transformation):
......@@ -167,3 +168,63 @@ class MoveScalarAddPastMatMul(Transformation):
graph_modified = True
model = model.transform(InferShapes())
return (model, graph_modified)
class MoveScalarAddPastConv(Transformation):
"""Move scalar add operations past conv operations. We want to have adds
next to each other such that they can be collapsed into a single add."""
def apply(self, model):
graph = model.graph
node_ind = 0
graph_modified = False
for n in graph.node:
node_ind += 1
if n.op_type == "Add":
consumer = model.find_consumer(n.output[0])
if consumer is not None and consumer.op_type == "Conv":
conv_node = consumer
add_node = n
add_weight_name = n.input[1]
conv_weight_name = consumer.input[1]
conv_in_name = consumer.input[0]
conv_in_shape = model.get_tensor_shape(conv_in_name)
A = model.get_initializer(add_weight_name)
W = model.get_initializer(conv_weight_name)
assert A is not None, "Initializer for add weights is not set."
assert W is not None, "Initializer for conv weights is not set."
start_name = n.input[0]
end_name = consumer.output[0]
conv_out_shape = model.get_tensor_shape(end_name)
if all(x == 1 for x in A.shape):
# create a tensor filled with the add constant, in
# the shape expected by the convolution
conv_in_const = np.zeros(conv_in_shape, dtype=np.float32)
conv_in_const.fill(A.item())
# create an execution context and put in const input
exec_ctx = model.make_empty_exec_context()
exec_ctx[conv_in_name] = conv_in_const
# execute the conv node only
execute_node(conv_node, exec_ctx, model.graph)
# retrieve the conv output
Anew = exec_ctx[end_name]
# strip out repetition
Anew = Anew[0, :, 0, 0].reshape(1, -1, 1, 1)
# update the add weight
model.set_initializer(add_weight_name, Anew)
# rewire add input to be conv input
conv_node.input[0] = start_name
model.set_tensor_shape(start_name, conv_in_shape)
# use old conv input tensor as conv output
conv_node.output[0] = conv_in_name
model.set_tensor_shape(conv_in_name, conv_out_shape)
# use new conv output as new add node input
add_node.input[0] = conv_in_name
# use old conv output as new add node output
add_node.output[0] = end_name
# move add node past conv node
graph.node.remove(add_node)
graph.node.insert(node_ind, add_node)
graph_modified = True
model = model.transform(InferShapes())
return (model, graph_modified)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment