From 2724eef995b775ed7b404d1ba61604e94d9dc630 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <yamanu@xilinx.com>
Date: Tue, 31 Aug 2021 00:14:03 +0200
Subject: [PATCH] [DataLayout] change strategy in
 AbsorbTransposeIntoMultiThreshold

---
 src/finn/transformation/streamline/absorb.py | 78 ++++++++------------
 1 file changed, 32 insertions(+), 46 deletions(-)

diff --git a/src/finn/transformation/streamline/absorb.py b/src/finn/transformation/streamline/absorb.py
index 9902bdc34..e174759f0 100644
--- a/src/finn/transformation/streamline/absorb.py
+++ b/src/finn/transformation/streamline/absorb.py
@@ -308,8 +308,8 @@ class Absorb1BitMulIntoConv(Transformation):
 
 
 class AbsorbTransposeIntoMultiThreshold(Transformation):
-    """Change (NCHWTranspose -> MultiThreshold -> NHWCTranspose) to (MultiThreshold)
-    with NHWC mode. For (NCHWTranspose -> MultiThreshold) move Transpose past MT."""
+    """For (NCHWTranspose -> MultiThreshold) move Transpose past MultiThreshold
+    and set its data_layout mode to NHWC."""
 
     def apply(self, model):
         graph = model.graph
@@ -324,53 +324,39 @@ class AbsorbTransposeIntoMultiThreshold(Transformation):
                     if (
                         mt_cand is not None
                         and mt_cand.op_type == "MultiThreshold"
-                        and not model.is_fork_node(mt_cand)
+                        # and not model.is_fork_node(mt_cand)
                     ):
-                        final_t_cand = model.find_consumer(mt_cand.output[0])
-                        if (
-                            final_t_cand is not None
-                            and final_t_cand.op_type == "Transpose"
-                        ):
-                            perms = list(
-                                get_by_name(final_t_cand.attribute, "perm").ints
-                            )
-                            if perms == [0, 2, 3, 1]:
-                                mt = getCustomOp(mt_cand)
-                                mt.set_nodeattr("data_layout", "NHWC")
-                                # get rid of tranpose nodes, wire MT directly
-                                mt_cand.input[0] = n.input[0]
-                                mt_cand.output[0] = final_t_cand.output[0]
-                                graph.node.remove(n)
-                                graph.node.remove(final_t_cand)
-                                graph_modified = True
+                        final_t_cands = model.find_consumers(mt_cand.output[0])
+                        mt = getCustomOp(mt_cand)
+                        mt.set_nodeattr("data_layout", "NHWC")
+                        # get rid of first tranpose node
+                        mt_cand.input[0] = n.input[0]
+                        graph.node.remove(n)
+                        # fix output shape for MultiThreshold
+                        mt_orig_oshape = model.get_tensor_shape(mt_cand.output[0])
+                        mt_ishape = model.get_tensor_shape(mt_cand.input[0])
+                        model.set_tensor_shape(mt_cand.output[0], mt_ishape)
+                        # re-insert Transpose behind MultiThreshold
+                        transpose_output = model.make_new_valueinfo_name()
+                        model.set_tensor_shape(transpose_output, mt_orig_oshape)
+                        new_transpose = oh.make_node(
+                            "Transpose",
+                            [mt_cand.output[0]],
+                            [transpose_output],
+                            perm=[0, 3, 1, 2],
+                        )
+                        graph.node.insert(node_ind + 1, new_transpose)
+                        if final_t_cands is not None:
+                            # rewire next nodes' inputs
+                            for final_t_cand in final_t_cands:
+                                final_t_cand.input[0] = transpose_output
                         else:
-                            mt = getCustomOp(mt_cand)
-                            mt.set_nodeattr("data_layout", "NHWC")
-                            # get rid of first tranpose node
-                            mt_cand.input[0] = n.input[0]
-                            graph.node.remove(n)
-                            # fix output shape for MultiThreshold
-                            mt_ishape = model.get_tensor_shape(mt_cand.input[0])
+                            # replace graph top-level output
+                            get_by_name(
+                                model.graph.output, mt_cand.output[0]
+                            ).name = transpose_output
                             model.set_tensor_shape(mt_cand.output[0], mt_ishape)
-                            # re-insert Transpose behind MultiThreshold
-                            transpose_output = model.make_new_valueinfo_name()
-                            new_transpose = oh.make_node(
-                                "Transpose",
-                                [mt_cand.output[0]],
-                                [transpose_output],
-                                perm=[0, 3, 1, 2],
-                            )
-                            graph.node.insert(node_ind + 1, new_transpose)
-                            if final_t_cand is not None:
-                                # rewire next node's input
-                                final_t_cand.input[0] = transpose_output
-                            else:
-                                # replace graph top-level output
-                                get_by_name(
-                                    model.graph.output, mt_cand.output[0]
-                                ).name = transpose_output
-                                model.set_tensor_shape(mt_cand.output[0], mt_ishape)
-                            graph_modified = True
+                        graph_modified = True
         if graph_modified:
             model = model.transform(InferDataTypes())
         return (model, graph_modified)
-- 
GitLab