diff --git a/src/finn/transformation/streamline/absorb.py b/src/finn/transformation/streamline/absorb.py index 92945a9eff1cc45ce295ccd76b40a39b429f45f8..5795d9d71b771f86d03188c320c5ecfe706f5050 100644 --- a/src/finn/transformation/streamline/absorb.py +++ b/src/finn/transformation/streamline/absorb.py @@ -31,6 +31,8 @@ from onnx import helper as oh from finn.core.datatype import DataType from finn.transformation import Transformation +from finn.util.basic import get_by_name +from finn.custom_op.registry import getCustomOp class AbsorbAddIntoMultiThreshold(Transformation): @@ -227,3 +229,53 @@ class Absorb1BitMulIntoConv(Transformation): graph.node.remove(consumer) graph_modified = True return (model, graph_modified) + + +class AbsorbTransposeIntoMultiThreshold(Transformation): + """Change (NHWCTranpose -> MultiThreshold -> NCHWTranspose) to (MultiThreshold) + with NHWC mode.""" + + def apply(self, model): + graph = model.graph + node_ind = 0 + graph_modified = False + for n in graph.node: + node_ind += 1 + if n.op_type == "Transpose": + perms = list(get_by_name(n.attribute, "perm").ints) + if perms == [0, 3, 1, 2]: + mt_cand = model.find_consumer(n.output[0]) + if mt_cand.op_type == "MultiThreshold": + final_t_cand = model.find_consumer(mt_cand.output[0]) + if final_t_cand.op_type == "Transpose": + perms = list( + get_by_name(final_t_cand.attribute, "perm").ints + ) + if perms == [0, 2, 3, 1]: + mt = getCustomOp(mt_cand) + mt.set_nodeattr("data_layout", "NHWC") + # get rid of tranpose nodes, wire MT directly + mt_cand.input[0] = n.input[0] + mt_cand.output[0] = final_t_cand.output[0] + graph.node.remove(n) + graph.node.remove(final_t_cand) + graph_modified = True + elif final_t_cand.op_type == "Reshape": + oshape = model.get_tensor_shape(final_t_cand.output[0]) + if len(oshape) == 2: + # transition to FC part, can still use NHWC + mt = getCustomOp(mt_cand) + mt.set_nodeattr("data_layout", "NHWC") + # get rid of first tranpose node + mt_cand.input[0] = n.input[0] + # fix output shape for MultiThreshold + mt_ishape = model.get_tensor_shape(mt_cand.input[0]) + (b, h, w, c) = mt_ishape + assert ( + h == 1 and w == 1 + ), """Untested spatial dim + in conv->fc transition, proceed with caution!""" + model.set_tensor_shape(mt_cand.output[0], mt_ishape) + graph.node.remove(n) + graph_modified = True + return (model, graph_modified)