From aef84650f4e793cb061e8ee625a90a5e15d3f750 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <maltanar@gmail.com> Date: Sat, 21 Mar 2020 00:01:55 +0000 Subject: [PATCH] [Transform] add MoveScalarAddPastConv, use in streamlining --- .../transformation/streamline/__init__.py | 2 + src/finn/transformation/streamline/reorder.py | 61 +++++++++++++++++++ 2 files changed, 63 insertions(+) diff --git a/src/finn/transformation/streamline/__init__.py b/src/finn/transformation/streamline/__init__.py index e11454dd6..a1c7f1ab3 100644 --- a/src/finn/transformation/streamline/__init__.py +++ b/src/finn/transformation/streamline/__init__.py @@ -52,6 +52,7 @@ from finn.transformation.streamline.reorder import ( MoveAddPastMul, MoveScalarMulPastMatMul, MoveScalarAddPastMatMul, + MoveScalarAddPastConv, ) from finn.transformation.streamline.round_thresholds import RoundAndClipThresholds @@ -70,6 +71,7 @@ class Streamline(Transformation): ConvertSignToThres(), MoveAddPastMul(), MoveScalarAddPastMatMul(), + MoveScalarAddPastConv(), MoveScalarMulPastMatMul(), MoveAddPastMul(), CollapseRepeatedAdd(), diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py index db55dc202..aa1010245 100644 --- a/src/finn/transformation/streamline/reorder.py +++ b/src/finn/transformation/streamline/reorder.py @@ -31,6 +31,7 @@ from onnx import helper as oh from finn.transformation import Transformation from finn.transformation.infer_shapes import InferShapes +from finn.core.onnx_exec import execute_node class MoveAddPastMul(Transformation): @@ -167,3 +168,63 @@ class MoveScalarAddPastMatMul(Transformation): graph_modified = True model = model.transform(InferShapes()) return (model, graph_modified) + + +class MoveScalarAddPastConv(Transformation): + """Move scalar add operations past conv operations. We want to have adds + next to each other such that they can be collapsed into a single add.""" + + def apply(self, model): + graph = model.graph + node_ind = 0 + graph_modified = False + for n in graph.node: + node_ind += 1 + if n.op_type == "Add": + consumer = model.find_consumer(n.output[0]) + if consumer is not None and consumer.op_type == "Conv": + conv_node = consumer + add_node = n + add_weight_name = n.input[1] + conv_weight_name = consumer.input[1] + conv_in_name = consumer.input[0] + conv_in_shape = model.get_tensor_shape(conv_in_name) + A = model.get_initializer(add_weight_name) + W = model.get_initializer(conv_weight_name) + assert A is not None, "Initializer for add weights is not set." + assert W is not None, "Initializer for conv weights is not set." + start_name = n.input[0] + end_name = consumer.output[0] + conv_out_shape = model.get_tensor_shape(end_name) + if all(x == 1 for x in A.shape): + # create a tensor filled with the add constant, in + # the shape expected by the convolution + conv_in_const = np.zeros(conv_in_shape, dtype=np.float32) + conv_in_const.fill(A.item()) + # create an execution context and put in const input + exec_ctx = model.make_empty_exec_context() + exec_ctx[conv_in_name] = conv_in_const + # execute the conv node only + execute_node(conv_node, exec_ctx, model.graph) + # retrieve the conv output + Anew = exec_ctx[end_name] + # strip out repetition + Anew = Anew[0, :, 0, 0].reshape(1, -1, 1, 1) + # update the add weight + model.set_initializer(add_weight_name, Anew) + # rewire add input to be conv input + conv_node.input[0] = start_name + model.set_tensor_shape(start_name, conv_in_shape) + # use old conv input tensor as conv output + conv_node.output[0] = conv_in_name + model.set_tensor_shape(conv_in_name, conv_out_shape) + # use new conv output as new add node input + add_node.input[0] = conv_in_name + # use old conv output as new add node output + add_node.output[0] = end_name + # move add node past conv node + graph.node.remove(add_node) + graph.node.insert(node_ind, add_node) + graph_modified = True + model = model.transform(InferShapes()) + return (model, graph_modified) -- GitLab