From ec10c641339c514c71f4a394e90481926b6a49b0 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <yamanu@xilinx.com> Date: Thu, 14 Oct 2021 16:50:41 +0200 Subject: [PATCH] [Lookup] add HLS conversion transform for Gather -> Lookup --- .../fpgadataflow/convert_to_hls_layers.py | 53 +++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py b/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py index 8ac3a705b..113ccb93b 100644 --- a/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py +++ b/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py @@ -1540,3 +1540,56 @@ class InferGlobalAccPoolLayer(Transformation): model = model.transform(InferShapes()) model = model.transform(InferDataTypes()) return (model, graph_modified) + + +class InferLookupLayer(Transformation): + """Convert Gather nodes with constant op0 into Lookup HLS layers.""" + + def apply(self, model): + graph = model.graph + node_ind = 0 + graph_modified = False + for node in graph.node: + node_ind += 1 + if node.op_type == "Gather": + emb_name = node.input[0] + embs = model.get_initializer(emb_name) + axis = get_by_name(node.attribute, "axis") + # skip conversion if input0 is not constant + if embs is None: + continue + # skip conversion if axis != 0 + if axis is not None and axis.i != 0: + continue + ind_name = node.input[1] + ind_dtype = model.get_tensor_datatype(ind_name) + emb_dtype = model.get_tensor_datatype(emb_name) + # skip conversion if inputs are not unsigned integers + if (not ind_dtype.is_integer()) or ind_dtype.signed(): + continue + num_embs, emb_dim = embs.shape + out_name = node.output[0] + ishape = model.get_tensor_shape(node.input[1]) + # create and insert new Lookup node + new_node = helper.make_node( + "Lookup", + [ind_name, emb_name], + [out_name], + domain="finn.custom_op.fpgadataflow", + backend="fpgadataflow", + name="Lookup_" + node.name, + NumEmbeddings=num_embs, + EmbeddingDim=emb_dim, + EmbeddingType=emb_dtype.name, + InputType=ind_dtype.name, + InputShape=list(ishape), + ) + graph.node.insert(node_ind, new_node) + # remove old node + graph.node.remove(node) + graph_modified = True + + if graph_modified: + model = model.transform(InferShapes()) + model = model.transform(InferDataTypes()) + return (model, graph_modified) -- GitLab