diff --git a/src/finn/custom_op/fpgadataflow/lookup.py b/src/finn/custom_op/fpgadataflow/lookup.py index cd50f0bb605c185c00c218e3ec8bc911d1b85309..d34209ce6c39286b33e27d5a9ce714328596ac30 100644 --- a/src/finn/custom_op/fpgadataflow/lookup.py +++ b/src/finn/custom_op/fpgadataflow/lookup.py @@ -29,7 +29,7 @@ import numpy as np import os import warnings -from math import ceil +from math import ceil, log2 from finn.core.datatype import DataType from finn.custom_op.fpgadataflow.hlscustomop import HLSCustomOp @@ -58,6 +58,13 @@ class Lookup(HLSCustomOp): "InputType": ("s", True, ""), # Input shape "InputShape": ("ints", False, [1]), + # Memory mode + # const : parameters baked into bitfile (BRAM) + # external : lookup performed in external memory over AXI MM + "mem_mode": ("s", False, "const", ["const", "external"]), + # Width for AXI-MM interface + # only relevant when mem_mode="external" + "ext_mem_width": ("i", False, 32), } my_attrs.update(super().get_nodeattr_types()) return my_attrs @@ -72,7 +79,8 @@ class Lookup(HLSCustomOp): def get_normal_output_shape(self): ishape = self.get_normal_input_shape() - oshape = list(ishape) + [self.get_nodeattr("EmbeddingDim")] + emb_dim = self.get_nodeattr("EmbeddingDim") + oshape = list(ishape) + [emb_dim] return tuple(oshape) def get_folded_input_shape(self): @@ -81,7 +89,23 @@ class Lookup(HLSCustomOp): return tuple(folded_ishape) def get_folded_output_shape(self): - return self.get_normal_output_shape() + ishape = self.get_normal_input_shape() + mem_mode = self.get_nodeattr("mem_mode") + emb_dim = self.get_nodeattr("EmbeddingDim") + if mem_mode == "const": + oshape = list(ishape) + [emb_dim] + elif mem_mode == "external": + ext_mem_width = self.get_nodeattr("ext_mem_width") + bits_per_emb_elem = self.get_output_datatype().bitwidth() + assert ext_mem_width % bits_per_emb_elem == 0 + emb_elems_per_ext_mem_width = ext_mem_width // bits_per_emb_elem + oshape = list(ishape) + [ + emb_dim // emb_elems_per_ext_mem_width, + emb_elems_per_ext_mem_width, + ] + else: + raise Exception("Unrecognized mem_mode:" + mem_mode) + return tuple(oshape) def make_shape_compatible_op(self, model): exp_ishape = tuple(self.get_normal_input_shape()) @@ -123,17 +147,20 @@ class Lookup(HLSCustomOp): return ibits def get_outstream_width(self): + folded_oshape = self.get_folded_output_shape() obits = self.get_output_datatype().bitwidth() - ofm_ch = self.get_nodeattr("EmbeddingDim") - return obits * ofm_ch + return obits * folded_oshape[-1] def get_number_output_values(self): folded_oshape = self.get_folded_output_shape() return np.prod(folded_oshape[:-1]) def global_includes(self): - global_incls = ['#include "lookup.hpp"'] - global_incls.append('#include "embeddings.hpp"') + mem_mode = self.get_nodeattr("mem_mode") + global_incls = [] + if mem_mode == "const": + global_incls.append('#include "lookup.hpp"') + global_incls.append('#include "embeddings.hpp"') self.code_gen_dict["$GLOBALS$"] = global_incls def defines(self, var): @@ -142,14 +169,26 @@ class Lookup(HLSCustomOp): elem_hls_type = dtype.get_hls_datatype_str() emb_type = DataType[self.get_nodeattr("EmbeddingType")] emb_hls_type = emb_type.get_hls_datatype_str() + emb_dim = self.get_nodeattr("EmbeddingDim") + mem_mode = self.get_nodeattr("mem_mode") my_defines = [] - my_defines.append( - "#define NumEmbeddings %d" % self.get_nodeattr("NumEmbeddings") - ) - my_defines.append("#define EmbeddingDim %d" % self.get_nodeattr("EmbeddingDim")) my_defines.append("#define NumInputs %d" % n_inputs) - my_defines.append("#define InputType %s" % elem_hls_type) - my_defines.append("#define EmbeddingType %s" % emb_hls_type) + if mem_mode == "external": + ext_mem_width = self.get_nodeattr("ext_mem_width") + ext_mem_emb_size = self.get_folded_output_shape()[-2] + ext_mem_emb_align = ceil(log2(ext_mem_emb_size)) + my_defines.append("#define MemBits %d" % ext_mem_width) + my_defines.append("#define EmbeddingSize %d" % ext_mem_emb_size) + my_defines.append("#define EmbeddingAlign %d" % ext_mem_emb_align) + my_defines.append("#define T_SRC %s" % elem_hls_type) + my_defines.append("#define T_DST ap_uint<MemBits>") + elif mem_mode == "const": + my_defines.append( + "#define NumEmbeddings %d" % self.get_nodeattr("NumEmbeddings") + ) + my_defines.append("#define EmbeddingDim %d" % emb_dim) + my_defines.append("#define InputType %s" % elem_hls_type) + my_defines.append("#define EmbeddingType %s" % emb_hls_type) self.code_gen_dict["$DEFINES$"] = my_defines def read_npy_data(self): @@ -211,22 +250,46 @@ class Lookup(HLSCustomOp): ) def docompute(self): - self.code_gen_dict["$DOCOMPUTE$"] = [ - """StreamingLookup<NumEmbeddings, EmbeddingDim, NumInputs, - InputType, EmbeddingType >(in0, out, embeddings);""" - ] + mem_mode = self.get_nodeattr("mem_mode") + if mem_mode == "const": + self.code_gen_dict["$DOCOMPUTE$"] = [ + """StreamingLookup<NumEmbeddings, EmbeddingDim, NumInputs, + InputType, EmbeddingType >(in0, out, embeddings);""" + ] + elif mem_mode == "external": + hls_impl = """ + for(unsigned i = 0; i < NumInputs; i++) { + ap_uint<T_SRC::width+EmbeddingAlign> const base = + (in0.read(), ap_uint<EmbeddingAlign>(0)); + for(unsigned j = 0; j < EmbeddingSize; j++) { +#pragma HLS PIPELINE II=1 + out.write(mem[base+j]); + } + } + """ + self.code_gen_dict["$DOCOMPUTE$"] = [hls_impl] def blackboxfunction(self): + mem_mode = self.get_nodeattr("mem_mode") ibits = self.get_instream_width() packed_input_hls_type = "ap_uint<%d>" % ibits obits = self.get_outstream_width() packed_output_hls_type = "ap_uint<%d>" % obits - self.code_gen_dict["$BLACKBOXFUNCTION$"] = [ - "void %s(hls::stream<%s > &in0, hls::stream<%s > &out)" - % (self.onnx_node.name, packed_input_hls_type, packed_output_hls_type) - ] + if mem_mode == "const": + self.code_gen_dict["$BLACKBOXFUNCTION$"] = [ + "void %s(hls::stream<%s > &in0, hls::stream<%s > &out)" + % (self.onnx_node.name, packed_input_hls_type, packed_output_hls_type) + ] + elif mem_mode == "external": + self.code_gen_dict["$BLACKBOXFUNCTION$"] = [ + "void " + + self.onnx_node.name + + "(hls::stream<T_SRC> &in0, hls::stream<T_DST> &out, " + + "T_DST const *const mem)" + ] def pragmas(self): + mem_mode = self.get_nodeattr("mem_mode") my_pragmas = [ "#pragma HLS INTERFACE axis port=in0 name=in0_" + self.hls_sname() ] @@ -234,30 +297,38 @@ class Lookup(HLSCustomOp): "#pragma HLS INTERFACE axis port=out name=out_" + self.hls_sname() ) my_pragmas.append("#pragma HLS INTERFACE ap_ctrl_none port=return") - my_pragmas.append( - "#pragma HLS BIND_STORAGE variable=embeddings type=ROM_2P impl=BRAM" - ) + if mem_mode == "const": + my_pragmas.append( + "#pragma HLS BIND_STORAGE variable=embeddings type=ROM_2P impl=BRAM" + ) + elif mem_mode == "external": + my_pragmas.append("#pragma HLS INTERFACE m_axi offset=slave port=mem") + my_pragmas.append("#pragma HLS INTERFACE s_axilite port=mem bundle=control") + else: + raise Exception("Unrecognized mem_mode: " + mem_mode) self.code_gen_dict["$PRAGMAS$"] = my_pragmas def generate_params(self, model, path): - code_gen_dir = path - embeddings = model.get_initializer(self.onnx_node.input[1]) - weight_filename = "{}/embeddings.hpp".format(code_gen_dir) - edt = DataType[self.get_nodeattr("EmbeddingType")] - # obits = self.get_outstream_width() - # packed_output_hls_type = "ap_uint<%d>" % obits - assert np.vectorize(edt.allowed)( - embeddings - ).all(), "Embeddings can't be expressed with type %s" % str(edt) - # reverse innertmost dim in embeddings to remain compatible with - # how we normally encode the data in FINN - embeddings_rev = np.flip(embeddings, -1) - embeddings_hls_code = numpy_to_hls_code( - embeddings_rev, edt, "embeddings", True, False - ) - f_thresh = open(weight_filename, "w") - f_thresh.write(embeddings_hls_code) - f_thresh.close() + mem_mode = self.get_nodeattr("mem_mode") + if mem_mode == "const": + code_gen_dir = path + embeddings = model.get_initializer(self.onnx_node.input[1]) + weight_filename = "{}/embeddings.hpp".format(code_gen_dir) + edt = DataType[self.get_nodeattr("EmbeddingType")] + # obits = self.get_outstream_width() + # packed_output_hls_type = "ap_uint<%d>" % obits + assert np.vectorize(edt.allowed)( + embeddings + ).all(), "Embeddings can't be expressed with type %s" % str(edt) + # reverse innertmost dim in embeddings to remain compatible with + # how we normally encode the data in FINN + embeddings_rev = np.flip(embeddings, -1) + embeddings_hls_code = numpy_to_hls_code( + embeddings_rev, edt, "embeddings", True, False + ) + f_thresh = open(weight_filename, "w") + f_thresh.write(embeddings_hls_code) + f_thresh.close() def execute_node(self, context, graph): mode = self.get_nodeattr("exec_mode") @@ -266,6 +337,10 @@ class Lookup(HLSCustomOp): exp_oshape = tuple(self.get_normal_output_shape()) folded_ishape = tuple(self.get_folded_input_shape()) folded_oshape = tuple(self.get_folded_output_shape()) + mem_mode = self.get_nodeattr("mem_mode") + assert ( + mem_mode == "const" + ), "Only mem_mode=const is supported for simulation of Lookup layer" if mode == "cppsim": code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim") @@ -335,10 +410,16 @@ class Lookup(HLSCustomOp): ), """Output shape doesn't match expected shape.""" def bram_estimation(self): - # current calculation assumes embeddings always stored in BRAM_18Ks - width_factor = ceil(self.get_outstream_width() / 16) - depth_factor = ceil(self.get_nodeattr("NumEmbeddings") / 1024) - return width_factor * depth_factor + mem_mode = self.get_nodeattr("mem_mode") + if mem_mode == "const": + # current calculation assumes embeddings always stored in BRAM_18Ks + # when mem_mode is const + width_factor = ceil(self.get_outstream_width() / 16) + depth_factor = ceil(self.get_nodeattr("NumEmbeddings") / 1024) + return width_factor * depth_factor + else: + # TODO can we estimate BRAMs for the DMA engine? + return 0 def bram_efficiency_estimation(self): bram16_est = self.bram_estimation() @@ -347,3 +428,12 @@ class Lookup(HLSCustomOp): ebits = self.get_outstream_width() * self.get_nodeattr("NumEmbeddings") bram16_est_capacity = bram16_est * 18 * 1024 return ebits / bram16_est_capacity + + def get_ap_int_max_w(self): + parent_max = super().get_ap_int_max_w() + mem_mode = self.get_nodeattr("mem_mode") + ext_mem_width = self.get_nodeattr("ext_mem_width") + if mem_mode == "external": + return max(ext_mem_width, parent_max) + else: + return parent_max