diff --git a/custom_hls/lookup.hpp b/custom_hls/lookup.hpp index 3001f6613ec6ed9a9e5f47d9be356d4b032f7192..037b038a09a10ff2bd066740d20f0b47489e24e4 100644 --- a/custom_hls/lookup.hpp +++ b/custom_hls/lookup.hpp @@ -26,14 +26,15 @@ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - - *******************************************************************************/ +*******************************************************************************/ +#ifndef LOOKUP_HPP +#define LOOKUP_HPP #include <ap_int.h> #include <hls_stream.h> -#ifndef LOOKUP_HPP -#define LOOKUP_HPP +#include "utils.hpp" + template < unsigned NumEmbeddings, @@ -57,4 +58,50 @@ void StreamingLookup( } } +/** + * Lookup implementation over a table stored in AXI-accessible memory. + */ +template < + unsigned EmbeddingSize, // Number of memory words per embedding + unsigned EmbeddingAlign = clog2(EmbeddingSize), // Alignment of entries = number of word index bits + typename T_SRC, + typename T_DST +> +void StreamingLookup_ext( + hls::stream<T_SRC> &in0, + hls::stream<T_DST> &out, + T_DST const *const mem, + unsigned const size, + unsigned &oob_count, + bool &oob_irq +) { +#pragma HLS pipeline II=EmbeddingSize+9 style=flp + + static unsigned oob_count_li; + static unsigned oob_count_int; +#pragma HLS reset variable=oob_count_li +#pragma HLS reset variable=oob_count_int + + if(oob_count != oob_count_li) { + oob_count_int -= oob_count_li; + oob_count_li = oob_count; + } + if(!in0.empty()) { + T_SRC const x = in0.read(); + + // Map out-of-bounds inputs to an offset of zero and increment counter + bool const oob = x >= T_SRC(size); + ap_uint<T_SRC::width+EmbeddingAlign> const ofs = + ((oob? T_SRC(0) : x), ap_uint<EmbeddingAlign>(0)); + oob_count_int += oob; + + // Stream lookup data (burst inferred) + for(unsigned i = 0; i < EmbeddingSize; i++) { +#pragma HLS pipeline II=1 style=flp + out.write(mem[ofs+i]); + } + } + oob_count = oob_count_int; + oob_irq = (oob_count_int != 0); +} #endif diff --git a/src/finn/custom_op/fpgadataflow/hlscustomop.py b/src/finn/custom_op/fpgadataflow/hlscustomop.py index b202e95a28a26de3dabc098c2030bafcf840d164..c5041acd46a63880160f7726946e1c609642710d 100644 --- a/src/finn/custom_op/fpgadataflow/hlscustomop.py +++ b/src/finn/custom_op/fpgadataflow/hlscustomop.py @@ -138,6 +138,7 @@ class HLSCustomOp(CustomOp): intf_names["m_axis"] = [("out_" + sname, self.get_outstream_width_padded())] intf_names["aximm"] = [] intf_names["axilite"] = [] + intf_names["ap_none"] = [] return intf_names def get_verilog_top_filename(self): diff --git a/src/finn/custom_op/fpgadataflow/lookup.py b/src/finn/custom_op/fpgadataflow/lookup.py index d90fa0f05ab2a92391f610ae1c4516a95a881ce4..613a91b6284e0789dff2446e1615690a03336d99 100644 --- a/src/finn/custom_op/fpgadataflow/lookup.py +++ b/src/finn/custom_op/fpgadataflow/lookup.py @@ -159,8 +159,8 @@ class Lookup(HLSCustomOp): def global_includes(self): mem_mode = self.get_nodeattr("mem_mode") global_incls = [] + global_incls.append('#include "lookup.hpp"') if mem_mode == "const": - global_incls.append('#include "lookup.hpp"') global_incls.append('#include "embeddings.hpp"') self.code_gen_dict["$GLOBALS$"] = global_incls @@ -258,17 +258,10 @@ class Lookup(HLSCustomOp): InputType, EmbeddingType >(in0, out, embeddings);""" ] elif mem_mode == "external": - hls_impl = """ - if(!in0.empty()) { - ap_uint<T_SRC::width+EmbeddingAlign> const base = - (in0.read(), ap_uint<EmbeddingAlign>(0)); - for(unsigned j = 0; j < EmbeddingSize; j++) { -#pragma HLS PIPELINE II=1 - out.write(mem[base+j]); - } - } - """ - self.code_gen_dict["$DOCOMPUTE$"] = [hls_impl] + self.code_gen_dict["$DOCOMPUTE$"] = [ + """StreamingLookup_ext<EmbeddingSize>(in0, out, mem, size, oob_count, + oob_irq);""" + ] def blackboxfunction(self): mem_mode = self.get_nodeattr("mem_mode") @@ -286,7 +279,8 @@ class Lookup(HLSCustomOp): "void " + self.onnx_node.name + "(hls::stream<T_SRC> &in0, hls::stream<T_DST> &out, " - + "T_DST const *const mem)" + + "T_DST const *const mem, unsigned const size, " + + "unsigned &oob_count, bool &oob_irq)" ] def pragmas(self): @@ -305,6 +299,13 @@ class Lookup(HLSCustomOp): elif mem_mode == "external": my_pragmas.append("#pragma HLS INTERFACE m_axi offset=slave port=mem") my_pragmas.append("#pragma HLS INTERFACE s_axilite port=mem bundle=control") + my_pragmas.append( + "#pragma HLS INTERFACE s_axilite port=size bundle=control" + ) + my_pragmas.append( + "#pragma HLS INTERFACE s_axilite port=oob_count bundle=control" + ) + my_pragmas.append("#pragma HLS INTERFACE ap_none port=oob_irq") else: raise Exception("Unrecognized mem_mode: " + mem_mode) self.code_gen_dict["$PRAGMAS$"] = my_pragmas @@ -475,4 +476,5 @@ class Lookup(HLSCustomOp): if mem_mode == "external": intf_names["axilite"] = ["s_axi_control"] intf_names["aximm"] = [("m_axi_gmem", self.get_nodeattr("ext_mem_width"))] + intf_names["ap_none"] = ["oob_irq"] return intf_names diff --git a/src/finn/transformation/fpgadataflow/create_stitched_ip.py b/src/finn/transformation/fpgadataflow/create_stitched_ip.py index 892ab09fdf41947f86e2bf122e057e94585dfa8c..5b0b0cb600ca10564db00bb6d57bcd19a4f49bb6 100644 --- a/src/finn/transformation/fpgadataflow/create_stitched_ip.py +++ b/src/finn/transformation/fpgadataflow/create_stitched_ip.py @@ -228,6 +228,22 @@ class CreateStitchedIP(Transformation): ) self.s_axis_idx += 1 + def connect_ap_none_external(self, node): + inst_name = node.name + node_inst = getCustomOp(node) + input_intf_names = node_inst.get_verilog_top_module_intf_names()["ap_none"] + # make external + for i in range(len(input_intf_names)): + input_intf_name = input_intf_names[i] + self.connect_cmds.append( + "make_bd_pins_external [get_bd_pins %s/%s]" + % (inst_name, input_intf_name) + ) + self.connect_cmds.append( + "set_property name %s [get_bd_ports %s_0]" + % (input_intf_name, input_intf_name) + ) + def insert_signature(self, checksum_count): signature_vlnv = "AMD:user:axi_info_top:1.0" signature_name = "axi_info_top0" @@ -305,6 +321,7 @@ class CreateStitchedIP(Transformation): ip_dirs += [ip_dir_value] self.create_cmds += node_inst.code_generation_ipi() self.connect_clk_rst(node) + self.connect_ap_none_external(node) self.connect_axi(node) for i in range(len(node.input)): if not is_external_input(model, node, i):