From c0d9d0a10305703cd823acd28d11e4764c39ce52 Mon Sep 17 00:00:00 2001
From: auphelia <jakobapk@web.de>
Date: Mon, 23 Mar 2020 16:38:49 +0000
Subject: [PATCH] [StreamingFC] Change function template of streaming MVAU and
 make template parameters for decoupled mode equal to embedded (const) mode

---
 .../fpgadataflow/streamingfclayer_batch.py    | 52 +++++++------------
 1 file changed, 20 insertions(+), 32 deletions(-)

diff --git a/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py b/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py
index 983b1b8fa..3cff5a6af 100644
--- a/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py
+++ b/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py
@@ -278,7 +278,6 @@ class StreamingFCLayer_Batch(HLSCustomOp):
         ret = dict()
         inp_hls_str = self.get_input_datatype().get_hls_datatype_str()
         out_hls_str = self.get_output_datatype().get_hls_datatype_str()
-        wt_hls_str = self.get_weight_datatype().get_hls_datatype_str()
         inp_is_binary = self.get_input_datatype() == DataType.BINARY
         # out_is_binary = self.get_output_datatype() == DataType.BINARY
         wt_is_binary = self.get_weight_datatype() == DataType.BINARY
@@ -294,36 +293,18 @@ class StreamingFCLayer_Batch(HLSCustomOp):
         # fill in TSrcI and TWeightI
         # TODO check these with Giulio
         # TODO handle non-bipolar binary inputs
-        mem_mode = self.get_nodeattr("mem_mode")
-        if mem_mode == "const":
-            if inp_is_bipolar and wt_is_bipolar:
-                ret["TSrcI"] = "Recast<XnorMul>"
-                ret["TWeightI"] = "Identity"
-            elif (not inp_is_bipolar) and wt_is_bipolar:
-                ret["TSrcI"] = "Slice<%s>" % inp_hls_str
-                ret["TWeightI"] = "Recast<Binary>"
-            elif inp_is_bipolar and (not wt_is_bipolar):
-                ret["TSrcI"] = "Recast<Binary>"
-                ret["TWeightI"] = "Identity"
-            elif (not inp_is_bipolar) and (not wt_is_bipolar):
-                ret["TSrcI"] = "Slice<%s>" % inp_hls_str
-                ret["TWeightI"] = "Identity"
-
-        elif mem_mode == "decoupled":
-            if inp_is_bipolar and wt_is_bipolar:
-                ret["TSrcI"] = "Recast<XnorMul>"
-                ret["TWeightI"] = "Identity"
-            elif (not inp_is_bipolar) and wt_is_bipolar:
-                ret["TSrcI"] = "Slice<%s>" % inp_hls_str
-                ret["TWeightI"] = "Recast<Binary>"
-                # ret["TWeightI"] = "Recast<Binary>"
-            elif inp_is_bipolar and (not wt_is_bipolar):
-                ret["TSrcI"] = "Recast<Binary>"
-                ret["TWeightI"] = "Slice<%s>" % wt_hls_str
-                # ret["TWeightI"] = "Slice<%s>" % wt_hls_str
-            elif (not inp_is_bipolar) and (not wt_is_bipolar):
-                ret["TSrcI"] = "Slice<%s>" % inp_hls_str
-                ret["TWeightI"] = "Slice<%s>" % wt_hls_str
+        if inp_is_bipolar and wt_is_bipolar:
+            ret["TSrcI"] = "Recast<XnorMul>"
+            ret["TWeightI"] = "Identity"
+        elif (not inp_is_bipolar) and wt_is_bipolar:
+            ret["TSrcI"] = "Slice<%s>" % inp_hls_str
+            ret["TWeightI"] = "Recast<Binary>"
+        elif inp_is_bipolar and (not wt_is_bipolar):
+            ret["TSrcI"] = "Recast<Binary>"
+            ret["TWeightI"] = "Identity"
+        elif (not inp_is_bipolar) and (not wt_is_bipolar):
+            ret["TSrcI"] = "Slice<%s>" % inp_hls_str
+            ret["TWeightI"] = "Identity"
 
         # fill in TDstI
         ret["TDstI"] = "Slice<%s>" % out_hls_str
@@ -761,12 +742,19 @@ class StreamingFCLayer_Batch(HLSCustomOp):
                 )
             ]
         elif mem_mode == "decoupled":
+            wdt = self.get_weight_datatype()
+            if wdt == DataType.BIPOLAR:
+                export_wdt = DataType.BINARY
+            else:
+                export_wdt = wdt
+            wdtype_hls_str = export_wdt.get_hls_datatype_str()
             self.code_gen_dict["$DOCOMPUTE$"] = [
-                """Matrix_Vector_Activate_Stream_Batch<MW1, MH1, SIMD1, PE1, WP1, {}, {}, {}>
+                """Matrix_Vector_Activate_Stream_Batch<MW1, MH1, SIMD1, PE1, {}, {}, {}, {} >
                 (in0, out, weights, {}, numReps, {});""".format(
                     tmpl_args["TSrcI"],
                     tmpl_args["TDstI"],
                     tmpl_args["TWeightI"],
+                    wdtype_hls_str,
                     threshs,
                     self.get_nodeattr("resType"),
                 )
-- 
GitLab