Skip to content
Snippets Groups Projects
Commit 00c423dd authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[DataLayout] handle Squeeze/Unsqueeze for InferDataLayouts

parent d7fe1ef4
No related branches found
No related tags found
No related merge requests found
......@@ -75,6 +75,17 @@ def _infer_node_data_layout(model, node):
inp_layout = model.get_tensor_layout(node.input[0])
out_layout = [inp_layout[i] for i in perm]
model.set_tensor_layout(node.output[0], out_layout)
elif node.op_type == "Unsqueeze":
inp_layout = model.get_tensor_layout(node.input[0])
# add dummy dimension at the output
out_layout = inp_layout + ["x"]
model.set_tensor_layout(node.output[0], out_layout)
elif node.op_type == "Squeeze":
inp_layout = model.get_tensor_layout(node.input[0])
assert inp_layout[-1] == "x"
# remove dummy dimension
out_layout = inp_layout[:-1]
model.set_tensor_layout(node.output[0], out_layout)
else:
# try to guess based on number of output dims
for o in node.output:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment