From 67f89abf0a2cf92d2ba97d4bc5ad48f979c65863 Mon Sep 17 00:00:00 2001
From: Tobi-Alonso <tobi.alonso@gmail.com>
Date: Wed, 20 May 2020 18:30:08 +0100
Subject: [PATCH] [Transform] Add new transformation the moves MaxPool Past
 MultiThreshold nodes

---
 src/finn/transformation/streamline/reorder.py | 40 +++++++++++++++++++
 1 file changed, 40 insertions(+)

diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py
index 0c800afa3..09e177450 100644
--- a/src/finn/transformation/streamline/reorder.py
+++ b/src/finn/transformation/streamline/reorder.py
@@ -447,3 +447,43 @@ class MoveMulPastFork(MoveOpPastFork):
 class MoveLinearPastFork(MoveOpPastFork):
     def __init__(self):
         super().__init__(["Add", "Mul"])
+
+
+class MoveMaxPoolPastMultiThreshold(Transformation):
+    """Move MaxPool nodes past MultiThreshold nodes on linear segments of the graph."""
+
+    def apply(self, model):
+        graph = model.graph
+        node_ind = 0
+        graph_modified = False
+        nodes = [n for n in graph.node]
+        for n in nodes:
+            node_ind += 1
+            if n.op_type == "MaxPool" and not model.is_fork_node(n):
+                consumer = model.find_consumer(n.output[0])
+                if consumer is not None and consumer.op_type == "MultiThreshold":
+
+                    # remove old nodes
+                    graph.node.remove(n)
+                    graph.node.remove(consumer)
+
+                    # swap conections
+                    group_in = n.input[0]
+                    # new tensor because dims change
+                    group_middle = model.make_new_valueinfo_name()
+                    group_out = consumer.output[0]
+
+                    consumer.input[0] = group_in
+                    consumer.output[0] = group_middle
+
+                    n.input[0] = group_middle
+                    n.output[0] = group_out
+
+                    # insert them back in
+                    graph.node.insert(node_ind - 1, consumer)
+                    graph.node.insert(node_ind, n)
+
+                    graph_modified = True
+
+        model = model.transform(InferShapes())
+        return (model, graph_modified)
-- 
GitLab