Skip to content
Snippets Groups Projects
Commit de37b2a6 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Transform] add AbsorbTransposeIntoMultiThreshold

parent cc91f877
No related branches found
No related tags found
No related merge requests found
......@@ -31,6 +31,8 @@ from onnx import helper as oh
from finn.core.datatype import DataType
from finn.transformation import Transformation
from finn.util.basic import get_by_name
from finn.custom_op.registry import getCustomOp
class AbsorbAddIntoMultiThreshold(Transformation):
......@@ -227,3 +229,53 @@ class Absorb1BitMulIntoConv(Transformation):
graph.node.remove(consumer)
graph_modified = True
return (model, graph_modified)
class AbsorbTransposeIntoMultiThreshold(Transformation):
"""Change (NHWCTranpose -> MultiThreshold -> NCHWTranspose) to (MultiThreshold)
with NHWC mode."""
def apply(self, model):
graph = model.graph
node_ind = 0
graph_modified = False
for n in graph.node:
node_ind += 1
if n.op_type == "Transpose":
perms = list(get_by_name(n.attribute, "perm").ints)
if perms == [0, 3, 1, 2]:
mt_cand = model.find_consumer(n.output[0])
if mt_cand.op_type == "MultiThreshold":
final_t_cand = model.find_consumer(mt_cand.output[0])
if final_t_cand.op_type == "Transpose":
perms = list(
get_by_name(final_t_cand.attribute, "perm").ints
)
if perms == [0, 2, 3, 1]:
mt = getCustomOp(mt_cand)
mt.set_nodeattr("data_layout", "NHWC")
# get rid of tranpose nodes, wire MT directly
mt_cand.input[0] = n.input[0]
mt_cand.output[0] = final_t_cand.output[0]
graph.node.remove(n)
graph.node.remove(final_t_cand)
graph_modified = True
elif final_t_cand.op_type == "Reshape":
oshape = model.get_tensor_shape(final_t_cand.output[0])
if len(oshape) == 2:
# transition to FC part, can still use NHWC
mt = getCustomOp(mt_cand)
mt.set_nodeattr("data_layout", "NHWC")
# get rid of first tranpose node
mt_cand.input[0] = n.input[0]
# fix output shape for MultiThreshold
mt_ishape = model.get_tensor_shape(mt_cand.input[0])
(b, h, w, c) = mt_ishape
assert (
h == 1 and w == 1
), """Untested spatial dim
in conv->fc transition, proceed with caution!"""
model.set_tensor_shape(mt_cand.output[0], mt_ishape)
graph.node.remove(n)
graph_modified = True
return (model, graph_modified)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment