diff --git a/src/finn/custom_op/fpgadataflow/streamingdatawidthconverter_batch.py b/src/finn/custom_op/fpgadataflow/streamingdatawidthconverter_batch.py index 97be7c8c3783f86f47b7aefca646838c582b95c6..a51399a29996f7f9f19e699935179253e020ebfc 100644 --- a/src/finn/custom_op/fpgadataflow/streamingdatawidthconverter_batch.py +++ b/src/finn/custom_op/fpgadataflow/streamingdatawidthconverter_batch.py @@ -72,9 +72,25 @@ class StreamingDataWidthConverter_Batch(HLSCustomOp): return oshape def get_folded_input_shape(self): + # for correct functionality of the dwc node the + # following must apply: + # if inWidth > outWidth: inWidth % outWidth = 0 + # if inWidth < outWidth: outWidth % inWidth = 0 + iwidth = self.get_nodeattr("inWidth") + owidth = self.get_nodeattr("outWidth") + if iwidth > owidth: + assert ( + iwidth % owidth == 0 + ), """InWidth is bigger than OutWidth and is not divisible by it. + Please adjust PE and SIMD values so that InWidth % OutWidth = 0""" + else: + assert ( + owidth % iwidth == 0 + ), """OutWidth is bigger than InWidth and is not divisible by it. + Please adjust PE and SIMD values so that OutWidth % InWidth = 0""" + ishape = self.get_normal_input_shape() dummy_t = np.random.randn(*ishape) - iwidth = self.get_nodeattr("inWidth") ibits = self.get_input_datatype().bitwidth() assert ( iwidth % ibits == 0 @@ -91,9 +107,25 @@ class StreamingDataWidthConverter_Batch(HLSCustomOp): return dummy_t.shape def get_folded_output_shape(self): + # for correct functionality of the dwc node the + # following must apply: + # if inWidth > outWidth: inWidth % outWidth = 0 + # if inWidth < outWidth: outWidth % inWidth = 0 + iwidth = self.get_nodeattr("inWidth") + owidth = self.get_nodeattr("outWidth") + if iwidth > owidth: + assert ( + iwidth % owidth == 0 + ), """InWidth is bigger than OutWidth and is not divisible by it. + Please adjust PE and SIMD values so that InWidth % OutWidth = 0""" + else: + assert ( + owidth % iwidth == 0 + ), """OutWidth is bigger than InWidth and is not divisible by it. + Please adjust PE and SIMD values so that OutWidth % InWidth = 0""" + oshape = self.get_normal_output_shape() dummy_t = np.random.randn(*oshape) - owidth = self.get_nodeattr("outWidth") obits = self.get_output_datatype().bitwidth() assert ( owidth % obits == 0