diff --git a/src/finn/transformation/fpgadataflow/insert_iodma.py b/src/finn/transformation/fpgadataflow/insert_iodma.py
index d6812fe1999647b50ec59e649130f048ec593675..2a2e056f0c8db8dc1c5320be9c5788941083c07e 100644
--- a/src/finn/transformation/fpgadataflow/insert_iodma.py
+++ b/src/finn/transformation/fpgadataflow/insert_iodma.py
@@ -80,7 +80,8 @@ class InsertIODMA(Transformation):
                 # check if tensor is NHWC
                 assert (
                     model.get_tensor_layout(graph_out_name) == DataLayout.NHWC
-                ), "Data layout of tensors must be NHWC"
+                    or model.get_tensor_layout(graph_in_name) == DataLayout.NC
+                ), "Data layout of tensors must be NHWC or NC"
                 out_shape = model.get_tensor_shape(graph_out_name)
                 out_dtype = model.get_tensor_datatype(graph_out_name)
                 # determine the feasible interface width
@@ -117,7 +118,8 @@ class InsertIODMA(Transformation):
                 # check if tensor is NHWC
                 assert (
                     model.get_tensor_layout(graph_in_name) == DataLayout.NHWC
-                ), "Data layout of tensors must be NHWC"
+                    or model.get_tensor_layout(graph_in_name) == DataLayout.NC
+                ), "Data layout of tensors must be NHWC or NC"
                 in_shape = model.get_tensor_shape(graph_in_name)
                 in_dtype = model.get_tensor_datatype(graph_in_name)
                 # determine the feasible interface width
@@ -154,7 +156,8 @@ class InsertIODMA(Transformation):
                 # check if tensor is NHWC
                 assert (
                     model.get_tensor_layout(fc_node.input[1]) == DataLayout.NHWC
-                ), "Data layout of tensors must be NHWC"
+                    or model.get_tensor_layout(graph_in_name) == DataLayout.NC
+                ), "Data layout of tensors must be NHWC or NC"
                 fc_w_name = fc_node.input[1]
                 w_shape = model.get_tensor_shape(fc_w_name)
                 w_dtype = model.get_tensor_datatype(fc_w_name)