Skip to content
Snippets Groups Projects
Commit 67f89abf authored by Tobi-Alonso's avatar Tobi-Alonso
Browse files

[Transform] Add new transformation the moves MaxPool Past MultiThreshold nodes

parent 255cec17
No related branches found
No related tags found
No related merge requests found
......@@ -447,3 +447,43 @@ class MoveMulPastFork(MoveOpPastFork):
class MoveLinearPastFork(MoveOpPastFork):
def __init__(self):
super().__init__(["Add", "Mul"])
class MoveMaxPoolPastMultiThreshold(Transformation):
"""Move MaxPool nodes past MultiThreshold nodes on linear segments of the graph."""
def apply(self, model):
graph = model.graph
node_ind = 0
graph_modified = False
nodes = [n for n in graph.node]
for n in nodes:
node_ind += 1
if n.op_type == "MaxPool" and not model.is_fork_node(n):
consumer = model.find_consumer(n.output[0])
if consumer is not None and consumer.op_type == "MultiThreshold":
# remove old nodes
graph.node.remove(n)
graph.node.remove(consumer)
# swap conections
group_in = n.input[0]
# new tensor because dims change
group_middle = model.make_new_valueinfo_name()
group_out = consumer.output[0]
consumer.input[0] = group_in
consumer.output[0] = group_middle
n.input[0] = group_middle
n.output[0] = group_out
# insert them back in
graph.node.insert(node_ind - 1, consumer)
graph.node.insert(node_ind, n)
graph_modified = True
model = model.transform(InferShapes())
return (model, graph_modified)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment