Skip to content
Snippets Groups Projects
Commit 0215a3e9 authored by Tobi-Alonso's avatar Tobi-Alonso
Browse files

[Transfom] Add Transform to move operations past graph forks

parent b83183bd
No related branches found
No related tags found
No related merge requests found
......@@ -360,3 +360,90 @@ class MakeMaxPoolNHWC(Transformation):
graph.node.insert(node_ind - 1, consumer)
graph_modified = True
return (model, graph_modified)
class MoveOpPastFork(Transformation):
"""Move node operations past graph forks. Used when a node before a fork
can be merged with nodes in the branches
"""
def __init__(self, op_name_list):
super().__init__()
self.ops_to_move = op_name_list
def apply(self, model):
graph = model.graph
graph_modified = False
nodes = [n for n in graph.node]
node_ind = 0
for n in nodes:
node_ind += 1
if (
n.op_type in self.ops_to_move
and model.is_fork_node(n)
and not model.is_join_node(n)
):
# Restrict this transform to operations with constant parameters
# Assuming parameters is in input 1
op_init_param = model.get_initializer(n.input[1])
if op_init_param is None:
continue
# Check case when branches are empty and go
# to the same node
consumers = model.find_consumers(n.output[0])
unique_consumer = True
for consum_node in consumers[1:]:
if consumers[0] != consum_node:
unique_consumer = False
break
if unique_consumer:
continue
for consumer_node in consumers[1:]:
# create new node
new_param_name = model.make_new_valueinfo_name()
new_output_tensor_name = model.make_new_valueinfo_name()
new_node = oh.make_node(
n.op_type,
[n.input[0], new_param_name],
[new_output_tensor_name],
)
graph.node.insert(node_ind, new_node)
node_ind += 1
model.set_initializer(new_param_name, op_init_param)
# change consumer input tensor
graph.node.remove(consumer_node)
for idx, consumer_input in enumerate(consumer_node.input):
if consumer_input == n.output[0]:
consumer_node.input[idx] = new_output_tensor_name
break
else:
raise Exception(
"Consumer should have the current node output as input"
)
graph.node.insert(node_ind, consumer_node)
graph_modified = True
model = model.transform(InferShapes())
return (model, graph_modified)
class MoveAddPastFork(MoveOpPastFork):
def __init__(self):
super().__init__(["Add"])
class MoveMulPastFork(MoveOpPastFork):
def __init__(self):
super().__init__(["Mul"])
class MoveLinearPastFork(MoveOpPastFork):
def __init__(self):
super().__init__(["Add", "Mul"])
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment