From de37b2a6c84eff2c9d5913bbeb18ad25c45a3eed Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Mon, 23 Mar 2020 00:30:43 +0000
Subject: [PATCH] [Transform] add AbsorbTransposeIntoMultiThreshold

---
 src/finn/transformation/streamline/absorb.py | 52 ++++++++++++++++++++
 1 file changed, 52 insertions(+)

diff --git a/src/finn/transformation/streamline/absorb.py b/src/finn/transformation/streamline/absorb.py
index 92945a9ef..5795d9d71 100644
--- a/src/finn/transformation/streamline/absorb.py
+++ b/src/finn/transformation/streamline/absorb.py
@@ -31,6 +31,8 @@ from onnx import helper as oh
 
 from finn.core.datatype import DataType
 from finn.transformation import Transformation
+from finn.util.basic import get_by_name
+from finn.custom_op.registry import getCustomOp
 
 
 class AbsorbAddIntoMultiThreshold(Transformation):
@@ -227,3 +229,53 @@ class Absorb1BitMulIntoConv(Transformation):
                             graph.node.remove(consumer)
                             graph_modified = True
         return (model, graph_modified)
+
+
+class AbsorbTransposeIntoMultiThreshold(Transformation):
+    """Change (NHWCTranpose -> MultiThreshold -> NCHWTranspose) to (MultiThreshold)
+    with NHWC mode."""
+
+    def apply(self, model):
+        graph = model.graph
+        node_ind = 0
+        graph_modified = False
+        for n in graph.node:
+            node_ind += 1
+            if n.op_type == "Transpose":
+                perms = list(get_by_name(n.attribute, "perm").ints)
+                if perms == [0, 3, 1, 2]:
+                    mt_cand = model.find_consumer(n.output[0])
+                    if mt_cand.op_type == "MultiThreshold":
+                        final_t_cand = model.find_consumer(mt_cand.output[0])
+                        if final_t_cand.op_type == "Transpose":
+                            perms = list(
+                                get_by_name(final_t_cand.attribute, "perm").ints
+                            )
+                            if perms == [0, 2, 3, 1]:
+                                mt = getCustomOp(mt_cand)
+                                mt.set_nodeattr("data_layout", "NHWC")
+                                # get rid of tranpose nodes, wire MT directly
+                                mt_cand.input[0] = n.input[0]
+                                mt_cand.output[0] = final_t_cand.output[0]
+                                graph.node.remove(n)
+                                graph.node.remove(final_t_cand)
+                                graph_modified = True
+                        elif final_t_cand.op_type == "Reshape":
+                            oshape = model.get_tensor_shape(final_t_cand.output[0])
+                            if len(oshape) == 2:
+                                # transition to FC part, can still use NHWC
+                                mt = getCustomOp(mt_cand)
+                                mt.set_nodeattr("data_layout", "NHWC")
+                                # get rid of first tranpose node
+                                mt_cand.input[0] = n.input[0]
+                                # fix output shape for MultiThreshold
+                                mt_ishape = model.get_tensor_shape(mt_cand.input[0])
+                                (b, h, w, c) = mt_ishape
+                                assert (
+                                    h == 1 and w == 1
+                                ), """Untested spatial dim
+                                in conv->fc transition, proceed with caution!"""
+                                model.set_tensor_shape(mt_cand.output[0], mt_ishape)
+                                graph.node.remove(n)
+                                graph_modified = True
+        return (model, graph_modified)
-- 
GitLab