Skip to content
Snippets Groups Projects
Commit 0bdecff5 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Transform] bugfixes to data layout inference transform

parent dd1348e5
No related branches found
No related tags found
No related merge requests found
......@@ -33,33 +33,29 @@ import warnings
from finn.util.basic import get_by_name
def _dims_to_layout(node, ndims):
if node.domain == "finn":
if node.op_type == "MultiThreshold":
mt_inst = registry.getCustomOp(node)
layout = mt_inst.get_nodeattr("data_layout")
if ndims == 2:
return DataLayout.NC
elif layout == "NHWC" and ndims == 4:
return DataLayout.NHWC
elif layout == "NCHW" and ndims == 4:
return DataLayout.NCHW
else:
return DataLayout.UNKNOWN
else:
if ndims == 2:
return DataLayout.NC
elif ndims == 4:
return DataLayout.NHWC
else:
return DataLayout.UNKNOWN
def _dims_to_layout(model, node, ndims):
if ndims == 2:
return DataLayout.NC
else:
if ndims == 2:
return DataLayout.NC
elif ndims == 4:
return DataLayout.NCHW
if node.domain == "finn":
if node.op_type == "MultiThreshold":
mt_inst = registry.getCustomOp(node)
layout = mt_inst.get_nodeattr("data_layout")
if layout == "NHWC" and ndims == 4:
return DataLayout.NHWC
elif layout == "NCHW" and ndims == 4:
return DataLayout.NCHW
else:
return DataLayout.UNKNOWN
else:
if ndims == 4:
return DataLayout.NHWC
else:
return DataLayout.UNKNOWN
else:
return DataLayout.UNKNOWN
# propagate input layout to output
# TODO this won't work for concat, squeeze/unsqueeze/reshape...
return model.get_tensor_layout(node.input[0])
def _infer_node_data_layout(model, node):
......@@ -70,20 +66,20 @@ def _infer_node_data_layout(model, node):
# try to guess based on number of output dims
for o in node.output:
ndims = len(model.get_tensor_shape(o))
new_layout = _dims_to_layout(node, ndims)
new_layout = _dims_to_layout(model, node, ndims)
model.set_tensor_layout(o, new_layout)
else:
if node.op_type == "Transpose":
# grab input annotation and switch it around using perm
perm = get_by_name(node.attribute, "perm").ints
inp_layout = model.get_tensor_layout(node.input[0])
out_layout = [x for _, x in sorted(zip(perm, inp_layout))]
out_layout = [inp_layout[i] for i in perm]
model.set_tensor_layout(node.output[0], out_layout)
else:
# try to guess based on number of output dims
for o in node.output:
ndims = len(model.get_tensor_shape(o))
model.set_tensor_layout(o, _dims_to_layout(node, ndims))
model.set_tensor_layout(o, _dims_to_layout(model, node, ndims))
# compare old and new output dtypes to see if anything changed
new_layouts = list(map(lambda x: model.get_tensor_layout(x), node.output))
graph_modified = new_layouts != old_layouts
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment