diff --git a/src/finn/custom_op/fpgadataflow/streamingdatawidthconverter_batch.py b/src/finn/custom_op/fpgadataflow/streamingdatawidthconverter_batch.py index 23c1779a27c123583c0c8af5f53d022d03e78126..27238fc2a8764ea1fe357ffae4e884429af3e13e 100644 --- a/src/finn/custom_op/fpgadataflow/streamingdatawidthconverter_batch.py +++ b/src/finn/custom_op/fpgadataflow/streamingdatawidthconverter_batch.py @@ -76,24 +76,30 @@ class StreamingDataWidthConverter_Batch(HLSCustomOp): oshape = self.get_nodeattr("shape") return oshape + def check_divisible_iowidths(self): + impl_style = self.get_nodeattr("impl_style") + if impl_style == "hls": + # when using impl_style = hls must have the following + # if inWidth > outWidth: inWidth % outWidth = 0 + # if inWidth < outWidth: outWidth % inWidth = 0 + iwidth = self.get_nodeattr("inWidth") + owidth = self.get_nodeattr("outWidth") + if iwidth > owidth: + assert ( + iwidth % owidth == 0 + ), """DWC InWidth is bigger than OutWidth and is not divisible by it. + Please adjust PE and SIMD values so that InWidth % OutWidth = 0 + or alternatively use impl_style = vivado""" + else: + assert ( + owidth % iwidth == 0 + ), """DWC OutWidth is bigger than InWidth and is not divisible by it. + Please adjust PE and SIMD values so that OutWidth % InWidth = 0 + or alternatively use impl_style = vivado""" + def get_folded_input_shape(self): - # for correct functionality of the dwc node the - # following must apply: - # if inWidth > outWidth: inWidth % outWidth = 0 - # if inWidth < outWidth: outWidth % inWidth = 0 + self.check_divisible_iowidths() iwidth = self.get_nodeattr("inWidth") - owidth = self.get_nodeattr("outWidth") - if iwidth > owidth: - assert ( - iwidth % owidth == 0 - ), """InWidth is bigger than OutWidth and is not divisible by it. - Please adjust PE and SIMD values so that InWidth % OutWidth = 0""" - else: - assert ( - owidth % iwidth == 0 - ), """OutWidth is bigger than InWidth and is not divisible by it. - Please adjust PE and SIMD values so that OutWidth % InWidth = 0""" - ishape = self.get_normal_input_shape() dummy_t = np.random.randn(*ishape) ibits = self.get_input_datatype().bitwidth() @@ -112,23 +118,8 @@ class StreamingDataWidthConverter_Batch(HLSCustomOp): return dummy_t.shape def get_folded_output_shape(self): - # for correct functionality of the dwc node the - # following must apply: - # if inWidth > outWidth: inWidth % outWidth = 0 - # if inWidth < outWidth: outWidth % inWidth = 0 - iwidth = self.get_nodeattr("inWidth") + self.check_divisible_iowidths() owidth = self.get_nodeattr("outWidth") - if iwidth > owidth: - assert ( - iwidth % owidth == 0 - ), """InWidth is bigger than OutWidth and is not divisible by it. - Please adjust PE and SIMD values so that InWidth % OutWidth = 0""" - else: - assert ( - owidth % iwidth == 0 - ), """OutWidth is bigger than InWidth and is not divisible by it. - Please adjust PE and SIMD values so that OutWidth % InWidth = 0""" - oshape = self.get_normal_output_shape() dummy_t = np.random.randn(*oshape) obits = self.get_output_datatype().bitwidth() @@ -307,14 +298,17 @@ class StreamingDataWidthConverter_Batch(HLSCustomOp): def execute_node(self, context, graph): mode = self.get_nodeattr("exec_mode") + impl_style = self.get_nodeattr("impl_style") node = self.onnx_node exp_shape = self.get_normal_input_shape() folded_ishape = self.get_folded_input_shape() # TODO ensure codegen dir exists if mode == "cppsim": + assert impl_style == "hls", "DWC cppsim only possible when impl_style==hls" code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim") elif mode == "rtlsim": + assert impl_style == "hls", "DWC rtlsim only possible when impl_style==hls" code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") else: raise Exception(