From dff518d317bb4cb18759d3ce5ae44c31e50ecf92 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <yamanu@xilinx.com>
Date: Fri, 6 Mar 2020 13:59:34 +0000
Subject: [PATCH] [Test] compare against expected output in conv lowering test

---
 tests/transformation/test_conv_lowering.py | 17 ++++++++++++++---
 1 file changed, 14 insertions(+), 3 deletions(-)

diff --git a/tests/transformation/test_conv_lowering.py b/tests/transformation/test_conv_lowering.py
index 098238531..85dd0f721 100644
--- a/tests/transformation/test_conv_lowering.py
+++ b/tests/transformation/test_conv_lowering.py
@@ -27,9 +27,9 @@
 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
 import os
-
-
+import pkg_resources as pk
 import brevitas.onnx as bo
+import numpy as np
 
 
 from finn.core.modelwrapper import ModelWrapper
@@ -37,6 +37,7 @@ from finn.transformation.fold_constants import FoldConstants
 from finn.transformation.infer_shapes import InferShapes
 from finn.util.test import get_test_model_trained
 from finn.transformation.lower_convs_to_matmul import LowerConvsToMatMul
+import finn.core.onnx_exec as oxe
 
 export_onnx_path = "test_output_cnv.onnx"
 
@@ -47,6 +48,16 @@ def test_conv_lowering_cnv_w1a1():
     model = ModelWrapper(export_onnx_path)
     model = model.transform(InferShapes())
     model = model.transform(FoldConstants())
+    fn = pk.resource_filename("finn", "data/cifar10/cifar10-test-data-class3.npz")
+    input_tensor = np.load(fn)["arr_0"].astype(np.float32)
+    assert input_tensor.shape == (1, 3, 32, 32)
+    # execute imported model to get expected answer
+    input_dict = {"0": input_tensor}
+    output_dict_e = oxe.execute_onnx(model, input_dict)
+    expected = output_dict_e[list(output_dict_e.keys())[0]]
+    # execute transformed model and compare
     model = model.transform(LowerConvsToMatMul())
-    model.save("test.onnx")
+    output_dict_p = oxe.execute_onnx(model, input_dict)
+    produced = output_dict_p[list(output_dict_p.keys())[0]]
+    assert np.isclose(produced, expected).all()
     os.remove(export_onnx_path)
-- 
GitLab