Skip to content
Snippets Groups Projects
Commit c3083929 authored by auphelia's avatar auphelia
Browse files

[Streamline/Absorb] Reorder insertion/rewire of new trn node

parent 013684a1
No related branches found
No related tags found
No related merge requests found
......@@ -580,7 +580,6 @@ class AbsorbTransposeIntoResize(Transformation):
trans_input = mt_cand.output[0]
trans_output = new_tensor_name
# fix tensor shapes for Resize and Transpose
# n, c, h, w = model.get_tensor_shape(mt_cand.input[0])
n, c, hx, wx = model.get_tensor_shape(mt_cand.output[0])
model.set_tensor_shape(trans_input, (n, hx, wx, c))
model.set_tensor_shape(trans_output, (n, c, hx, wx))
......@@ -591,13 +590,13 @@ class AbsorbTransposeIntoResize(Transformation):
[trans_output],
perm=[0, 3, 1, 2],
)
graph.node.insert(node_ind + 1, new_transpose)
# rewire nodes
final_t_cands = model.find_consumers(mt_cand.output[0])
# rewire next nodes' inputs
for final_t_cand in final_t_cands:
final_t_cand.input[0] = trans_output
mt_cand.output[0] = trans_input
graph.node.insert(node_ind + 1, new_transpose)
graph_modified = True
if graph_modified:
model = model.transform(InferDataTypes())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment