From 3319c81edeade8aa8fafc5bb3246d63633b4a384 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <yamanu@xilinx.com>
Date: Wed, 23 Oct 2019 15:00:16 +0100
Subject: [PATCH] [Transform] complete move_add_past_mul

---
 src/finn/transformation/streamline.py | 38 +++++++++++++++++++++------
 1 file changed, 30 insertions(+), 8 deletions(-)

diff --git a/src/finn/transformation/streamline.py b/src/finn/transformation/streamline.py
index 3d025c97c..39684af25 100644
--- a/src/finn/transformation/streamline.py
+++ b/src/finn/transformation/streamline.py
@@ -1,3 +1,5 @@
+from onnx import helper as oh
+
 import finn.transformation.infer_shapes as si
 
 
@@ -12,16 +14,36 @@ def move_add_past_mul(model):
         node_ind += 1
         if n.op_type == "Add":
             consumer = model.find_consumer(n.output[0])
-            if consumer.op_type == "Mul":
+            if consumer is not None and consumer.op_type == "Mul":
+                # have: (x) -> add(,B) -> (x+B) -> mul(,A) -> (xA+BA)
+                # want: (x) -> mul(,A) -> (xA) -> add(,BA) -> (xA+BA)
                 # assume input 0 is from the previous layer, input 1 is the
                 # trained (constant) parameter
-                mul_param = model.get_initializer(consumer.input[1])
-                add_param = model.get_initializer(n.input[1])
-                assert mul_param is not None
-                assert add_param is not None
-                # TODO compute new param values
-                # TODO make new nodes
-                # TODO mark nodes for removal
+                mul_weight_name = consumer.input[1]
+                add_weight_name = n.input[1]
+                A = model.get_initializer(mul_weight_name)
+                B = model.get_initializer(add_weight_name)
+                assert A is not None
+                assert B is not None
+                start_name = n.input[0]
+                middle_name = n.output[0]
+                end_name = consumer.output[0]
+                # compute new param value for add
+                BA = B * A
+                # make and insert new nodes
+                new_mul = oh.make_node(
+                    "Mul", [start_name, mul_weight_name], [middle_name]
+                )
+                new_add = oh.make_node(
+                    "Add", [middle_name, add_weight_name], [end_name]
+                )
+                graph.node.insert(node_ind, new_mul)
+                graph.node.insert(node_ind + 1, new_add)
+                # replace add value
+                model.set_initializer(add_weight_name, BA)
+                # mark old nodes for removal
+                nodes_to_remove += [n, consumer]
+                graph_modified = True
     # delete marked nodes (batchnorm and (un)squeezing)
     for n in nodes_to_remove:
         graph.node.remove(n)
-- 
GitLab