Skip to content
Snippets Groups Projects
Commit ec10c641 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Lookup] add HLS conversion transform for Gather -> Lookup

parent cde63ec1
No related branches found
No related tags found
No related merge requests found
......@@ -1540,3 +1540,56 @@ class InferGlobalAccPoolLayer(Transformation):
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())
return (model, graph_modified)
class InferLookupLayer(Transformation):
"""Convert Gather nodes with constant op0 into Lookup HLS layers."""
def apply(self, model):
graph = model.graph
node_ind = 0
graph_modified = False
for node in graph.node:
node_ind += 1
if node.op_type == "Gather":
emb_name = node.input[0]
embs = model.get_initializer(emb_name)
axis = get_by_name(node.attribute, "axis")
# skip conversion if input0 is not constant
if embs is None:
continue
# skip conversion if axis != 0
if axis is not None and axis.i != 0:
continue
ind_name = node.input[1]
ind_dtype = model.get_tensor_datatype(ind_name)
emb_dtype = model.get_tensor_datatype(emb_name)
# skip conversion if inputs are not unsigned integers
if (not ind_dtype.is_integer()) or ind_dtype.signed():
continue
num_embs, emb_dim = embs.shape
out_name = node.output[0]
ishape = model.get_tensor_shape(node.input[1])
# create and insert new Lookup node
new_node = helper.make_node(
"Lookup",
[ind_name, emb_name],
[out_name],
domain="finn.custom_op.fpgadataflow",
backend="fpgadataflow",
name="Lookup_" + node.name,
NumEmbeddings=num_embs,
EmbeddingDim=emb_dim,
EmbeddingType=emb_dtype.name,
InputType=ind_dtype.name,
InputShape=list(ishape),
)
graph.node.insert(node_ind, new_node)
# remove old node
graph.node.remove(node)
graph_modified = True
if graph_modified:
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())
return (model, graph_modified)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment