Skip to content
Snippets Groups Projects
Commit 28f9e437 authored by Lucian Petrica's avatar Lucian Petrica
Browse files

Added NC in the list of allowed layouts

parent 8315f708
No related branches found
No related tags found
No related merge requests found
......@@ -80,7 +80,8 @@ class InsertIODMA(Transformation):
# check if tensor is NHWC
assert (
model.get_tensor_layout(graph_out_name) == DataLayout.NHWC
), "Data layout of tensors must be NHWC"
or model.get_tensor_layout(graph_in_name) == DataLayout.NC
), "Data layout of tensors must be NHWC or NC"
out_shape = model.get_tensor_shape(graph_out_name)
out_dtype = model.get_tensor_datatype(graph_out_name)
# determine the feasible interface width
......@@ -117,7 +118,8 @@ class InsertIODMA(Transformation):
# check if tensor is NHWC
assert (
model.get_tensor_layout(graph_in_name) == DataLayout.NHWC
), "Data layout of tensors must be NHWC"
or model.get_tensor_layout(graph_in_name) == DataLayout.NC
), "Data layout of tensors must be NHWC or NC"
in_shape = model.get_tensor_shape(graph_in_name)
in_dtype = model.get_tensor_datatype(graph_in_name)
# determine the feasible interface width
......@@ -154,7 +156,8 @@ class InsertIODMA(Transformation):
# check if tensor is NHWC
assert (
model.get_tensor_layout(fc_node.input[1]) == DataLayout.NHWC
), "Data layout of tensors must be NHWC"
or model.get_tensor_layout(graph_in_name) == DataLayout.NC
), "Data layout of tensors must be NHWC or NC"
fc_w_name = fc_node.input[1]
w_shape = model.get_tensor_shape(fc_w_name)
w_dtype = model.get_tensor_datatype(fc_w_name)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment