diff --git a/src/finn/custom_op/fpgadataflow/lookup.py b/src/finn/custom_op/fpgadataflow/lookup.py
index cd50f0bb605c185c00c218e3ec8bc911d1b85309..d34209ce6c39286b33e27d5a9ce714328596ac30 100644
--- a/src/finn/custom_op/fpgadataflow/lookup.py
+++ b/src/finn/custom_op/fpgadataflow/lookup.py
@@ -29,7 +29,7 @@
 import numpy as np
 import os
 import warnings
-from math import ceil
+from math import ceil, log2
 
 from finn.core.datatype import DataType
 from finn.custom_op.fpgadataflow.hlscustomop import HLSCustomOp
@@ -58,6 +58,13 @@ class Lookup(HLSCustomOp):
             "InputType": ("s", True, ""),
             # Input shape
             "InputShape": ("ints", False, [1]),
+            # Memory mode
+            # const : parameters baked into bitfile (BRAM)
+            # external : lookup performed in external memory over AXI MM
+            "mem_mode": ("s", False, "const", ["const", "external"]),
+            # Width for AXI-MM interface
+            # only relevant when mem_mode="external"
+            "ext_mem_width": ("i", False, 32),
         }
         my_attrs.update(super().get_nodeattr_types())
         return my_attrs
@@ -72,7 +79,8 @@ class Lookup(HLSCustomOp):
 
     def get_normal_output_shape(self):
         ishape = self.get_normal_input_shape()
-        oshape = list(ishape) + [self.get_nodeattr("EmbeddingDim")]
+        emb_dim = self.get_nodeattr("EmbeddingDim")
+        oshape = list(ishape) + [emb_dim]
         return tuple(oshape)
 
     def get_folded_input_shape(self):
@@ -81,7 +89,23 @@ class Lookup(HLSCustomOp):
         return tuple(folded_ishape)
 
     def get_folded_output_shape(self):
-        return self.get_normal_output_shape()
+        ishape = self.get_normal_input_shape()
+        mem_mode = self.get_nodeattr("mem_mode")
+        emb_dim = self.get_nodeattr("EmbeddingDim")
+        if mem_mode == "const":
+            oshape = list(ishape) + [emb_dim]
+        elif mem_mode == "external":
+            ext_mem_width = self.get_nodeattr("ext_mem_width")
+            bits_per_emb_elem = self.get_output_datatype().bitwidth()
+            assert ext_mem_width % bits_per_emb_elem == 0
+            emb_elems_per_ext_mem_width = ext_mem_width // bits_per_emb_elem
+            oshape = list(ishape) + [
+                emb_dim // emb_elems_per_ext_mem_width,
+                emb_elems_per_ext_mem_width,
+            ]
+        else:
+            raise Exception("Unrecognized mem_mode:" + mem_mode)
+        return tuple(oshape)
 
     def make_shape_compatible_op(self, model):
         exp_ishape = tuple(self.get_normal_input_shape())
@@ -123,17 +147,20 @@ class Lookup(HLSCustomOp):
         return ibits
 
     def get_outstream_width(self):
+        folded_oshape = self.get_folded_output_shape()
         obits = self.get_output_datatype().bitwidth()
-        ofm_ch = self.get_nodeattr("EmbeddingDim")
-        return obits * ofm_ch
+        return obits * folded_oshape[-1]
 
     def get_number_output_values(self):
         folded_oshape = self.get_folded_output_shape()
         return np.prod(folded_oshape[:-1])
 
     def global_includes(self):
-        global_incls = ['#include "lookup.hpp"']
-        global_incls.append('#include "embeddings.hpp"')
+        mem_mode = self.get_nodeattr("mem_mode")
+        global_incls = []
+        if mem_mode == "const":
+            global_incls.append('#include "lookup.hpp"')
+            global_incls.append('#include "embeddings.hpp"')
         self.code_gen_dict["$GLOBALS$"] = global_incls
 
     def defines(self, var):
@@ -142,14 +169,26 @@ class Lookup(HLSCustomOp):
         elem_hls_type = dtype.get_hls_datatype_str()
         emb_type = DataType[self.get_nodeattr("EmbeddingType")]
         emb_hls_type = emb_type.get_hls_datatype_str()
+        emb_dim = self.get_nodeattr("EmbeddingDim")
+        mem_mode = self.get_nodeattr("mem_mode")
         my_defines = []
-        my_defines.append(
-            "#define NumEmbeddings %d" % self.get_nodeattr("NumEmbeddings")
-        )
-        my_defines.append("#define EmbeddingDim %d" % self.get_nodeattr("EmbeddingDim"))
         my_defines.append("#define NumInputs %d" % n_inputs)
-        my_defines.append("#define InputType %s" % elem_hls_type)
-        my_defines.append("#define EmbeddingType %s" % emb_hls_type)
+        if mem_mode == "external":
+            ext_mem_width = self.get_nodeattr("ext_mem_width")
+            ext_mem_emb_size = self.get_folded_output_shape()[-2]
+            ext_mem_emb_align = ceil(log2(ext_mem_emb_size))
+            my_defines.append("#define MemBits %d" % ext_mem_width)
+            my_defines.append("#define EmbeddingSize %d" % ext_mem_emb_size)
+            my_defines.append("#define EmbeddingAlign %d" % ext_mem_emb_align)
+            my_defines.append("#define T_SRC %s" % elem_hls_type)
+            my_defines.append("#define T_DST ap_uint<MemBits>")
+        elif mem_mode == "const":
+            my_defines.append(
+                "#define NumEmbeddings %d" % self.get_nodeattr("NumEmbeddings")
+            )
+            my_defines.append("#define EmbeddingDim %d" % emb_dim)
+            my_defines.append("#define InputType %s" % elem_hls_type)
+            my_defines.append("#define EmbeddingType %s" % emb_hls_type)
         self.code_gen_dict["$DEFINES$"] = my_defines
 
     def read_npy_data(self):
@@ -211,22 +250,46 @@ class Lookup(HLSCustomOp):
         )
 
     def docompute(self):
-        self.code_gen_dict["$DOCOMPUTE$"] = [
-            """StreamingLookup<NumEmbeddings,  EmbeddingDim, NumInputs,
-            InputType, EmbeddingType >(in0, out, embeddings);"""
-        ]
+        mem_mode = self.get_nodeattr("mem_mode")
+        if mem_mode == "const":
+            self.code_gen_dict["$DOCOMPUTE$"] = [
+                """StreamingLookup<NumEmbeddings,  EmbeddingDim, NumInputs,
+                InputType, EmbeddingType >(in0, out, embeddings);"""
+            ]
+        elif mem_mode == "external":
+            hls_impl = """
+    for(unsigned  i = 0; i < NumInputs; i++) {
+        ap_uint<T_SRC::width+EmbeddingAlign> const  base =
+            (in0.read(), ap_uint<EmbeddingAlign>(0));
+        for(unsigned  j = 0; j < EmbeddingSize; j++) {
+#pragma HLS PIPELINE II=1
+            out.write(mem[base+j]);
+        }
+    }
+            """
+            self.code_gen_dict["$DOCOMPUTE$"] = [hls_impl]
 
     def blackboxfunction(self):
+        mem_mode = self.get_nodeattr("mem_mode")
         ibits = self.get_instream_width()
         packed_input_hls_type = "ap_uint<%d>" % ibits
         obits = self.get_outstream_width()
         packed_output_hls_type = "ap_uint<%d>" % obits
-        self.code_gen_dict["$BLACKBOXFUNCTION$"] = [
-            "void %s(hls::stream<%s > &in0, hls::stream<%s > &out)"
-            % (self.onnx_node.name, packed_input_hls_type, packed_output_hls_type)
-        ]
+        if mem_mode == "const":
+            self.code_gen_dict["$BLACKBOXFUNCTION$"] = [
+                "void %s(hls::stream<%s > &in0, hls::stream<%s > &out)"
+                % (self.onnx_node.name, packed_input_hls_type, packed_output_hls_type)
+            ]
+        elif mem_mode == "external":
+            self.code_gen_dict["$BLACKBOXFUNCTION$"] = [
+                "void "
+                + self.onnx_node.name
+                + "(hls::stream<T_SRC> &in0, hls::stream<T_DST> &out, "
+                + "T_DST const *const  mem)"
+            ]
 
     def pragmas(self):
+        mem_mode = self.get_nodeattr("mem_mode")
         my_pragmas = [
             "#pragma HLS INTERFACE axis port=in0 name=in0_" + self.hls_sname()
         ]
@@ -234,30 +297,38 @@ class Lookup(HLSCustomOp):
             "#pragma HLS INTERFACE axis port=out name=out_" + self.hls_sname()
         )
         my_pragmas.append("#pragma HLS INTERFACE ap_ctrl_none port=return")
-        my_pragmas.append(
-            "#pragma HLS BIND_STORAGE variable=embeddings type=ROM_2P impl=BRAM"
-        )
+        if mem_mode == "const":
+            my_pragmas.append(
+                "#pragma HLS BIND_STORAGE variable=embeddings type=ROM_2P impl=BRAM"
+            )
+        elif mem_mode == "external":
+            my_pragmas.append("#pragma HLS INTERFACE m_axi offset=slave port=mem")
+            my_pragmas.append("#pragma HLS INTERFACE s_axilite port=mem bundle=control")
+        else:
+            raise Exception("Unrecognized mem_mode: " + mem_mode)
         self.code_gen_dict["$PRAGMAS$"] = my_pragmas
 
     def generate_params(self, model, path):
-        code_gen_dir = path
-        embeddings = model.get_initializer(self.onnx_node.input[1])
-        weight_filename = "{}/embeddings.hpp".format(code_gen_dir)
-        edt = DataType[self.get_nodeattr("EmbeddingType")]
-        # obits = self.get_outstream_width()
-        # packed_output_hls_type = "ap_uint<%d>" % obits
-        assert np.vectorize(edt.allowed)(
-            embeddings
-        ).all(), "Embeddings can't be expressed with type %s" % str(edt)
-        # reverse innertmost dim in embeddings to remain compatible with
-        # how we normally encode the data in FINN
-        embeddings_rev = np.flip(embeddings, -1)
-        embeddings_hls_code = numpy_to_hls_code(
-            embeddings_rev, edt, "embeddings", True, False
-        )
-        f_thresh = open(weight_filename, "w")
-        f_thresh.write(embeddings_hls_code)
-        f_thresh.close()
+        mem_mode = self.get_nodeattr("mem_mode")
+        if mem_mode == "const":
+            code_gen_dir = path
+            embeddings = model.get_initializer(self.onnx_node.input[1])
+            weight_filename = "{}/embeddings.hpp".format(code_gen_dir)
+            edt = DataType[self.get_nodeattr("EmbeddingType")]
+            # obits = self.get_outstream_width()
+            # packed_output_hls_type = "ap_uint<%d>" % obits
+            assert np.vectorize(edt.allowed)(
+                embeddings
+            ).all(), "Embeddings can't be expressed with type %s" % str(edt)
+            # reverse innertmost dim in embeddings to remain compatible with
+            # how we normally encode the data in FINN
+            embeddings_rev = np.flip(embeddings, -1)
+            embeddings_hls_code = numpy_to_hls_code(
+                embeddings_rev, edt, "embeddings", True, False
+            )
+            f_thresh = open(weight_filename, "w")
+            f_thresh.write(embeddings_hls_code)
+            f_thresh.close()
 
     def execute_node(self, context, graph):
         mode = self.get_nodeattr("exec_mode")
@@ -266,6 +337,10 @@ class Lookup(HLSCustomOp):
         exp_oshape = tuple(self.get_normal_output_shape())
         folded_ishape = tuple(self.get_folded_input_shape())
         folded_oshape = tuple(self.get_folded_output_shape())
+        mem_mode = self.get_nodeattr("mem_mode")
+        assert (
+            mem_mode == "const"
+        ), "Only mem_mode=const is supported for simulation of Lookup layer"
 
         if mode == "cppsim":
             code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim")
@@ -335,10 +410,16 @@ class Lookup(HLSCustomOp):
         ), """Output shape doesn't match expected shape."""
 
     def bram_estimation(self):
-        # current calculation assumes embeddings always stored in BRAM_18Ks
-        width_factor = ceil(self.get_outstream_width() / 16)
-        depth_factor = ceil(self.get_nodeattr("NumEmbeddings") / 1024)
-        return width_factor * depth_factor
+        mem_mode = self.get_nodeattr("mem_mode")
+        if mem_mode == "const":
+            # current calculation assumes embeddings always stored in BRAM_18Ks
+            # when mem_mode is const
+            width_factor = ceil(self.get_outstream_width() / 16)
+            depth_factor = ceil(self.get_nodeattr("NumEmbeddings") / 1024)
+            return width_factor * depth_factor
+        else:
+            # TODO can we estimate BRAMs for the DMA engine?
+            return 0
 
     def bram_efficiency_estimation(self):
         bram16_est = self.bram_estimation()
@@ -347,3 +428,12 @@ class Lookup(HLSCustomOp):
         ebits = self.get_outstream_width() * self.get_nodeattr("NumEmbeddings")
         bram16_est_capacity = bram16_est * 18 * 1024
         return ebits / bram16_est_capacity
+
+    def get_ap_int_max_w(self):
+        parent_max = super().get_ap_int_max_w()
+        mem_mode = self.get_nodeattr("mem_mode")
+        ext_mem_width = self.get_nodeattr("ext_mem_width")
+        if mem_mode == "external":
+            return max(ext_mem_width, parent_max)
+        else:
+            return parent_max