Skip to content
Snippets Groups Projects
Commit 5ced9302 authored by mmrahorovic's avatar mmrahorovic
Browse files

[Transform]: bug fix AbsorbTransposeIntoMultiThreshold to support non-linear graphs

parent beebdd77
No related branches found
No related tags found
No related merge requests found
......@@ -315,7 +315,8 @@ class AbsorbTransposeIntoMultiThreshold(Transformation):
graph = model.graph
node_ind = 0
graph_modified = False
for n in graph.node:
nodes = [n for n in model.graph.node]
for n in nodes:
node_ind += 1
if n.op_type == "Transpose" and not model.is_fork_node(n):
perms = list(get_by_name(n.attribute, "perm").ints)
......@@ -326,37 +327,41 @@ class AbsorbTransposeIntoMultiThreshold(Transformation):
and mt_cand.op_type == "MultiThreshold"
# and not model.is_fork_node(mt_cand)
):
final_t_cands = model.find_consumers(mt_cand.output[0])
mt_cand_orig_output = mt_cand.output[0]
mt = getCustomOp(mt_cand)
mt.set_nodeattr("data_layout", "NHWC")
# get rid of first tranpose node
# Rewire input of MultiThreshold node
mt_cand.input[0] = n.input[0]
# Make new intermediate tensor
intermediate_tensor_name = model.make_new_valueinfo_name()
intermediate_tensor_shape = model.get_tensor_shape(n.input[0])
intermediate_tensor_finn_dtype = model.get_tensor_datatype(
mt_cand.output[0]
)
# Create a new ValueInfoProto and set the shape
model.set_tensor_shape(
intermediate_tensor_name, intermediate_tensor_shape
)
# Set the tensor layout
model.set_tensor_layout(
intermediate_tensor_name, DataLayout.NHWC
)
# Set the tensor FINN datatype
model.set_tensor_datatype(
intermediate_tensor_name, intermediate_tensor_finn_dtype
)
# Rewire output of MT node
mt_cand.output[0] = intermediate_tensor_name
# Get rid of first transpose node
graph.node.remove(n)
# fix output shape for MultiThreshold
mt_orig_oshape = model.get_tensor_shape(mt_cand.output[0])
mt_ishape = model.get_tensor_shape(mt_cand.input[0])
model.set_tensor_shape(mt_cand.output[0], mt_ishape)
# re-insert Transpose behind MultiThreshold
transpose_output = model.make_new_valueinfo_name()
# Create new Transpose node
new_transpose = oh.make_node(
"Transpose",
[mt_cand.output[0]],
[transpose_output],
[intermediate_tensor_name],
[mt_cand_orig_output],
perm=[0, 3, 1, 2],
)
graph.node.insert(node_ind + 1, new_transpose)
if final_t_cands is not None:
# rewire next nodes' inputs
for final_t_cand in final_t_cands:
final_t_cand.input[0] = transpose_output
else:
# replace graph top-level output
get_by_name(
model.graph.output, mt_cand.output[0]
).name = transpose_output
model.set_tensor_shape(mt_cand.output[0], mt_ishape)
# set value_info shape for transpose output
model.set_tensor_shape(transpose_output, mt_orig_oshape)
graph_modified = True
if graph_modified:
model = model.transform(InferDataTypes())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment