Skip to content
Snippets Groups Projects
Commit 246ed946 authored by Lucian Petrica's avatar Lucian Petrica
Browse files

Added checks for data layout

parent 9a5ab142
No related branches found
No related tags found
No related merge requests found
......@@ -33,6 +33,7 @@ from finn.util.basic import get_by_name
from finn.custom_op.registry import getCustomOp
from finn.transformation import Transformation
from finn.transformation.general import SortGraph
import finn.core.data_layout as DataLayout
import math
import numpy as np
......@@ -76,6 +77,10 @@ class InsertIODMA(Transformation):
return (model, False)
else:
if final_node.op_type != "IODMA":
# check if tensor is NHWC
assert (
model.get_tensor_layout(graph_out_name) == DataLayout.NHWC
), "Data layout of tensors must be NHWC"
out_shape = model.get_tensor_shape(graph_out_name)
out_dtype = model.get_tensor_datatype(graph_out_name)
# determine the feasible interface width
......@@ -109,6 +114,10 @@ class InsertIODMA(Transformation):
)
model.graph.node.append(dma_node)
if first_node.op_type != "IODMA":
# check if tensor is NHWC
assert (
model.get_tensor_layout(graph_in_name) == DataLayout.NHWC
), "Data layout of tensors must be NHWC"
in_shape = model.get_tensor_shape(graph_in_name)
in_dtype = model.get_tensor_datatype(graph_in_name)
# determine the feasible interface width
......@@ -142,6 +151,10 @@ class InsertIODMA(Transformation):
)
model.graph.node.insert(0, dma_node)
for fc_node in fc_extw_nodes:
# check if tensor is NHWC
assert (
model.get_tensor_layout(fc_node.input[1]) == DataLayout.NHWC
), "Data layout of tensors must be NHWC"
fc_w_name = fc_node.input[1]
w_shape = model.get_tensor_shape(fc_w_name)
w_dtype = model.get_tensor_datatype(fc_w_name)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment