From 5e50f6443395937dedb39cc365b5c304a2ad8672 Mon Sep 17 00:00:00 2001
From: Lucian Petrica <lucianp@xilinx.com>
Date: Mon, 19 Oct 2020 09:48:03 +0000
Subject: [PATCH] Modified DWC instantiation to process all input words in a
 single call, optimizing throughput

---
 .../fpgadataflow/streamingdatawidthconverter_batch.py  | 10 ++++------
 1 file changed, 4 insertions(+), 6 deletions(-)

diff --git a/src/finn/custom_op/fpgadataflow/streamingdatawidthconverter_batch.py b/src/finn/custom_op/fpgadataflow/streamingdatawidthconverter_batch.py
index 748880400..23a906882 100644
--- a/src/finn/custom_op/fpgadataflow/streamingdatawidthconverter_batch.py
+++ b/src/finn/custom_op/fpgadataflow/streamingdatawidthconverter_batch.py
@@ -214,7 +214,7 @@ class StreamingDataWidthConverter_Batch(HLSCustomOp):
 
     def defines(self, var):
         numReps = 1
-        numInWords = 1
+        numInWords = int(np.prod(self.get_folded_input_shape()[:-1]))
         inWidth = self.get_nodeattr("inWidth")
         outWidth = self.get_nodeattr("outWidth")
         if outWidth > inWidth:
@@ -451,7 +451,6 @@ class StreamingDataWidthConverter_Batch(HLSCustomOp):
 
     def lut_estimation(self):
         """Calculates resource estimations for LUTs"""
-        impl = self.get_nodeattr("impl_style")
         inw = self.get_instream_width()
         outw = self.get_outstream_width()
 
@@ -461,7 +460,7 @@ class StreamingDataWidthConverter_Batch(HLSCustomOp):
         # sometimes withs aren't directly divisible
         # this requires going up from input width to least common multiple
         # then down to output width
-        intw = abs(maxw*minw) // math.gcd(maxw, minw)
+        intw = abs(maxw * minw) // math.gcd(maxw, minw)
 
         # we assume a shift-based implementation
         # even if we don't use LUTs explicitly, we make some unavailable
@@ -471,11 +470,10 @@ class StreamingDataWidthConverter_Batch(HLSCustomOp):
         cset_luts = 0
 
         if inw != intw:
-            cnt_luts += abs(math.ceil(math.log(inw/intw, 2)))
+            cnt_luts += abs(math.ceil(math.log(inw / intw, 2)))
             cset_luts += intw
         if intw != outw:
             cnt_luts += abs(math.ceil(math.log(intw / outw, 2)))
             cset_luts += outw
 
-        return int(cnt_luts+cset_luts)
-
+        return int(cnt_luts + cset_luts)
-- 
GitLab