From d2b363aa9c7ee2e562af6ccfbd4521c63c4918a6 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <yamanu@xilinx.com> Date: Thu, 24 Mar 2022 21:21:52 +0100 Subject: [PATCH] [Lookup] support genrating .dat file for external mem lookup --- src/finn/custom_op/fpgadataflow/lookup.py | 35 ++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/src/finn/custom_op/fpgadataflow/lookup.py b/src/finn/custom_op/fpgadataflow/lookup.py index 1a2945ea9..1a4c917fd 100644 --- a/src/finn/custom_op/fpgadataflow/lookup.py +++ b/src/finn/custom_op/fpgadataflow/lookup.py @@ -33,9 +33,11 @@ from math import ceil, log2 from finn.core.datatype import DataType from finn.custom_op.fpgadataflow.hlscustomop import HLSCustomOp +from finn.util.basic import roundup_to_integer_multiple from finn.util.data_packing import ( npy_to_rtlsim_input, numpy_to_hls_code, + pack_innermost_dim_as_hex_string, rtlsim_output_to_npy, ) @@ -310,9 +312,9 @@ class Lookup(HLSCustomOp): def generate_params(self, model, path): mem_mode = self.get_nodeattr("mem_mode") + embeddings = model.get_initializer(self.onnx_node.input[1]) if mem_mode == "const": code_gen_dir = path - embeddings = model.get_initializer(self.onnx_node.input[1]) weight_filename = "{}/embeddings.hpp".format(code_gen_dir) edt = DataType[self.get_nodeattr("EmbeddingType")] # obits = self.get_outstream_width() @@ -329,6 +331,37 @@ class Lookup(HLSCustomOp): f_thresh = open(weight_filename, "w") f_thresh.write(embeddings_hls_code) f_thresh.close() + elif mem_mode == "external": + edt = DataType[self.get_nodeattr("EmbeddingType")] + ext_mem_width = self.get_nodeattr("ext_mem_width") + assert edt.bitwidth() == 8, ( + "Lookup with mem_mode=external " + + "only works with 8-bit embeddings but found " + + str(edt) + ) + emb_dim = self.get_nodeattr("EmbeddingDim") + # need to zero-pad embeddings in external mode for burst alignment + # compute how much padding we need + emb_elems_per_ext_mem_width = self.get_folded_output_shape()[-1] + ext_mem_emb_size = self.get_folded_output_shape()[-2] + ext_mem_emb_align = ceil(log2(ext_mem_emb_size)) + align_factor = 2**ext_mem_emb_align + aligned_emb_dim = roundup_to_integer_multiple(emb_dim, align_factor) + pad_amount = aligned_emb_dim - emb_dim + embeddings_padded = np.pad(embeddings, [(0, 0), (0, pad_amount)]) + # reshape for packing the innermost dim + embeddings_padded = embeddings_padded.reshape( + -1, emb_elems_per_ext_mem_width + ) + weight_filename = "%s/%s.dat" % (path, self.onnx_node.name) + ret = pack_innermost_dim_as_hex_string( + embeddings_padded, edt, ext_mem_width, True, prefix="" + ) + with open(weight_filename, "w") as f: + for current_line in ret: + f.write(current_line + "\n") + else: + raise Exception("Unrecognized mem_mode: " + mem_mode) def execute_node(self, context, graph): mode = self.get_nodeattr("exec_mode") -- GitLab