Skip to content
Snippets Groups Projects
Commit 2724eef9 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[DataLayout] change strategy in AbsorbTransposeIntoMultiThreshold

parent da18bc95
No related branches found
No related tags found
No related merge requests found
......@@ -308,8 +308,8 @@ class Absorb1BitMulIntoConv(Transformation):
class AbsorbTransposeIntoMultiThreshold(Transformation):
"""Change (NCHWTranspose -> MultiThreshold -> NHWCTranspose) to (MultiThreshold)
with NHWC mode. For (NCHWTranspose -> MultiThreshold) move Transpose past MT."""
"""For (NCHWTranspose -> MultiThreshold) move Transpose past MultiThreshold
and set its data_layout mode to NHWC."""
def apply(self, model):
graph = model.graph
......@@ -324,53 +324,39 @@ class AbsorbTransposeIntoMultiThreshold(Transformation):
if (
mt_cand is not None
and mt_cand.op_type == "MultiThreshold"
and not model.is_fork_node(mt_cand)
# and not model.is_fork_node(mt_cand)
):
final_t_cand = model.find_consumer(mt_cand.output[0])
if (
final_t_cand is not None
and final_t_cand.op_type == "Transpose"
):
perms = list(
get_by_name(final_t_cand.attribute, "perm").ints
)
if perms == [0, 2, 3, 1]:
mt = getCustomOp(mt_cand)
mt.set_nodeattr("data_layout", "NHWC")
# get rid of tranpose nodes, wire MT directly
mt_cand.input[0] = n.input[0]
mt_cand.output[0] = final_t_cand.output[0]
graph.node.remove(n)
graph.node.remove(final_t_cand)
graph_modified = True
final_t_cands = model.find_consumers(mt_cand.output[0])
mt = getCustomOp(mt_cand)
mt.set_nodeattr("data_layout", "NHWC")
# get rid of first tranpose node
mt_cand.input[0] = n.input[0]
graph.node.remove(n)
# fix output shape for MultiThreshold
mt_orig_oshape = model.get_tensor_shape(mt_cand.output[0])
mt_ishape = model.get_tensor_shape(mt_cand.input[0])
model.set_tensor_shape(mt_cand.output[0], mt_ishape)
# re-insert Transpose behind MultiThreshold
transpose_output = model.make_new_valueinfo_name()
model.set_tensor_shape(transpose_output, mt_orig_oshape)
new_transpose = oh.make_node(
"Transpose",
[mt_cand.output[0]],
[transpose_output],
perm=[0, 3, 1, 2],
)
graph.node.insert(node_ind + 1, new_transpose)
if final_t_cands is not None:
# rewire next nodes' inputs
for final_t_cand in final_t_cands:
final_t_cand.input[0] = transpose_output
else:
mt = getCustomOp(mt_cand)
mt.set_nodeattr("data_layout", "NHWC")
# get rid of first tranpose node
mt_cand.input[0] = n.input[0]
graph.node.remove(n)
# fix output shape for MultiThreshold
mt_ishape = model.get_tensor_shape(mt_cand.input[0])
# replace graph top-level output
get_by_name(
model.graph.output, mt_cand.output[0]
).name = transpose_output
model.set_tensor_shape(mt_cand.output[0], mt_ishape)
# re-insert Transpose behind MultiThreshold
transpose_output = model.make_new_valueinfo_name()
new_transpose = oh.make_node(
"Transpose",
[mt_cand.output[0]],
[transpose_output],
perm=[0, 3, 1, 2],
)
graph.node.insert(node_ind + 1, new_transpose)
if final_t_cand is not None:
# rewire next node's input
final_t_cand.input[0] = transpose_output
else:
# replace graph top-level output
get_by_name(
model.graph.output, mt_cand.output[0]
).name = transpose_output
model.set_tensor_shape(mt_cand.output[0], mt_ishape)
graph_modified = True
graph_modified = True
if graph_modified:
model = model.transform(InferDataTypes())
return (model, graph_modified)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment