From 0bdecff51a68eef89eb50ce941e3da1c5ea81144 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <maltanar@gmail.com> Date: Wed, 20 May 2020 02:16:26 +0100 Subject: [PATCH] [Transform] bugfixes to data layout inference transform --- src/finn/transformation/infer_data_layouts.py | 52 +++++++++---------- 1 file changed, 24 insertions(+), 28 deletions(-) diff --git a/src/finn/transformation/infer_data_layouts.py b/src/finn/transformation/infer_data_layouts.py index bb06f6abc..9ac75578f 100644 --- a/src/finn/transformation/infer_data_layouts.py +++ b/src/finn/transformation/infer_data_layouts.py @@ -33,33 +33,29 @@ import warnings from finn.util.basic import get_by_name -def _dims_to_layout(node, ndims): - if node.domain == "finn": - if node.op_type == "MultiThreshold": - mt_inst = registry.getCustomOp(node) - layout = mt_inst.get_nodeattr("data_layout") - if ndims == 2: - return DataLayout.NC - elif layout == "NHWC" and ndims == 4: - return DataLayout.NHWC - elif layout == "NCHW" and ndims == 4: - return DataLayout.NCHW - else: - return DataLayout.UNKNOWN - else: - if ndims == 2: - return DataLayout.NC - elif ndims == 4: - return DataLayout.NHWC - else: - return DataLayout.UNKNOWN +def _dims_to_layout(model, node, ndims): + if ndims == 2: + return DataLayout.NC else: - if ndims == 2: - return DataLayout.NC - elif ndims == 4: - return DataLayout.NCHW + if node.domain == "finn": + if node.op_type == "MultiThreshold": + mt_inst = registry.getCustomOp(node) + layout = mt_inst.get_nodeattr("data_layout") + if layout == "NHWC" and ndims == 4: + return DataLayout.NHWC + elif layout == "NCHW" and ndims == 4: + return DataLayout.NCHW + else: + return DataLayout.UNKNOWN + else: + if ndims == 4: + return DataLayout.NHWC + else: + return DataLayout.UNKNOWN else: - return DataLayout.UNKNOWN + # propagate input layout to output + # TODO this won't work for concat, squeeze/unsqueeze/reshape... + return model.get_tensor_layout(node.input[0]) def _infer_node_data_layout(model, node): @@ -70,20 +66,20 @@ def _infer_node_data_layout(model, node): # try to guess based on number of output dims for o in node.output: ndims = len(model.get_tensor_shape(o)) - new_layout = _dims_to_layout(node, ndims) + new_layout = _dims_to_layout(model, node, ndims) model.set_tensor_layout(o, new_layout) else: if node.op_type == "Transpose": # grab input annotation and switch it around using perm perm = get_by_name(node.attribute, "perm").ints inp_layout = model.get_tensor_layout(node.input[0]) - out_layout = [x for _, x in sorted(zip(perm, inp_layout))] + out_layout = [inp_layout[i] for i in perm] model.set_tensor_layout(node.output[0], out_layout) else: # try to guess based on number of output dims for o in node.output: ndims = len(model.get_tensor_shape(o)) - model.set_tensor_layout(o, _dims_to_layout(node, ndims)) + model.set_tensor_layout(o, _dims_to_layout(model, node, ndims)) # compare old and new output dtypes to see if anything changed new_layouts = list(map(lambda x: model.get_tensor_layout(x), node.output)) graph_modified = new_layouts != old_layouts -- GitLab