From 00c423ddb1b21f215da353aea8908d6c2c1b0417 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Mon, 14 Sep 2020 20:37:24 +0200
Subject: [PATCH] [DataLayout] handle Squeeze/Unsqueeze for InferDataLayouts

---
 src/finn/transformation/infer_data_layouts.py | 11 +++++++++++
 1 file changed, 11 insertions(+)

diff --git a/src/finn/transformation/infer_data_layouts.py b/src/finn/transformation/infer_data_layouts.py
index e7a6b8823..d07162fa0 100644
--- a/src/finn/transformation/infer_data_layouts.py
+++ b/src/finn/transformation/infer_data_layouts.py
@@ -75,6 +75,17 @@ def _infer_node_data_layout(model, node):
             inp_layout = model.get_tensor_layout(node.input[0])
             out_layout = [inp_layout[i] for i in perm]
             model.set_tensor_layout(node.output[0], out_layout)
+        elif node.op_type == "Unsqueeze":
+            inp_layout = model.get_tensor_layout(node.input[0])
+            # add dummy dimension at the output
+            out_layout = inp_layout + ["x"]
+            model.set_tensor_layout(node.output[0], out_layout)
+        elif node.op_type == "Squeeze":
+            inp_layout = model.get_tensor_layout(node.input[0])
+            assert inp_layout[-1] == "x"
+            # remove dummy dimension
+            out_layout = inp_layout[:-1]
+            model.set_tensor_layout(node.output[0], out_layout)
         else:
             # try to guess based on number of output dims
             for o in node.output:
-- 
GitLab