From efe1a7a395c46a92e999cc4de99dc6daaffdf630 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <yamanu@xilinx.com>
Date: Wed, 30 Oct 2019 11:42:16 +0000
Subject: [PATCH] [Transform] add move_scalar_add_past_matmul

---
 src/finn/transformation/streamline.py | 47 ++++++++++++++++++++++++---
 1 file changed, 43 insertions(+), 4 deletions(-)

diff --git a/src/finn/transformation/streamline.py b/src/finn/transformation/streamline.py
index 23f844cd9..7d90df603 100644
--- a/src/finn/transformation/streamline.py
+++ b/src/finn/transformation/streamline.py
@@ -1,3 +1,4 @@
+import numpy as np
 from onnx import helper as oh
 
 import finn.transformation.infer_shapes as si
@@ -114,10 +115,6 @@ def move_scalar_mul_past_matmul(model):
         if n.op_type == "Mul":
             consumer = model.find_consumer(n.output[0])
             if consumer is not None and consumer.op_type == "MatMul":
-                # have: (x) -> add(,B) -> (x+B) -> mul(,A) -> (xA+BA)
-                # want: (x) -> mul(,A) -> (xA) -> add(,BA) -> (xA+BA)
-                # assume input 0 is from the previous layer, input 1 is the
-                # trained (constant) parameter
                 mul_weight_name = n.input[1]
                 matmul_weight_name = consumer.input[1]
                 A = model.get_initializer(mul_weight_name)
@@ -147,3 +144,45 @@ def move_scalar_mul_past_matmul(model):
         graph_modified = True
     model = model.transform_single(si.infer_shapes)
     return (model, graph_modified)
+
+
+def move_scalar_add_past_matmul(model):
+    """Move scalar add operations past matmul operations. We want to have adds
+    next to each other such that they can be collapsed into a single add."""
+    graph = model.graph
+    node_ind = 0
+    graph_modified = False
+    for n in graph.node:
+        node_ind += 1
+        if n.op_type == "Add":
+            consumer = model.find_consumer(n.output[0])
+            if consumer is not None and consumer.op_type == "MatMul":
+                add_weight_name = n.input[1]
+                matmul_weight_name = consumer.input[1]
+                A = model.get_initializer(add_weight_name)
+                W = model.get_initializer(matmul_weight_name)
+                assert A is not None
+                assert W is not None
+                start_name = n.input[0]
+                middle_name = n.output[0]
+                end_name = consumer.output[0]
+                if all(x == 1 for x in A.shape):
+                    # if the add is scalar, we can move it past the matmul
+                    # by taking it past the matmul with a dot product
+                    Anew = np.dot(A * np.ones(W.shape[0], dtype=np.float32), W)
+                    # update the add weight
+                    model.set_initializer(add_weight_name, Anew)
+                    new_matmul = oh.make_node(
+                        "MatMul", [start_name, matmul_weight_name], [middle_name]
+                    )
+                    new_add = oh.make_node(
+                        "Add", [middle_name, add_weight_name], [end_name]
+                    )
+                    graph.node.insert(node_ind, new_matmul)
+                    graph.node.insert(node_ind + 1, new_add)
+                    # remove old nodes
+                    graph.node.remove(n)
+                    graph.node.remove(consumer)
+                    graph_modified = True
+    model = model.transform_single(si.infer_shapes)
+    return (model, graph_modified)
-- 
GitLab