diff --git a/src/finn/custom_op/fpgadataflow/iodma.py b/src/finn/custom_op/fpgadataflow/iodma.py index 97d37cd99aad227aa3ec36fc2b06d83cb4171aca..1e342902a5288c0b275dc98697e0155865c18b28 100644 --- a/src/finn/custom_op/fpgadataflow/iodma.py +++ b/src/finn/custom_op/fpgadataflow/iodma.py @@ -177,14 +177,18 @@ class IODMA(HLSCustomOp): def get_instream_width(self): if self.get_nodeattr("direction") == "in": return self.get_nodeattr("intfWidth") - else: + elif self.get_nodeattr("direction") == "out": return self.get_nodeattr("streamWidth") + else: + raise ValueError("Invalid IODMA direction, please set to in or out") def get_outstream_width(self): if self.get_nodeattr("direction") == "out": return self.get_nodeattr("intfWidth") - else: + elif self.get_nodeattr("direction") == "in": return self.get_nodeattr("streamWidth") + else: + raise ValueError("Invalid IODMA direction, please set to in or out") def get_number_output_values(self): oshape = self.get_normal_output_shape() @@ -227,9 +231,11 @@ class IODMA(HLSCustomOp): else: func = "Mem2Stream_Batch" dwc_func = "WidthAdjustedOutputStream" - else: + elif direction == "out": func = "Stream2Mem_Batch" dwc_func = "WidthAdjustedInputStream" + else: + raise ValueError("Invalid IODMA direction, please set to in or out") # define templates for instantiation dma_inst_template = func + "<DataWidth1, NumBytes1>(%s, %s, numReps);" dwc_inst_template = dwc_func + "<%d, %d, %d> %s(%s, numReps);" @@ -269,11 +275,13 @@ class IODMA(HLSCustomOp): "void %s(%s *in0, hls::stream<%s > &out, unsigned int numReps)" % (self.onnx_node.name, packed_hls_type_in, packed_hls_type_out) ] - else: + elif direction == "out": self.code_gen_dict["$BLACKBOXFUNCTION$"] = [ "void %s(hls::stream<%s > &in0, %s *out, unsigned int numReps)" % (self.onnx_node.name, packed_hls_type_in, packed_hls_type_out) ] + else: + raise ValueError("Invalid IODMA direction, please set to in or out") def pragmas(self): self.code_gen_dict["$PRAGMAS$"] = [ @@ -293,7 +301,7 @@ class IODMA(HLSCustomOp): self.code_gen_dict["$PRAGMAS$"].append( "#pragma HLS INTERFACE axis port=out" ) - else: + elif direction == "out": self.code_gen_dict["$PRAGMAS$"].append( "#pragma HLS INTERFACE axis port=in0" ) @@ -303,6 +311,8 @@ class IODMA(HLSCustomOp): self.code_gen_dict["$PRAGMAS$"].append( "#pragma HLS INTERFACE s_axilite port=out bundle=control" ) + else: + raise ValueError("Invalid IODMA direction, please set to in or out") self.code_gen_dict["$PRAGMAS$"].append("#pragma HLS DATAFLOW") def execute_node(self, context, graph):