From 981f64d3cd19d7e08376b7b820780a8df273e68a Mon Sep 17 00:00:00 2001
From: auphelia <jakobapk@web.de>
Date: Fri, 3 Apr 2020 18:55:06 +0100
Subject: [PATCH] [StreamingDWC] Raise exceptions when setting of InWidth and
 OutWidth forbidden

---
 .../streamingdatawidthconverter_batch.py      | 36 +++++++++++++++++--
 1 file changed, 34 insertions(+), 2 deletions(-)

diff --git a/src/finn/custom_op/fpgadataflow/streamingdatawidthconverter_batch.py b/src/finn/custom_op/fpgadataflow/streamingdatawidthconverter_batch.py
index 97be7c8c3..a51399a29 100644
--- a/src/finn/custom_op/fpgadataflow/streamingdatawidthconverter_batch.py
+++ b/src/finn/custom_op/fpgadataflow/streamingdatawidthconverter_batch.py
@@ -72,9 +72,25 @@ class StreamingDataWidthConverter_Batch(HLSCustomOp):
         return oshape
 
     def get_folded_input_shape(self):
+        # for correct functionality of the dwc node the
+        # following must apply:
+        # if inWidth > outWidth: inWidth % outWidth = 0
+        # if inWidth < outWidth: outWidth % inWidth = 0
+        iwidth = self.get_nodeattr("inWidth")
+        owidth = self.get_nodeattr("outWidth")
+        if iwidth > owidth:
+            assert (
+                iwidth % owidth == 0
+            ), """InWidth is bigger than OutWidth and is not divisible by it.
+            Please adjust PE and SIMD values so that InWidth % OutWidth = 0"""
+        else:
+            assert (
+                owidth % iwidth == 0
+            ), """OutWidth is bigger than InWidth and is not divisible by it.
+            Please adjust PE and SIMD values so that OutWidth % InWidth = 0"""
+
         ishape = self.get_normal_input_shape()
         dummy_t = np.random.randn(*ishape)
-        iwidth = self.get_nodeattr("inWidth")
         ibits = self.get_input_datatype().bitwidth()
         assert (
             iwidth % ibits == 0
@@ -91,9 +107,25 @@ class StreamingDataWidthConverter_Batch(HLSCustomOp):
         return dummy_t.shape
 
     def get_folded_output_shape(self):
+        # for correct functionality of the dwc node the
+        # following must apply:
+        # if inWidth > outWidth: inWidth % outWidth = 0
+        # if inWidth < outWidth: outWidth % inWidth = 0
+        iwidth = self.get_nodeattr("inWidth")
+        owidth = self.get_nodeattr("outWidth")
+        if iwidth > owidth:
+            assert (
+                iwidth % owidth == 0
+            ), """InWidth is bigger than OutWidth and is not divisible by it.
+            Please adjust PE and SIMD values so that InWidth % OutWidth = 0"""
+        else:
+            assert (
+                owidth % iwidth == 0
+            ), """OutWidth is bigger than InWidth and is not divisible by it.
+            Please adjust PE and SIMD values so that OutWidth % InWidth = 0"""
+
         oshape = self.get_normal_output_shape()
         dummy_t = np.random.randn(*oshape)
-        owidth = self.get_nodeattr("outWidth")
         obits = self.get_output_datatype().bitwidth()
         assert (
             owidth % obits == 0
-- 
GitLab