From 2724eef995b775ed7b404d1ba61604e94d9dc630 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <yamanu@xilinx.com> Date: Tue, 31 Aug 2021 00:14:03 +0200 Subject: [PATCH] [DataLayout] change strategy in AbsorbTransposeIntoMultiThreshold --- src/finn/transformation/streamline/absorb.py | 78 ++++++++------------ 1 file changed, 32 insertions(+), 46 deletions(-) diff --git a/src/finn/transformation/streamline/absorb.py b/src/finn/transformation/streamline/absorb.py index 9902bdc34..e174759f0 100644 --- a/src/finn/transformation/streamline/absorb.py +++ b/src/finn/transformation/streamline/absorb.py @@ -308,8 +308,8 @@ class Absorb1BitMulIntoConv(Transformation): class AbsorbTransposeIntoMultiThreshold(Transformation): - """Change (NCHWTranspose -> MultiThreshold -> NHWCTranspose) to (MultiThreshold) - with NHWC mode. For (NCHWTranspose -> MultiThreshold) move Transpose past MT.""" + """For (NCHWTranspose -> MultiThreshold) move Transpose past MultiThreshold + and set its data_layout mode to NHWC.""" def apply(self, model): graph = model.graph @@ -324,53 +324,39 @@ class AbsorbTransposeIntoMultiThreshold(Transformation): if ( mt_cand is not None and mt_cand.op_type == "MultiThreshold" - and not model.is_fork_node(mt_cand) + # and not model.is_fork_node(mt_cand) ): - final_t_cand = model.find_consumer(mt_cand.output[0]) - if ( - final_t_cand is not None - and final_t_cand.op_type == "Transpose" - ): - perms = list( - get_by_name(final_t_cand.attribute, "perm").ints - ) - if perms == [0, 2, 3, 1]: - mt = getCustomOp(mt_cand) - mt.set_nodeattr("data_layout", "NHWC") - # get rid of tranpose nodes, wire MT directly - mt_cand.input[0] = n.input[0] - mt_cand.output[0] = final_t_cand.output[0] - graph.node.remove(n) - graph.node.remove(final_t_cand) - graph_modified = True + final_t_cands = model.find_consumers(mt_cand.output[0]) + mt = getCustomOp(mt_cand) + mt.set_nodeattr("data_layout", "NHWC") + # get rid of first tranpose node + mt_cand.input[0] = n.input[0] + graph.node.remove(n) + # fix output shape for MultiThreshold + mt_orig_oshape = model.get_tensor_shape(mt_cand.output[0]) + mt_ishape = model.get_tensor_shape(mt_cand.input[0]) + model.set_tensor_shape(mt_cand.output[0], mt_ishape) + # re-insert Transpose behind MultiThreshold + transpose_output = model.make_new_valueinfo_name() + model.set_tensor_shape(transpose_output, mt_orig_oshape) + new_transpose = oh.make_node( + "Transpose", + [mt_cand.output[0]], + [transpose_output], + perm=[0, 3, 1, 2], + ) + graph.node.insert(node_ind + 1, new_transpose) + if final_t_cands is not None: + # rewire next nodes' inputs + for final_t_cand in final_t_cands: + final_t_cand.input[0] = transpose_output else: - mt = getCustomOp(mt_cand) - mt.set_nodeattr("data_layout", "NHWC") - # get rid of first tranpose node - mt_cand.input[0] = n.input[0] - graph.node.remove(n) - # fix output shape for MultiThreshold - mt_ishape = model.get_tensor_shape(mt_cand.input[0]) + # replace graph top-level output + get_by_name( + model.graph.output, mt_cand.output[0] + ).name = transpose_output model.set_tensor_shape(mt_cand.output[0], mt_ishape) - # re-insert Transpose behind MultiThreshold - transpose_output = model.make_new_valueinfo_name() - new_transpose = oh.make_node( - "Transpose", - [mt_cand.output[0]], - [transpose_output], - perm=[0, 3, 1, 2], - ) - graph.node.insert(node_ind + 1, new_transpose) - if final_t_cand is not None: - # rewire next node's input - final_t_cand.input[0] = transpose_output - else: - # replace graph top-level output - get_by_name( - model.graph.output, mt_cand.output[0] - ).name = transpose_output - model.set_tensor_shape(mt_cand.output[0], mt_ishape) - graph_modified = True + graph_modified = True if graph_modified: model = model.transform(InferDataTypes()) return (model, graph_modified) -- GitLab