diff --git a/src/finn/transformation/infer_data_layouts.py b/src/finn/transformation/infer_data_layouts.py index e7a6b88239a1735d5379e165333f8356ae6f88a1..d07162fa049bd016e91b8c5b01ea56eda6267655 100644 --- a/src/finn/transformation/infer_data_layouts.py +++ b/src/finn/transformation/infer_data_layouts.py @@ -75,6 +75,17 @@ def _infer_node_data_layout(model, node): inp_layout = model.get_tensor_layout(node.input[0]) out_layout = [inp_layout[i] for i in perm] model.set_tensor_layout(node.output[0], out_layout) + elif node.op_type == "Unsqueeze": + inp_layout = model.get_tensor_layout(node.input[0]) + # add dummy dimension at the output + out_layout = inp_layout + ["x"] + model.set_tensor_layout(node.output[0], out_layout) + elif node.op_type == "Squeeze": + inp_layout = model.get_tensor_layout(node.input[0]) + assert inp_layout[-1] == "x" + # remove dummy dimension + out_layout = inp_layout[:-1] + model.set_tensor_layout(node.output[0], out_layout) else: # try to guess based on number of output dims for o in node.output: