From 00c423ddb1b21f215da353aea8908d6c2c1b0417 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <maltanar@gmail.com> Date: Mon, 14 Sep 2020 20:37:24 +0200 Subject: [PATCH] [DataLayout] handle Squeeze/Unsqueeze for InferDataLayouts --- src/finn/transformation/infer_data_layouts.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/finn/transformation/infer_data_layouts.py b/src/finn/transformation/infer_data_layouts.py index e7a6b8823..d07162fa0 100644 --- a/src/finn/transformation/infer_data_layouts.py +++ b/src/finn/transformation/infer_data_layouts.py @@ -75,6 +75,17 @@ def _infer_node_data_layout(model, node): inp_layout = model.get_tensor_layout(node.input[0]) out_layout = [inp_layout[i] for i in perm] model.set_tensor_layout(node.output[0], out_layout) + elif node.op_type == "Unsqueeze": + inp_layout = model.get_tensor_layout(node.input[0]) + # add dummy dimension at the output + out_layout = inp_layout + ["x"] + model.set_tensor_layout(node.output[0], out_layout) + elif node.op_type == "Squeeze": + inp_layout = model.get_tensor_layout(node.input[0]) + assert inp_layout[-1] == "x" + # remove dummy dimension + out_layout = inp_layout[:-1] + model.set_tensor_layout(node.output[0], out_layout) else: # try to guess based on number of output dims for o in node.output: -- GitLab