From dff518d317bb4cb18759d3ce5ae44c31e50ecf92 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <yamanu@xilinx.com> Date: Fri, 6 Mar 2020 13:59:34 +0000 Subject: [PATCH] [Test] compare against expected output in conv lowering test --- tests/transformation/test_conv_lowering.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/tests/transformation/test_conv_lowering.py b/tests/transformation/test_conv_lowering.py index 098238531..85dd0f721 100644 --- a/tests/transformation/test_conv_lowering.py +++ b/tests/transformation/test_conv_lowering.py @@ -27,9 +27,9 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import os - - +import pkg_resources as pk import brevitas.onnx as bo +import numpy as np from finn.core.modelwrapper import ModelWrapper @@ -37,6 +37,7 @@ from finn.transformation.fold_constants import FoldConstants from finn.transformation.infer_shapes import InferShapes from finn.util.test import get_test_model_trained from finn.transformation.lower_convs_to_matmul import LowerConvsToMatMul +import finn.core.onnx_exec as oxe export_onnx_path = "test_output_cnv.onnx" @@ -47,6 +48,16 @@ def test_conv_lowering_cnv_w1a1(): model = ModelWrapper(export_onnx_path) model = model.transform(InferShapes()) model = model.transform(FoldConstants()) + fn = pk.resource_filename("finn", "data/cifar10/cifar10-test-data-class3.npz") + input_tensor = np.load(fn)["arr_0"].astype(np.float32) + assert input_tensor.shape == (1, 3, 32, 32) + # execute imported model to get expected answer + input_dict = {"0": input_tensor} + output_dict_e = oxe.execute_onnx(model, input_dict) + expected = output_dict_e[list(output_dict_e.keys())[0]] + # execute transformed model and compare model = model.transform(LowerConvsToMatMul()) - model.save("test.onnx") + output_dict_p = oxe.execute_onnx(model, input_dict) + produced = output_dict_p[list(output_dict_p.keys())[0]] + assert np.isclose(produced, expected).all() os.remove(export_onnx_path) -- GitLab