Skip to content
Snippets Groups Projects
Commit 981f64d3 authored by auphelia's avatar auphelia
Browse files

[StreamingDWC] Raise exceptions when setting of InWidth and OutWidth forbidden

parent 23ea169c
No related branches found
No related tags found
No related merge requests found
......@@ -72,9 +72,25 @@ class StreamingDataWidthConverter_Batch(HLSCustomOp):
return oshape
def get_folded_input_shape(self):
# for correct functionality of the dwc node the
# following must apply:
# if inWidth > outWidth: inWidth % outWidth = 0
# if inWidth < outWidth: outWidth % inWidth = 0
iwidth = self.get_nodeattr("inWidth")
owidth = self.get_nodeattr("outWidth")
if iwidth > owidth:
assert (
iwidth % owidth == 0
), """InWidth is bigger than OutWidth and is not divisible by it.
Please adjust PE and SIMD values so that InWidth % OutWidth = 0"""
else:
assert (
owidth % iwidth == 0
), """OutWidth is bigger than InWidth and is not divisible by it.
Please adjust PE and SIMD values so that OutWidth % InWidth = 0"""
ishape = self.get_normal_input_shape()
dummy_t = np.random.randn(*ishape)
iwidth = self.get_nodeattr("inWidth")
ibits = self.get_input_datatype().bitwidth()
assert (
iwidth % ibits == 0
......@@ -91,9 +107,25 @@ class StreamingDataWidthConverter_Batch(HLSCustomOp):
return dummy_t.shape
def get_folded_output_shape(self):
# for correct functionality of the dwc node the
# following must apply:
# if inWidth > outWidth: inWidth % outWidth = 0
# if inWidth < outWidth: outWidth % inWidth = 0
iwidth = self.get_nodeattr("inWidth")
owidth = self.get_nodeattr("outWidth")
if iwidth > owidth:
assert (
iwidth % owidth == 0
), """InWidth is bigger than OutWidth and is not divisible by it.
Please adjust PE and SIMD values so that InWidth % OutWidth = 0"""
else:
assert (
owidth % iwidth == 0
), """OutWidth is bigger than InWidth and is not divisible by it.
Please adjust PE and SIMD values so that OutWidth % InWidth = 0"""
oshape = self.get_normal_output_shape()
dummy_t = np.random.randn(*oshape)
owidth = self.get_nodeattr("outWidth")
obits = self.get_output_datatype().bitwidth()
assert (
owidth % obits == 0
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment