Skip to content
Snippets Groups Projects
Commit 33aecc82 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Test] restructure lookup test

parent 7d231da0
No related branches found
No related tags found
No related merge requests found
......@@ -26,9 +26,9 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import pytest
import numpy as np
import onnx # noqa
import os
import torch
from brevitas.export import FINNManager
from torch import nn
......@@ -48,8 +48,6 @@ from finn.transformation.infer_datatypes import InferDataTypes
from finn.transformation.infer_shapes import InferShapes
from finn.util.basic import gen_finn_dt_tensor
tmpdir = os.environ["FINN_BUILD_DIR"]
def make_lookup_model(embeddings, ishape, idt, edt):
num_embeddings, embedding_dim = embeddings.shape
......@@ -81,18 +79,23 @@ def make_lookup_model(embeddings, ishape, idt, edt):
return model
def test_lookup_export_convert():
export_path = tmpdir + "/test_lookup_export.onnx"
# embedding DataType
@pytest.mark.parametrize("edt", [DataType["FIXED<8,2>"]])
# other embedding config
@pytest.mark.parametrize(
"embedding_cfg", [(130, DataType["UINT8"], 25), (5145, DataType["UINT16"], 20)]
)
# execution mode
@pytest.mark.parametrize("exec_mode", ["cppsim", "rtlsim"])
@pytest.mark.vivado
@pytest.mark.slow
def test_fpgadataflow_lookup(edt, embedding_cfg, exec_mode):
ishape = (1, 10)
idt = DataType["UINT8"]
edt = DataType["FIXED<8,2>"]
num_embeddings = 2 ** idt.bitwidth()
embedding_dim = 2
num_embeddings, idt, embedding_dim = embedding_cfg
eshape = (num_embeddings, embedding_dim)
exp_oshape = tuple(list(ishape) + [embedding_dim])
embeddings = gen_finn_dt_tensor(edt, eshape)
model = make_lookup_model(embeddings, ishape, idt, edt)
model.save(export_path)
assert len(model.graph.node) == 1
assert model.graph.node[0].op_type == "Gather"
iname = model.graph.input[0].name
......@@ -105,6 +108,7 @@ def test_lookup_export_convert():
assert tuple(model.get_tensor_shape(oname)) == exp_oshape
assert (model.get_initializer(ename) == embeddings).all()
itensor = gen_finn_dt_tensor(idt, ishape).astype(np.int64)
itensor = np.clip(itensor, 0, num_embeddings - 1)
ret = execute_onnx(model, {iname: itensor})
exp_out = np.take(embeddings, itensor, axis=0)
assert (exp_out == ret[oname]).all()
......@@ -114,17 +118,15 @@ def test_lookup_export_convert():
assert model.graph.node[0].input[0] == iname
assert model.graph.node[0].input[1] == ename
assert model.graph.node[0].output[0] == oname
# prepare and execute cppsim
model = model.transform(PrepareCppSim())
model = model.transform(CompileCppSim())
model = model.transform(SetExecMode("cppsim"))
ret_cppsim = execute_onnx(model, {iname: itensor})
assert (exp_out == ret_cppsim[oname]).all()
# prepare and execute rtlsim
model = model.transform(GiveUniqueNodeNames())
model = model.transform(PrepareIP("xc7z020clg400-1", 10))
model = model.transform(HLSSynthIP())
model = model.transform(SetExecMode("rtlsim"))
model = model.transform(PrepareRTLSim())
ret_rtlsim = execute_onnx(model, {iname: itensor})
assert (exp_out == ret_rtlsim[oname]).all()
if exec_mode == "cppsim":
model = model.transform(PrepareCppSim())
model = model.transform(CompileCppSim())
model = model.transform(SetExecMode("cppsim"))
elif exec_mode == "rtlsim":
model = model.transform(GiveUniqueNodeNames())
model = model.transform(PrepareIP("xc7z020clg400-1", 10))
model = model.transform(HLSSynthIP())
model = model.transform(SetExecMode("rtlsim"))
model = model.transform(PrepareRTLSim())
ret_sim = execute_onnx(model, {iname: itensor})
assert (exp_out == ret_sim[oname]).all()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment