Skip to content
Snippets Groups Projects
Commit cbdc575e authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Streamline] allow MoveOpPastFork to copy attributes, impl for Transpose

parent 4b79ea38
No related branches found
No related tags found
No related merge requests found
......@@ -728,9 +728,10 @@ class MoveOpPastFork(Transformation):
can be merged with nodes in the branches
"""
def __init__(self, op_name_list):
def __init__(self, op_name_list, get_attrs_fxn=lambda x: {}):
super().__init__()
self.ops_to_move = op_name_list
self.get_attrs_fxn = get_attrs_fxn
def apply(self, model):
graph = model.graph
......@@ -747,9 +748,10 @@ class MoveOpPastFork(Transformation):
# Restrict this transform to operations with constant parameters
# Assuming parameters is in input 1
op_init_param = model.get_initializer(n.input[1])
if op_init_param is None:
continue
if len(n.input) > 1:
op_init_param = model.get_initializer(n.input[1])
else:
op_init_param = None
# Check case when branches are empty and go
# to the same node
......@@ -766,16 +768,20 @@ class MoveOpPastFork(Transformation):
for consumer_node in consumers[1:]:
# create new node
new_param_name = model.make_new_valueinfo_name()
new_output_tensor_name = model.make_new_valueinfo_name()
if op_init_param is None:
new_inp_list = [n.input[0]]
else:
new_param_name = model.make_new_valueinfo_name()
new_inp_list = [n.input[0], new_param_name]
model.set_initializer(new_param_name, op_init_param)
attrs = self.get_attrs_fxn(n)
# TODO use copy of original node instead to get attrs?
new_node = oh.make_node(
n.op_type,
[n.input[0], new_param_name],
[new_output_tensor_name],
n.op_type, new_inp_list, [new_output_tensor_name], **attrs
)
graph.node.insert(node_ind, new_node)
node_ind += 1
model.set_initializer(new_param_name, op_init_param)
# change consumer input tensor
graph.node.remove(consumer_node)
......@@ -811,6 +817,13 @@ class MoveLinearPastFork(MoveOpPastFork):
super().__init__(["Add", "Mul"])
class MoveTransposePastFork(MoveOpPastFork):
def __init__(self):
super().__init__(
["Transpose"], lambda x: {"perm": get_by_name(x.attribute, "perm").ints}
)
class MoveMaxPoolPastMultiThreshold(Transformation):
"""Move MaxPool nodes past MultiThreshold nodes on linear segments of the graph."""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment