From 246ed946a41c0a35ec89a950168daca517ec7909 Mon Sep 17 00:00:00 2001
From: Lucian Petrica <lucianp@xilinx.com>
Date: Mon, 29 Jun 2020 14:24:22 +0000
Subject: [PATCH] Added checks for data layout

---
 .../transformation/fpgadataflow/insert_iodma.py     | 13 +++++++++++++
 1 file changed, 13 insertions(+)

diff --git a/src/finn/transformation/fpgadataflow/insert_iodma.py b/src/finn/transformation/fpgadataflow/insert_iodma.py
index e3808114a..d6812fe19 100644
--- a/src/finn/transformation/fpgadataflow/insert_iodma.py
+++ b/src/finn/transformation/fpgadataflow/insert_iodma.py
@@ -33,6 +33,7 @@ from finn.util.basic import get_by_name
 from finn.custom_op.registry import getCustomOp
 from finn.transformation import Transformation
 from finn.transformation.general import SortGraph
+import finn.core.data_layout as DataLayout
 import math
 import numpy as np
 
@@ -76,6 +77,10 @@ class InsertIODMA(Transformation):
             return (model, False)
         else:
             if final_node.op_type != "IODMA":
+                # check if tensor is NHWC
+                assert (
+                    model.get_tensor_layout(graph_out_name) == DataLayout.NHWC
+                ), "Data layout of tensors must be NHWC"
                 out_shape = model.get_tensor_shape(graph_out_name)
                 out_dtype = model.get_tensor_datatype(graph_out_name)
                 # determine the feasible interface width
@@ -109,6 +114,10 @@ class InsertIODMA(Transformation):
                 )
                 model.graph.node.append(dma_node)
             if first_node.op_type != "IODMA":
+                # check if tensor is NHWC
+                assert (
+                    model.get_tensor_layout(graph_in_name) == DataLayout.NHWC
+                ), "Data layout of tensors must be NHWC"
                 in_shape = model.get_tensor_shape(graph_in_name)
                 in_dtype = model.get_tensor_datatype(graph_in_name)
                 # determine the feasible interface width
@@ -142,6 +151,10 @@ class InsertIODMA(Transformation):
                 )
                 model.graph.node.insert(0, dma_node)
             for fc_node in fc_extw_nodes:
+                # check if tensor is NHWC
+                assert (
+                    model.get_tensor_layout(fc_node.input[1]) == DataLayout.NHWC
+                ), "Data layout of tensors must be NHWC"
                 fc_w_name = fc_node.input[1]
                 w_shape = model.get_tensor_shape(fc_w_name)
                 w_dtype = model.get_tensor_datatype(fc_w_name)
-- 
GitLab