Skip to content
Snippets Groups Projects
Commit 64b10b67 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Absorb] edge-case fixes to AbsorbTransposeIntoResize

parent 54ea3c6f
No related branches found
No related tags found
No related merge requests found
...@@ -563,55 +563,71 @@ class AbsorbTransposeIntoResize(Transformation): ...@@ -563,55 +563,71 @@ class AbsorbTransposeIntoResize(Transformation):
graph = model.graph graph = model.graph
node_ind = 0 node_ind = 0
graph_modified = False graph_modified = False
for n in graph.node: for node in graph.node:
node_ind += 1 node_ind += 1
if n.op_type == "Transpose" and not model.is_fork_node(n): if node.op_type == "Transpose" and not model.is_fork_node(node):
perms = list(get_by_name(n.attribute, "perm").ints) perms = list(get_by_name(node.attribute, "perm").ints)
if perms == [0, 3, 1, 2]: if perms == [0, 3, 1, 2]:
mt_cand = model.find_consumer(n.output[0]) mt_cand = model.find_consumer(node.output[0])
if mt_cand is not None and mt_cand.op_type == "Resize": if mt_cand is not None and mt_cand.op_type == "Resize":
final_t_cands = model.find_consumers(mt_cand.output[0])
mode = get_by_name(mt_cand.attribute, "mode").s.decode("ascii") mode = get_by_name(mt_cand.attribute, "mode").s.decode("ascii")
# skip if mode is not nearest # skip if mode is not nearest
if mode != "nearest": if mode != "nearest":
continue continue
# if sizes specified, turn into scales # if sizes specified, turn into scales
sizes = model.get_initializer(mt_cand.input[3]) if len(mt_cand.input) > 3:
sizes = model.get_initializer(mt_cand.input[3])
else:
sizes = None
if sizes is not None: if sizes is not None:
ishape = model.get_tensor_shape(mt_cand.input[0]) ishape = model.get_tensor_shape(mt_cand.input[0])
ns, cs, hs, ws = sizes / np.asarray(ishape) ns, cs, hs, ws = sizes / np.asarray(ishape)
model.set_initializer( model.set_initializer(
mt_cand.input[2], np.asarray([ns, hs, ws, cs]) mt_cand.input[2], np.asarray([ns, cs, hs, ws])
) )
mt_cand.input.remove(mt_cand.input[3]) mt_cand.input.remove(mt_cand.input[3])
# get rid of first tranpose node # scales already specified, transpose indices to NHWC
mt_cand.input[0] = n.input[0] scales = model.get_initializer(mt_cand.input[2])
graph.node.remove(n) assert scales is not None
# fix output shape for Resize ns, cs, hs, ws = scales
mt_orig_oshape = model.get_tensor_shape(mt_cand.output[0]) model.set_initializer(
mt_ishape = model.get_tensor_shape(mt_cand.input[0]) mt_cand.input[2], np.asarray([ns, hs, ws, cs])
model.set_tensor_shape(mt_cand.output[0], mt_ishape) )
# re-insert Transpose behind Resize # get rid of first tranpose node
transpose_output = model.make_new_valueinfo_name() mt_cand.input[0] = node.input[0]
model.set_tensor_shape(transpose_output, mt_orig_oshape) graph.node.remove(node)
new_transpose = oh.make_node( is_last_node = mt_cand.output[0] in [
"Transpose", x.name for x in model.graph.output
[mt_cand.output[0]], ]
[transpose_output],
perm=[0, 3, 1, 2], new_tensor_name = model.make_new_valueinfo_name()
) if is_last_node:
graph.node.insert(node_ind + 1, new_transpose) trans_input = new_tensor_name
if final_t_cands is not None: trans_output = mt_cand.output[0]
# rewire next nodes' inputs else:
for final_t_cand in final_t_cands: trans_input = mt_cand.output[0]
final_t_cand.input[0] = transpose_output trans_output = new_tensor_name
else: # fix tensor shapes for Resize and Transpose
# replace graph top-level output # n, c, h, w = model.get_tensor_shape(mt_cand.input[0])
get_by_name( n, c, hx, wx = model.get_tensor_shape(mt_cand.output[0])
model.graph.output, mt_cand.output[0] model.set_tensor_shape(trans_input, (n, hx, wx, c))
).name = transpose_output model.set_tensor_shape(trans_output, (n, c, hx, wx))
model.set_tensor_shape(mt_cand.output[0], mt_ishape) # re-insert Transpose behind Resize
graph_modified = True new_transpose = oh.make_node(
"Transpose",
[trans_input],
[trans_output],
perm=[0, 3, 1, 2],
)
graph.node.insert(node_ind + 1, new_transpose)
# rewire nodes
final_t_cands = model.find_consumers(mt_cand.output[0])
if final_t_cands is not None:
# rewire next nodes' inputs
for final_t_cand in final_t_cands:
final_t_cand.input[0] = trans_output
mt_cand.output[0] = trans_input
graph_modified = True
if graph_modified: if graph_modified:
model = model.transform(InferDataTypes()) model = model.transform(InferDataTypes())
return (model, graph_modified) return (model, graph_modified)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment