From 6ea013d52eaf062bc29780662c1f98925853b256 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Sun, 26 Apr 2020 23:19:34 +0100
Subject: [PATCH] [StreamingFC] fix STREAM_DEPTH for decoupled mode

---
 .../fpgadataflow/streamingfclayer_batch.py      | 17 +++++------------
 src/finn/custom_op/fpgadataflow/templates.py    |  2 +-
 2 files changed, 6 insertions(+), 13 deletions(-)

diff --git a/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py b/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py
index db9fc05bd..7408b119c 100644
--- a/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py
+++ b/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py
@@ -987,24 +987,17 @@ class StreamingFCLayer_Batch(HLSCustomOp):
             self.code_gen_dict["$LAYER_NAME$"] = [
                 "{}_{}".format(self.onnx_node.name, self.onnx_node.name)
             ]
-            # make instream width a multiple of 8 for axi interface
-            in_width = self.get_instream_width()
-            if in_width % 8 != 0:
-                in_width = math.floor(in_width / 8) + 8
+            # make instream width a multiple of 8 for AXI stream interface
+            in_width = roundup_to_integer_multiple(self.get_instream_width(), 8)
             self.code_gen_dict["$IN_RANGE$"] = ["[{}:0]".format(in_width - 1)]
             self.code_gen_dict["$OUT_RANGE$"] = [
                 "[{}:0]".format(self.get_outstream_width() - 1)
             ]
-            # make weight stream width a multiple of 8 for axi interface
-            weight_width = self.get_weightstream_width()
-            if weight_width % 8 != 0:
-                weight_width = math.floor(weight_width / 8) + 8
+            # make weight stream width a multiple of 8 for AXI stream interface
+            weight_width = roundup_to_integer_multiple(self.get_weightstream_width(), 8)
             self.code_gen_dict["$WEIGHT_RANGE$"] = ["[{}:0]".format(weight_width - 1)]
             self.code_gen_dict["$WEIGHT_WIDTH$"] = [str(weight_width)]
-            mw = self.get_nodeattr("MW")
-            mh = self.get_nodeattr("MH")
-            depth = int(mw * mh)
-            self.code_gen_dict["$WEIGHT_DEPTH$"] = [str(depth)]
+            self.code_gen_dict["$WSTREAM_DEPTH$"] = [str(self.calc_wmem())]
             self.code_gen_dict["$MEM_DEPTH$"] = [
                 str(roundup_to_integer_multiple(self.calc_wmem(), 1024))
             ]
diff --git a/src/finn/custom_op/fpgadataflow/templates.py b/src/finn/custom_op/fpgadataflow/templates.py
index 3e5205d9e..f6114ea02 100644
--- a/src/finn/custom_op/fpgadataflow/templates.py
+++ b/src/finn/custom_op/fpgadataflow/templates.py
@@ -196,7 +196,7 @@ memstream
 .STRM5_WIDTH($WEIGHT_WIDTH$),
 
 //depths per stream
-.STRM0_DEPTH($WEIGHT_DEPTH$),
+.STRM0_DEPTH($WSTREAM_DEPTH$),
 .STRM1_DEPTH(1),
 .STRM2_DEPTH(1),
 .STRM3_DEPTH(1),
-- 
GitLab