From 28f9e437d48d755fdad2894fe32142b35e0716ba Mon Sep 17 00:00:00 2001
From: Lucian Petrica <lucianp@xilinx.com>
Date: Mon, 29 Jun 2020 15:37:04 +0000
Subject: [PATCH] Added NC in the list of allowed layouts

---
 src/finn/transformation/fpgadataflow/insert_iodma.py | 9 ++++++---
 1 file changed, 6 insertions(+), 3 deletions(-)

diff --git a/src/finn/transformation/fpgadataflow/insert_iodma.py b/src/finn/transformation/fpgadataflow/insert_iodma.py
index d6812fe19..2a2e056f0 100644
--- a/src/finn/transformation/fpgadataflow/insert_iodma.py
+++ b/src/finn/transformation/fpgadataflow/insert_iodma.py
@@ -80,7 +80,8 @@ class InsertIODMA(Transformation):
                 # check if tensor is NHWC
                 assert (
                     model.get_tensor_layout(graph_out_name) == DataLayout.NHWC
-                ), "Data layout of tensors must be NHWC"
+                    or model.get_tensor_layout(graph_in_name) == DataLayout.NC
+                ), "Data layout of tensors must be NHWC or NC"
                 out_shape = model.get_tensor_shape(graph_out_name)
                 out_dtype = model.get_tensor_datatype(graph_out_name)
                 # determine the feasible interface width
@@ -117,7 +118,8 @@ class InsertIODMA(Transformation):
                 # check if tensor is NHWC
                 assert (
                     model.get_tensor_layout(graph_in_name) == DataLayout.NHWC
-                ), "Data layout of tensors must be NHWC"
+                    or model.get_tensor_layout(graph_in_name) == DataLayout.NC
+                ), "Data layout of tensors must be NHWC or NC"
                 in_shape = model.get_tensor_shape(graph_in_name)
                 in_dtype = model.get_tensor_datatype(graph_in_name)
                 # determine the feasible interface width
@@ -154,7 +156,8 @@ class InsertIODMA(Transformation):
                 # check if tensor is NHWC
                 assert (
                     model.get_tensor_layout(fc_node.input[1]) == DataLayout.NHWC
-                ), "Data layout of tensors must be NHWC"
+                    or model.get_tensor_layout(graph_in_name) == DataLayout.NC
+                ), "Data layout of tensors must be NHWC or NC"
                 fc_w_name = fc_node.input[1]
                 w_shape = model.get_tensor_shape(fc_w_name)
                 w_dtype = model.get_tensor_datatype(fc_w_name)
-- 
GitLab