From 0bdecff51a68eef89eb50ce941e3da1c5ea81144 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Wed, 20 May 2020 02:16:26 +0100
Subject: [PATCH] [Transform] bugfixes to data layout inference transform

---
 src/finn/transformation/infer_data_layouts.py | 52 +++++++++----------
 1 file changed, 24 insertions(+), 28 deletions(-)

diff --git a/src/finn/transformation/infer_data_layouts.py b/src/finn/transformation/infer_data_layouts.py
index bb06f6abc..9ac75578f 100644
--- a/src/finn/transformation/infer_data_layouts.py
+++ b/src/finn/transformation/infer_data_layouts.py
@@ -33,33 +33,29 @@ import warnings
 from finn.util.basic import get_by_name
 
 
-def _dims_to_layout(node, ndims):
-    if node.domain == "finn":
-        if node.op_type == "MultiThreshold":
-            mt_inst = registry.getCustomOp(node)
-            layout = mt_inst.get_nodeattr("data_layout")
-            if ndims == 2:
-                return DataLayout.NC
-            elif layout == "NHWC" and ndims == 4:
-                return DataLayout.NHWC
-            elif layout == "NCHW" and ndims == 4:
-                return DataLayout.NCHW
-            else:
-                return DataLayout.UNKNOWN
-        else:
-            if ndims == 2:
-                return DataLayout.NC
-            elif ndims == 4:
-                return DataLayout.NHWC
-            else:
-                return DataLayout.UNKNOWN
+def _dims_to_layout(model, node, ndims):
+    if ndims == 2:
+        return DataLayout.NC
     else:
-        if ndims == 2:
-            return DataLayout.NC
-        elif ndims == 4:
-            return DataLayout.NCHW
+        if node.domain == "finn":
+            if node.op_type == "MultiThreshold":
+                mt_inst = registry.getCustomOp(node)
+                layout = mt_inst.get_nodeattr("data_layout")
+                if layout == "NHWC" and ndims == 4:
+                    return DataLayout.NHWC
+                elif layout == "NCHW" and ndims == 4:
+                    return DataLayout.NCHW
+                else:
+                    return DataLayout.UNKNOWN
+            else:
+                if ndims == 4:
+                    return DataLayout.NHWC
+                else:
+                    return DataLayout.UNKNOWN
         else:
-            return DataLayout.UNKNOWN
+            # propagate input layout to output
+            # TODO this won't work for concat, squeeze/unsqueeze/reshape...
+            return model.get_tensor_layout(node.input[0])
 
 
 def _infer_node_data_layout(model, node):
@@ -70,20 +66,20 @@ def _infer_node_data_layout(model, node):
         # try to guess based on number of output dims
         for o in node.output:
             ndims = len(model.get_tensor_shape(o))
-            new_layout = _dims_to_layout(node, ndims)
+            new_layout = _dims_to_layout(model, node, ndims)
             model.set_tensor_layout(o, new_layout)
     else:
         if node.op_type == "Transpose":
             # grab input annotation and switch it around using perm
             perm = get_by_name(node.attribute, "perm").ints
             inp_layout = model.get_tensor_layout(node.input[0])
-            out_layout = [x for _, x in sorted(zip(perm, inp_layout))]
+            out_layout = [inp_layout[i] for i in perm]
             model.set_tensor_layout(node.output[0], out_layout)
         else:
             # try to guess based on number of output dims
             for o in node.output:
                 ndims = len(model.get_tensor_shape(o))
-                model.set_tensor_layout(o, _dims_to_layout(node, ndims))
+                model.set_tensor_layout(o, _dims_to_layout(model, node, ndims))
     # compare old and new output dtypes to see if anything changed
     new_layouts = list(map(lambda x: model.get_tensor_layout(x), node.output))
     graph_modified = new_layouts != old_layouts
-- 
GitLab