Skip to content
Snippets Groups Projects
Commit 3a786fab authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Streamline] handle edge case in AbsorbTransposeIntoMultiThreshold

need to do things different if MultiThreshold is the final node
in the graph (producing top-level output)
parent 30d3fa9d
No related branches found
No related tags found
No related merge requests found
......@@ -321,11 +321,16 @@ class AbsorbTransposeIntoMultiThreshold(Transformation):
perms = list(get_by_name(n.attribute, "perm").ints)
if perms == [0, 3, 1, 2]:
mt_cand = model.find_consumer(n.output[0])
if mt_cand.op_type == "MultiThreshold" and not model.is_fork_node(
mt_cand
if (
mt_cand is not None
and mt_cand.op_type == "MultiThreshold"
and not model.is_fork_node(mt_cand)
):
final_t_cand = model.find_consumer(mt_cand.output[0])
if final_t_cand.op_type == "Transpose":
if (
final_t_cand is not None
and final_t_cand.op_type == "Transpose"
):
perms = list(
get_by_name(final_t_cand.attribute, "perm").ints
)
......@@ -356,7 +361,15 @@ class AbsorbTransposeIntoMultiThreshold(Transformation):
perm=[0, 3, 1, 2],
)
graph.node.insert(node_ind + 1, new_transpose)
final_t_cand.input[0] = transpose_output
if final_t_cand is not None:
# rewire next node's input
final_t_cand.input[0] = transpose_output
else:
# replace graph top-level output
get_by_name(
model.graph.output, mt_cand.output[0]
).name = transpose_output
model.set_tensor_shape(mt_cand.output[0], mt_ishape)
graph_modified = True
if graph_modified:
model = model.transform(InferDataTypes())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment