diff --git a/src/finn/transformation/infer_data_layouts.py b/src/finn/transformation/infer_data_layouts.py
index e7a6b88239a1735d5379e165333f8356ae6f88a1..d07162fa049bd016e91b8c5b01ea56eda6267655 100644
--- a/src/finn/transformation/infer_data_layouts.py
+++ b/src/finn/transformation/infer_data_layouts.py
@@ -75,6 +75,17 @@ def _infer_node_data_layout(model, node):
             inp_layout = model.get_tensor_layout(node.input[0])
             out_layout = [inp_layout[i] for i in perm]
             model.set_tensor_layout(node.output[0], out_layout)
+        elif node.op_type == "Unsqueeze":
+            inp_layout = model.get_tensor_layout(node.input[0])
+            # add dummy dimension at the output
+            out_layout = inp_layout + ["x"]
+            model.set_tensor_layout(node.output[0], out_layout)
+        elif node.op_type == "Squeeze":
+            inp_layout = model.get_tensor_layout(node.input[0])
+            assert inp_layout[-1] == "x"
+            # remove dummy dimension
+            out_layout = inp_layout[:-1]
+            model.set_tensor_layout(node.output[0], out_layout)
         else:
             # try to guess based on number of output dims
             for o in node.output: