From 9a5ab1422f9e18cc0b204d295615187c7e16529d Mon Sep 17 00:00:00 2001
From: Lucian Petrica <lucianp@xilinx.com>
Date: Mon, 29 Jun 2020 14:03:25 +0000
Subject: [PATCH] Added explicit checks for direction values

---
 src/finn/custom_op/fpgadataflow/iodma.py | 20 +++++++++++++++-----
 1 file changed, 15 insertions(+), 5 deletions(-)

diff --git a/src/finn/custom_op/fpgadataflow/iodma.py b/src/finn/custom_op/fpgadataflow/iodma.py
index 97d37cd99..1e342902a 100644
--- a/src/finn/custom_op/fpgadataflow/iodma.py
+++ b/src/finn/custom_op/fpgadataflow/iodma.py
@@ -177,14 +177,18 @@ class IODMA(HLSCustomOp):
     def get_instream_width(self):
         if self.get_nodeattr("direction") == "in":
             return self.get_nodeattr("intfWidth")
-        else:
+        elif self.get_nodeattr("direction") == "out":
             return self.get_nodeattr("streamWidth")
+        else:
+            raise ValueError("Invalid IODMA direction, please set to in or out")
 
     def get_outstream_width(self):
         if self.get_nodeattr("direction") == "out":
             return self.get_nodeattr("intfWidth")
-        else:
+        elif self.get_nodeattr("direction") == "in":
             return self.get_nodeattr("streamWidth")
+        else:
+            raise ValueError("Invalid IODMA direction, please set to in or out")
 
     def get_number_output_values(self):
         oshape = self.get_normal_output_shape()
@@ -227,9 +231,11 @@ class IODMA(HLSCustomOp):
             else:
                 func = "Mem2Stream_Batch"
             dwc_func = "WidthAdjustedOutputStream"
-        else:
+        elif direction == "out":
             func = "Stream2Mem_Batch"
             dwc_func = "WidthAdjustedInputStream"
+        else:
+            raise ValueError("Invalid IODMA direction, please set to in or out")
         # define templates for instantiation
         dma_inst_template = func + "<DataWidth1, NumBytes1>(%s, %s, numReps);"
         dwc_inst_template = dwc_func + "<%d, %d, %d> %s(%s, numReps);"
@@ -269,11 +275,13 @@ class IODMA(HLSCustomOp):
                 "void %s(%s *in0, hls::stream<%s > &out, unsigned int numReps)"
                 % (self.onnx_node.name, packed_hls_type_in, packed_hls_type_out)
             ]
-        else:
+        elif direction == "out":
             self.code_gen_dict["$BLACKBOXFUNCTION$"] = [
                 "void %s(hls::stream<%s > &in0, %s *out, unsigned int numReps)"
                 % (self.onnx_node.name, packed_hls_type_in, packed_hls_type_out)
             ]
+        else:
+            raise ValueError("Invalid IODMA direction, please set to in or out")
 
     def pragmas(self):
         self.code_gen_dict["$PRAGMAS$"] = [
@@ -293,7 +301,7 @@ class IODMA(HLSCustomOp):
             self.code_gen_dict["$PRAGMAS$"].append(
                 "#pragma HLS INTERFACE axis port=out"
             )
-        else:
+        elif direction == "out":
             self.code_gen_dict["$PRAGMAS$"].append(
                 "#pragma HLS INTERFACE axis port=in0"
             )
@@ -303,6 +311,8 @@ class IODMA(HLSCustomOp):
             self.code_gen_dict["$PRAGMAS$"].append(
                 "#pragma HLS INTERFACE s_axilite port=out bundle=control"
             )
+        else:
+            raise ValueError("Invalid IODMA direction, please set to in or out")
         self.code_gen_dict["$PRAGMAS$"].append("#pragma HLS DATAFLOW")
 
     def execute_node(self, context, graph):
-- 
GitLab