Skip to content
Snippets Groups Projects
Commit 96265798 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Test] start adding a test for ext lookups

parent eabd2b71
No related branches found
No related tags found
No related merge requests found
......@@ -36,6 +36,7 @@ from torch import nn
from finn.core.datatype import DataType
from finn.core.modelwrapper import ModelWrapper
from finn.core.onnx_exec import execute_onnx
from finn.custom_op.registry import getCustomOp
from finn.transformation.fpgadataflow.compile_cppsim import CompileCppSim
from finn.transformation.fpgadataflow.convert_to_hls_layers import InferLookupLayer
from finn.transformation.fpgadataflow.hlssynth_ip import HLSSynthIP
......@@ -130,3 +131,42 @@ def test_fpgadataflow_lookup(edt, embedding_cfg, exec_mode):
model = model.transform(PrepareRTLSim())
ret_sim = execute_onnx(model, {iname: itensor})
assert (exp_out == ret_sim[oname]).all()
@pytest.mark.vivado
@pytest.mark.slow
def test_fpgadataflow_lookup_external():
edt = DataType["INT8"]
embedding_cfg = (200000, DataType["UINT32"], 300)
ishape = (1, 600)
num_embeddings, idt, embedding_dim = embedding_cfg
eshape = (num_embeddings, embedding_dim)
exp_oshape = tuple(list(ishape) + [embedding_dim])
embeddings = gen_finn_dt_tensor(edt, eshape)
model = make_lookup_model(embeddings, ishape, idt, edt)
assert len(model.graph.node) == 1
assert model.graph.node[0].op_type == "Gather"
iname = model.graph.input[0].name
ename = model.graph.node[0].input[0]
oname = model.graph.output[0].name
assert model.get_tensor_datatype(iname) == idt
assert model.get_tensor_datatype(ename) == edt
assert model.get_tensor_datatype(oname) == edt
assert tuple(model.get_tensor_shape(ename)) == eshape
assert tuple(model.get_tensor_shape(oname)) == exp_oshape
assert (model.get_initializer(ename) == embeddings).all()
# itensor = gen_finn_dt_tensor(idt, ishape).astype(np.int64)
# itensor = np.clip(itensor, 0, num_embeddings - 1)
# ret = execute_onnx(model, {iname: itensor})
# exp_out = np.take(embeddings, itensor, axis=0)
# assert (exp_out == ret[oname]).all()
# call transformation to convert to HLS and verify conversion
model = model.transform(InferLookupLayer())
assert model.graph.node[0].op_type == "Lookup"
assert model.graph.node[0].input[0] == iname
assert model.graph.node[0].input[1] == ename
assert model.graph.node[0].output[0] == oname
getCustomOp(model.graph.node[0]).set_nodeattr("mem_mode", "external")
model = model.transform(GiveUniqueNodeNames())
model = model.transform(PrepareIP("xczu3eg-sbva484-1-e", 10))
model = model.transform(HLSSynthIP())
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment