From ec10c641339c514c71f4a394e90481926b6a49b0 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <yamanu@xilinx.com>
Date: Thu, 14 Oct 2021 16:50:41 +0200
Subject: [PATCH] [Lookup] add HLS conversion transform for Gather -> Lookup

---
 .../fpgadataflow/convert_to_hls_layers.py     | 53 +++++++++++++++++++
 1 file changed, 53 insertions(+)

diff --git a/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py b/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py
index 8ac3a705b..113ccb93b 100644
--- a/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py
+++ b/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py
@@ -1540,3 +1540,56 @@ class InferGlobalAccPoolLayer(Transformation):
             model = model.transform(InferShapes())
             model = model.transform(InferDataTypes())
         return (model, graph_modified)
+
+
+class InferLookupLayer(Transformation):
+    """Convert Gather nodes with constant op0 into Lookup HLS layers."""
+
+    def apply(self, model):
+        graph = model.graph
+        node_ind = 0
+        graph_modified = False
+        for node in graph.node:
+            node_ind += 1
+            if node.op_type == "Gather":
+                emb_name = node.input[0]
+                embs = model.get_initializer(emb_name)
+                axis = get_by_name(node.attribute, "axis")
+                # skip conversion if input0 is not constant
+                if embs is None:
+                    continue
+                # skip conversion if axis != 0
+                if axis is not None and axis.i != 0:
+                    continue
+                ind_name = node.input[1]
+                ind_dtype = model.get_tensor_datatype(ind_name)
+                emb_dtype = model.get_tensor_datatype(emb_name)
+                # skip conversion if inputs are not unsigned integers
+                if (not ind_dtype.is_integer()) or ind_dtype.signed():
+                    continue
+                num_embs, emb_dim = embs.shape
+                out_name = node.output[0]
+                ishape = model.get_tensor_shape(node.input[1])
+                # create and insert new Lookup node
+                new_node = helper.make_node(
+                    "Lookup",
+                    [ind_name, emb_name],
+                    [out_name],
+                    domain="finn.custom_op.fpgadataflow",
+                    backend="fpgadataflow",
+                    name="Lookup_" + node.name,
+                    NumEmbeddings=num_embs,
+                    EmbeddingDim=emb_dim,
+                    EmbeddingType=emb_dtype.name,
+                    InputType=ind_dtype.name,
+                    InputShape=list(ishape),
+                )
+                graph.node.insert(node_ind, new_node)
+                # remove old node
+                graph.node.remove(node)
+                graph_modified = True
+
+        if graph_modified:
+            model = model.transform(InferShapes())
+            model = model.transform(InferDataTypes())
+        return (model, graph_modified)
-- 
GitLab