From 687952352237b2276dacddff812cbd8ae6a7d0d5 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <yamanu@xilinx.com> Date: Thu, 14 Oct 2021 16:16:07 +0200 Subject: [PATCH] [Test] high-level lookup behavior check --- tests/fpgadataflow/test_lookup.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/tests/fpgadataflow/test_lookup.py b/tests/fpgadataflow/test_lookup.py index b6d437d62..14e046633 100644 --- a/tests/fpgadataflow/test_lookup.py +++ b/tests/fpgadataflow/test_lookup.py @@ -26,6 +26,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import numpy as np import onnx # noqa import os import torch @@ -34,6 +35,7 @@ from torch import nn from finn.core.datatype import DataType from finn.core.modelwrapper import ModelWrapper +from finn.core.onnx_exec import execute_onnx from finn.transformation.infer_datatypes import InferDataTypes from finn.transformation.infer_shapes import InferShapes from finn.util.basic import gen_finn_dt_tensor @@ -41,7 +43,9 @@ from finn.util.basic import gen_finn_dt_tensor tmpdir = os.environ["FINN_BUILD_DIR"] -def make_lookup_model(num_embeddings, embedding_dim, ishape, idt, edt): +def make_lookup_model(embeddings, ishape, idt, edt): + num_embeddings, embedding_dim = embeddings.shape + class LookupModel(nn.Module): def __init__(self, num_embeddings, embedding_dim): super().__init__() @@ -61,8 +65,8 @@ def make_lookup_model(num_embeddings, embedding_dim, ishape, idt, edt): ename = model.graph.node[0].input[0] model.set_tensor_datatype(iname, idt) eshape = model.get_tensor_shape(ename) - new_embs = gen_finn_dt_tensor(edt, eshape) - model.set_initializer(ename, new_embs) + assert tuple(eshape) == embeddings.shape + model.set_initializer(ename, embeddings) model.set_tensor_datatype(ename, edt) model = model.transform(InferShapes()) model = model.transform(InferDataTypes()) @@ -72,12 +76,14 @@ def make_lookup_model(num_embeddings, embedding_dim, ishape, idt, edt): def test_lookup_export(): export_path = tmpdir + "/test_lookup_export.onnx" ishape = (1, 10) - num_embeddings = 8 - embedding_dim = 2 - exp_oshape = tuple(list(ishape) + [embedding_dim]) idt = DataType["UINT8"] edt = DataType["FIXED<8,2>"] - model = make_lookup_model(num_embeddings, embedding_dim, ishape, idt, edt) + num_embeddings = 2 ** idt.bitwidth() + embedding_dim = 2 + eshape = (num_embeddings, embedding_dim) + exp_oshape = tuple(list(ishape) + [embedding_dim]) + embeddings = gen_finn_dt_tensor(edt, eshape) + model = make_lookup_model(embeddings, ishape, idt, edt) model.save(export_path) assert len(model.graph.node) == 1 assert model.graph.node[0].op_type == "Gather" @@ -87,5 +93,10 @@ def test_lookup_export(): assert model.get_tensor_datatype(iname) == idt assert model.get_tensor_datatype(ename) == edt assert model.get_tensor_datatype(oname) == edt - assert tuple(model.get_tensor_shape(ename)) == (num_embeddings, embedding_dim) + assert tuple(model.get_tensor_shape(ename)) == eshape assert tuple(model.get_tensor_shape(oname)) == exp_oshape + assert (model.get_initializer(ename) == embeddings).all() + itensor = gen_finn_dt_tensor(idt, ishape).astype(np.int64) + ret = execute_onnx(model, {iname: itensor}) + exp_out = np.take(embeddings, itensor, axis=0) + assert (exp_out == ret[oname]).all() -- GitLab