From a3efdf209657979da289092d2d32bdad691830f3 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <maltanar@gmail.com> Date: Tue, 31 Mar 2020 22:12:11 +0100 Subject: [PATCH] [StreamingFC] rewrite the .dat generation logic for wt streamers --- .../fpgadataflow/streamingfclayer_batch.py | 25 +++++++++---------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py b/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py index 23585e178..00b8287a3 100644 --- a/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py +++ b/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py @@ -498,27 +498,26 @@ class StreamingFCLayer_Batch(HLSCustomOp): np.save(os.path.join(code_gen_dir, "weights.npy"), weight_tensor_flipped) """Saves weights into .dat file""" - # convert weight value sinto hexstring + # convert weight values into hexstring weight_width = self.get_weightstream_width() weight_tensor_unflipped = pack_innermost_dim_as_hex_string( - weight_tensor_unflipped, export_wdt, weight_width + weight_tensor_unflipped, export_wdt, weight_width, prefix="" ) weight_stream_len = np.prod(weight_tensor_unflipped.shape) assert ( weight_stream_len <= 1024 ), """Decoupled mem mode needs weight stream length <= 1024 for now""" - weight_pad = np.zeros((1024), int).astype(str) - weight_tensor_unflipped = weight_tensor_unflipped.flatten() - # delete "0x" in the beginning of the hexstring - for i in range(len(weight_tensor_unflipped)): - weight_tensor_unflipped[i] = weight_tensor_unflipped[i][2:] - weight_pad[: weight_tensor_unflipped.shape[0]] = weight_tensor_unflipped - weight_pad = weight_pad.copy() - f = open("{}/memblock_0.dat".format(code_gen_dir), "w+") - for val in weight_pad: - f.write(val + "\n") - f.close() + # add zeroes to pad out file to 1024 entries + weight_stream = weight_tensor_unflipped.flatten() + pad_amt = 1024 - weight_stream_len + weight_stream = np.pad( + weight_stream, (0, pad_amt), mode="constant", constant_values="0" + ) + weight_stream = weight_stream.copy() + with open("{}/memblock_0.dat".format(code_gen_dir), "w+") as f: + for val in weight_stream: + f.write(val + "\n") else: raise Exception( -- GitLab