From 246ed946a41c0a35ec89a950168daca517ec7909 Mon Sep 17 00:00:00 2001 From: Lucian Petrica <lucianp@xilinx.com> Date: Mon, 29 Jun 2020 14:24:22 +0000 Subject: [PATCH] Added checks for data layout --- .../transformation/fpgadataflow/insert_iodma.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/finn/transformation/fpgadataflow/insert_iodma.py b/src/finn/transformation/fpgadataflow/insert_iodma.py index e3808114a..d6812fe19 100644 --- a/src/finn/transformation/fpgadataflow/insert_iodma.py +++ b/src/finn/transformation/fpgadataflow/insert_iodma.py @@ -33,6 +33,7 @@ from finn.util.basic import get_by_name from finn.custom_op.registry import getCustomOp from finn.transformation import Transformation from finn.transformation.general import SortGraph +import finn.core.data_layout as DataLayout import math import numpy as np @@ -76,6 +77,10 @@ class InsertIODMA(Transformation): return (model, False) else: if final_node.op_type != "IODMA": + # check if tensor is NHWC + assert ( + model.get_tensor_layout(graph_out_name) == DataLayout.NHWC + ), "Data layout of tensors must be NHWC" out_shape = model.get_tensor_shape(graph_out_name) out_dtype = model.get_tensor_datatype(graph_out_name) # determine the feasible interface width @@ -109,6 +114,10 @@ class InsertIODMA(Transformation): ) model.graph.node.append(dma_node) if first_node.op_type != "IODMA": + # check if tensor is NHWC + assert ( + model.get_tensor_layout(graph_in_name) == DataLayout.NHWC + ), "Data layout of tensors must be NHWC" in_shape = model.get_tensor_shape(graph_in_name) in_dtype = model.get_tensor_datatype(graph_in_name) # determine the feasible interface width @@ -142,6 +151,10 @@ class InsertIODMA(Transformation): ) model.graph.node.insert(0, dma_node) for fc_node in fc_extw_nodes: + # check if tensor is NHWC + assert ( + model.get_tensor_layout(fc_node.input[1]) == DataLayout.NHWC + ), "Data layout of tensors must be NHWC" fc_w_name = fc_node.input[1] w_shape = model.get_tensor_shape(fc_w_name) w_dtype = model.get_tensor_datatype(fc_w_name) -- GitLab