From aef84650f4e793cb061e8ee625a90a5e15d3f750 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Sat, 21 Mar 2020 00:01:55 +0000
Subject: [PATCH] [Transform] add MoveScalarAddPastConv, use in streamlining

---
 .../transformation/streamline/__init__.py     |  2 +
 src/finn/transformation/streamline/reorder.py | 61 +++++++++++++++++++
 2 files changed, 63 insertions(+)

diff --git a/src/finn/transformation/streamline/__init__.py b/src/finn/transformation/streamline/__init__.py
index e11454dd6..a1c7f1ab3 100644
--- a/src/finn/transformation/streamline/__init__.py
+++ b/src/finn/transformation/streamline/__init__.py
@@ -52,6 +52,7 @@ from finn.transformation.streamline.reorder import (
     MoveAddPastMul,
     MoveScalarMulPastMatMul,
     MoveScalarAddPastMatMul,
+    MoveScalarAddPastConv,
 )
 
 from finn.transformation.streamline.round_thresholds import RoundAndClipThresholds
@@ -70,6 +71,7 @@ class Streamline(Transformation):
             ConvertSignToThres(),
             MoveAddPastMul(),
             MoveScalarAddPastMatMul(),
+            MoveScalarAddPastConv(),
             MoveScalarMulPastMatMul(),
             MoveAddPastMul(),
             CollapseRepeatedAdd(),
diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py
index db55dc202..aa1010245 100644
--- a/src/finn/transformation/streamline/reorder.py
+++ b/src/finn/transformation/streamline/reorder.py
@@ -31,6 +31,7 @@ from onnx import helper as oh
 
 from finn.transformation import Transformation
 from finn.transformation.infer_shapes import InferShapes
+from finn.core.onnx_exec import execute_node
 
 
 class MoveAddPastMul(Transformation):
@@ -167,3 +168,63 @@ class MoveScalarAddPastMatMul(Transformation):
                         graph_modified = True
         model = model.transform(InferShapes())
         return (model, graph_modified)
+
+
+class MoveScalarAddPastConv(Transformation):
+    """Move scalar add operations past conv operations. We want to have adds
+    next to each other such that they can be collapsed into a single add."""
+
+    def apply(self, model):
+        graph = model.graph
+        node_ind = 0
+        graph_modified = False
+        for n in graph.node:
+            node_ind += 1
+            if n.op_type == "Add":
+                consumer = model.find_consumer(n.output[0])
+                if consumer is not None and consumer.op_type == "Conv":
+                    conv_node = consumer
+                    add_node = n
+                    add_weight_name = n.input[1]
+                    conv_weight_name = consumer.input[1]
+                    conv_in_name = consumer.input[0]
+                    conv_in_shape = model.get_tensor_shape(conv_in_name)
+                    A = model.get_initializer(add_weight_name)
+                    W = model.get_initializer(conv_weight_name)
+                    assert A is not None, "Initializer for add weights is not set."
+                    assert W is not None, "Initializer for conv weights is not set."
+                    start_name = n.input[0]
+                    end_name = consumer.output[0]
+                    conv_out_shape = model.get_tensor_shape(end_name)
+                    if all(x == 1 for x in A.shape):
+                        # create a tensor filled with the add constant, in
+                        # the shape expected by the convolution
+                        conv_in_const = np.zeros(conv_in_shape, dtype=np.float32)
+                        conv_in_const.fill(A.item())
+                        # create an execution context and put in const input
+                        exec_ctx = model.make_empty_exec_context()
+                        exec_ctx[conv_in_name] = conv_in_const
+                        # execute the conv node only
+                        execute_node(conv_node, exec_ctx, model.graph)
+                        # retrieve the conv output
+                        Anew = exec_ctx[end_name]
+                        # strip out repetition
+                        Anew = Anew[0, :, 0, 0].reshape(1, -1, 1, 1)
+                        # update the add weight
+                        model.set_initializer(add_weight_name, Anew)
+                        # rewire add input to be conv input
+                        conv_node.input[0] = start_name
+                        model.set_tensor_shape(start_name, conv_in_shape)
+                        # use old conv input tensor as conv output
+                        conv_node.output[0] = conv_in_name
+                        model.set_tensor_shape(conv_in_name, conv_out_shape)
+                        # use new conv output as new add node input
+                        add_node.input[0] = conv_in_name
+                        # use old conv output as new add node output
+                        add_node.output[0] = end_name
+                        # move add node past conv node
+                        graph.node.remove(add_node)
+                        graph.node.insert(node_ind, add_node)
+                        graph_modified = True
+        model = model.transform(InferShapes())
+        return (model, graph_modified)
-- 
GitLab