From e9aa670f167e361f910834fd37d857225d9bdb0f Mon Sep 17 00:00:00 2001
From: auphelia <jakobapk@web.de>
Date: Thu, 25 Jun 2020 10:45:17 +0100
Subject: [PATCH] [Streamline] Add MoveFlattenPastAffine operation

---
 src/finn/transformation/streamline/reorder.py | 72 +++++++++++++++++++
 1 file changed, 72 insertions(+)

diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py
index b46b82c77..638c6ad5d 100644
--- a/src/finn/transformation/streamline/reorder.py
+++ b/src/finn/transformation/streamline/reorder.py
@@ -32,6 +32,7 @@ from onnx import helper as oh
 
 from finn.transformation import Transformation
 from finn.transformation.infer_shapes import InferShapes
+from finn.transformation.infer_datatypes import InferDataTypes
 from finn.core.onnx_exec import execute_node
 from finn.util.basic import get_by_name
 from finn.custom_op.registry import getCustomOp
@@ -597,3 +598,74 @@ class MoveMaxPoolPastMultiThreshold(Transformation):
 
         model = model.transform(InferShapes())
         return (model, graph_modified)
+
+
+class MoveFlattenPastAffine(Transformation):
+    """Moves a node that implements a (1, -1) reshape past a MatMul, Mul or Add node."""
+
+    def apply(self, model):
+        graph = model.graph
+        graph_modified = False
+        node_ind = 0
+        for n in graph.node:
+            node_ind += 1
+            if (
+                n.op_type == "Flatten"
+                and not model.is_fork_node(n)
+                and not model.is_join_node(n)
+            ):
+                consumer = model.find_consumer(n.output[0])
+                if (
+                    consumer is not None
+                    and (
+                        consumer.op_type == "MatMul"
+                        or consumer.op_type == "Mul"
+                        or consumer.op_type == "Add"
+                    )
+                    and not model.is_join_node(consumer)
+                ):
+                    # move flatten past operation and rewire tensors
+                    start_name = n.input[0]
+                    middle_name = n.output[0]
+                    end_name = consumer.output[0]
+                    op_param_name = consumer.input[1]
+                    A = model.get_initializer(op_param_name)
+                    if A is None:
+                        warnings.warn("Param is not constant, skipping")
+                        continue
+                    op_in_dt = model.get_tensor_datatype(consumer.input[0])
+                    op_out_dt = model.get_tensor_datatype(consumer.output[0])
+                    start_shape = model.get_tensor_shape(start_name)
+                    dummy_in = np.random.uniform(low=0, high=1, size=(start_shape))
+
+                    if consumer.op_type == "MatMul":
+                        dummy_out = np.matmul(dummy_in, A)
+                    elif consumer.op_type == "Mul":
+                        dummy_out = dummy_in * A
+                    elif consumer.op_type == "Add":
+                        dummy_out = dummy_in + A
+
+                    new_op = oh.make_node(
+                        consumer.op_type,
+                        [start_name, op_param_name],
+                        [middle_name],
+                        name=consumer.name,
+                    )
+                    new_flatten = oh.make_node("Flatten", [middle_name], [end_name])
+                    graph.node.insert(node_ind, new_op)
+                    graph.node.insert(node_ind + 1, new_flatten)
+                    model.set_tensor_shape(middle_name, dummy_out.shape)
+                    # because a flatten node doesn't change the datatype we need
+                    # only the datatype of the op node
+                    model.set_tensor_datatype(start_name, op_in_dt)
+                    model.set_tensor_datatype(middle_name, op_out_dt)
+                    model.set_tensor_datatype(end_name, op_out_dt)
+                    # remove old nodes
+                    graph.node.remove(n)
+                    graph.node.remove(consumer)
+                    graph_modified = True
+
+        model = model.transform(InferShapes())
+        model = model.transform(InferDataTypes())
+
+        return (model, graph_modified)
-- 
GitLab