From ce4d70c9cb0bf982e4dbbda61fb522be8a3f415b Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Mon, 10 Feb 2020 22:43:29 +0100
Subject: [PATCH] [StreamingFC] fix pragmas for array partitioning

---
 .../fpgadataflow/streamingfclayer_batch.py    | 23 +++++++++++++------
 1 file changed, 16 insertions(+), 7 deletions(-)

diff --git a/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py b/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py
index a42397414..33a5d963d 100644
--- a/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py
+++ b/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py
@@ -628,18 +628,27 @@ class StreamingFCLayer_Batch(HLSCustomOp):
         self.code_gen_dict["$PRAGMAS$"].append(
             "#pragma HLS INTERFACE ap_ctrl_none port=return"
         )
+        # the weight tensor is ap_uint<simd*prec> [PE][WMEM]
+        # partition for parallel access along the PE dimension (dim 1)
         self.code_gen_dict["$PRAGMAS$"].append(
-            "DO_PRAGMA(HLS ARRAY_PARTITION variable=weights complete dim=1)"
-        )
-
-        self.code_gen_dict["$PRAGMAS$"].append(
-            "DO_PRAGMA(HLS ARRAY_PARTITION variable=weights complete dim=2)"
+            (
+                "DO_PRAGMA(HLS ARRAY_PARTITION "
+                "variable=weights.m_weights complete dim=1)"
+            )
         )
+        # the threshold tensor is acc_type [PE][TMEM][N_THRES]
+        # partition for parallel access along PE and N_THRES dimensions (dims 1 and 3)
         if self.calc_tmem() != 0:
             # TODO find a better way of checking for no pregenerated thresholds
             self.code_gen_dict["$PRAGMAS$"].append(
-                "DO_PRAGMA(HLS ARRAY_PARTITION variable=threshs complete dim=1)"
+                (
+                    "DO_PRAGMA(HLS ARRAY_PARTITION variable=threshs.m_thresholds "
+                    "complete dim=1)"
+                )
             )
             self.code_gen_dict["$PRAGMAS$"].append(
-                "DO_PRAGMA(HLS ARRAY_PARTITION variable=threshs complete dim=3)"
+                (
+                    "DO_PRAGMA(HLS ARRAY_PARTITION variable=threshs.m_thresholds "
+                    "complete dim=3)"
+                )
             )
-- 
GitLab