From a3efdf209657979da289092d2d32bdad691830f3 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Tue, 31 Mar 2020 22:12:11 +0100
Subject: [PATCH] [StreamingFC] rewrite the .dat generation logic for wt
 streamers

---
 .../fpgadataflow/streamingfclayer_batch.py    | 25 +++++++++----------
 1 file changed, 12 insertions(+), 13 deletions(-)

diff --git a/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py b/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py
index 23585e178..00b8287a3 100644
--- a/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py
+++ b/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py
@@ -498,27 +498,26 @@ class StreamingFCLayer_Batch(HLSCustomOp):
             np.save(os.path.join(code_gen_dir, "weights.npy"), weight_tensor_flipped)
 
             """Saves weights into .dat file"""
-            # convert weight value sinto hexstring
+            # convert weight values into hexstring
             weight_width = self.get_weightstream_width()
             weight_tensor_unflipped = pack_innermost_dim_as_hex_string(
-                weight_tensor_unflipped, export_wdt, weight_width
+                weight_tensor_unflipped, export_wdt, weight_width, prefix=""
             )
             weight_stream_len = np.prod(weight_tensor_unflipped.shape)
             assert (
                 weight_stream_len <= 1024
             ), """Decoupled mem mode needs
             weight stream length <= 1024 for now"""
-            weight_pad = np.zeros((1024), int).astype(str)
-            weight_tensor_unflipped = weight_tensor_unflipped.flatten()
-            # delete "0x" in the beginning of the hexstring
-            for i in range(len(weight_tensor_unflipped)):
-                weight_tensor_unflipped[i] = weight_tensor_unflipped[i][2:]
-            weight_pad[: weight_tensor_unflipped.shape[0]] = weight_tensor_unflipped
-            weight_pad = weight_pad.copy()
-            f = open("{}/memblock_0.dat".format(code_gen_dir), "w+")
-            for val in weight_pad:
-                f.write(val + "\n")
-            f.close()
+            # add zeroes to pad out file to 1024 entries
+            weight_stream = weight_tensor_unflipped.flatten()
+            pad_amt = 1024 - weight_stream_len
+            weight_stream = np.pad(
+                weight_stream, (0, pad_amt), mode="constant", constant_values="0"
+            )
+            weight_stream = weight_stream.copy()
+            with open("{}/memblock_0.dat".format(code_gen_dir), "w+") as f:
+                for val in weight_stream:
+                    f.write(val + "\n")
 
         else:
             raise Exception(
-- 
GitLab