Skip to content
Snippets Groups Projects
Commit a8bb9754 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Streamline] simplify and improve AbsorbConsecutiveTransposes

parent 8f7a26b6
No related branches found
No related tags found
No related merge requests found
......@@ -473,7 +473,7 @@ class AbsorbConsecutiveTransposes(Transformation):
"""Remove (Transpose -> Transpose) patterns when the input and output
of the pattern have the same layout."""
def Are_opposite_permutations(self, perms1, perms2):
def are_opposite_permutations(self, perms1, perms2):
if len(perms1) != len(perms2):
return False
assert 0 <= max(perms2) < len(perms2), "invalid permutation"
......@@ -488,72 +488,40 @@ class AbsorbConsecutiveTransposes(Transformation):
def apply(self, model):
graph = model.graph
graph_modified = False
for n in graph.node:
if n.op_type == "Transpose":
if model.is_fork_node(n):
next_nodes = model.find_direct_successors(n)
perms1 = list(get_by_name(n.attribute, "perm").ints)
# check if all nodes after fork are opposite transposes
all_opposite_transposes = True
for next_node in next_nodes:
if next_node is not None and next_node.op_type == "Transpose":
perms2 = list(get_by_name(next_node.attribute, "perm").ints)
if not self.Are_opposite_permutations(perms1, perms2):
all_opposite_transposes = False
break
else:
all_opposite_transposes = False
break
if not all_opposite_transposes:
continue
prod = model.find_producer(n.input[0])
for next_node in next_nodes:
# connect next_node's consumer input to n's producer output
# TODO implement this to allow for forks as producers and
# joins as consumers
cons = model.find_consumer(next_node.output[0])
cons.input[0] = prod.output[0]
# remove consumer transpose
graph.node.remove(next_node)
# remove producer transpose
graph.node.remove(n)
graph_modified = True
else:
next_node = model.find_consumer(n.output[0])
for node in graph.node:
if node.op_type == "Transpose":
next_nodes = model.find_consumers(node.output[0])
perms1 = list(get_by_name(node.attribute, "perm").ints)
# check if all nodes after fork are opposite transposes
all_opposite_transposes = True
for next_node in next_nodes:
if next_node is not None and next_node.op_type == "Transpose":
perms1 = list(get_by_name(n.attribute, "perm").ints)
perms2 = list(get_by_name(next_node.attribute, "perm").ints)
if self.Are_opposite_permutations(perms1, perms2):
# connect next_node's consumer input to n's producer output
# TODO implement this to allow for forks as producers
consumers = model.find_direct_successors(next_node)
prod = model.find_producer(n.input[0])
if prod is not None:
for cons in consumers:
for cons_in in cons.input:
if cons_in == next_node.output[0]:
prod.output[0] = cons_in
break
else:
# n.input[0] is top-level graph input
# wire consumers directly to that
for cons in consumers:
for i, iname in enumerate(cons.input):
if iname == next_node.output[0]:
cons.input[i] = n.input[0]
# remove both transposes
graph.node.remove(n)
graph.node.remove(next_node)
if not self.are_opposite_permutations(perms1, perms2):
all_opposite_transposes = False
break
else:
all_opposite_transposes = False
break
if not all_opposite_transposes:
continue
source_tensor = node.input[0]
for next_node in next_nodes:
# connect next_node's consumers' appropriate input to n's input
# TODO how to handle top-level outputs if any?
nextnode_out = next_node.output[0]
assert nextnode_out not in [x.name for x in model.graph.output]
consumers = model.find_consumers(nextnode_out)
for cons in consumers:
for i, iname in enumerate(cons.input):
if iname == nextnode_out:
cons.input[i] = source_tensor
# remove consumer transpose
graph.node.remove(next_node)
# remove producer transpose
graph.node.remove(node)
graph_modified = True
graph_modified = True
if graph_modified:
model = model.transform(InferDataTypes())
return (model, graph_modified)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment