From 67f89abf0a2cf92d2ba97d4bc5ad48f979c65863 Mon Sep 17 00:00:00 2001 From: Tobi-Alonso <tobi.alonso@gmail.com> Date: Wed, 20 May 2020 18:30:08 +0100 Subject: [PATCH] [Transform] Add new transformation the moves MaxPool Past MultiThreshold nodes --- src/finn/transformation/streamline/reorder.py | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py index 0c800afa3..09e177450 100644 --- a/src/finn/transformation/streamline/reorder.py +++ b/src/finn/transformation/streamline/reorder.py @@ -447,3 +447,43 @@ class MoveMulPastFork(MoveOpPastFork): class MoveLinearPastFork(MoveOpPastFork): def __init__(self): super().__init__(["Add", "Mul"]) + + +class MoveMaxPoolPastMultiThreshold(Transformation): + """Move MaxPool nodes past MultiThreshold nodes on linear segments of the graph.""" + + def apply(self, model): + graph = model.graph + node_ind = 0 + graph_modified = False + nodes = [n for n in graph.node] + for n in nodes: + node_ind += 1 + if n.op_type == "MaxPool" and not model.is_fork_node(n): + consumer = model.find_consumer(n.output[0]) + if consumer is not None and consumer.op_type == "MultiThreshold": + + # remove old nodes + graph.node.remove(n) + graph.node.remove(consumer) + + # swap conections + group_in = n.input[0] + # new tensor because dims change + group_middle = model.make_new_valueinfo_name() + group_out = consumer.output[0] + + consumer.input[0] = group_in + consumer.output[0] = group_middle + + n.input[0] = group_middle + n.output[0] = group_out + + # insert them back in + graph.node.insert(node_ind - 1, consumer) + graph.node.insert(node_ind, n) + + graph_modified = True + + model = model.transform(InferShapes()) + return (model, graph_modified) -- GitLab