From 5fdd42506b26517a3fed8e83bb9a3bbf0f82bfc3 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <maltanar@gmail.com> Date: Sun, 3 Nov 2019 21:11:40 +0000 Subject: [PATCH] [Test] refactor tests to use onnx_exec.compare_execution --- tests/test_batchnorm_to_affine.py | 7 +------ tests/test_collapse_repeated_op.py | 4 +--- tests/test_move_add_past_mul.py | 8 ++------ tests/test_move_scalar_past_matmul.py | 8 ++------ tests/test_sign_to_thres.py | 8 +++----- 5 files changed, 9 insertions(+), 26 deletions(-) diff --git a/tests/test_batchnorm_to_affine.py b/tests/test_batchnorm_to_affine.py index 6c72aab2b..bb66f98f4 100644 --- a/tests/test_batchnorm_to_affine.py +++ b/tests/test_batchnorm_to_affine.py @@ -2,7 +2,6 @@ import os from pkgutil import get_data import brevitas.onnx as bo -import numpy as np import onnx import onnx.numpy_helper as nph import torch @@ -34,10 +33,6 @@ def test_batchnorm_to_affine(): # load one of the test vectors raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb") input_tensor = onnx.load_tensor_from_string(raw_i) - out_old = model.graph.output[0].name - out_new = new_model.graph.output[0].name input_dict = {"0": nph.to_array(input_tensor)} - output_original = oxe.execute_onnx(model, input_dict)[out_old] - output_transformed = oxe.execute_onnx(new_model, input_dict)[out_new] - assert np.isclose(output_transformed, output_original, atol=1e-3).all() + assert oxe.compare_execution(model, new_model, input_dict) os.remove(export_onnx_path) diff --git a/tests/test_collapse_repeated_op.py b/tests/test_collapse_repeated_op.py index 224df9c3b..d8cdde3c6 100644 --- a/tests/test_collapse_repeated_op.py +++ b/tests/test_collapse_repeated_op.py @@ -38,6 +38,4 @@ def test_collapse_repeated_op(): new_model = model.transform_repeated(tx.collapse_repeated_add) new_model = new_model.transform_repeated(tx.collapse_repeated_mul) inp_dict = {"top_in": np.asarray([-1.0, 1.0], dtype=np.float32)} - out_orig = ox.execute_onnx(model, inp_dict)["top_out"] - out_transformed = ox.execute_onnx(new_model, inp_dict)["top_out"] - assert np.isclose(out_orig, out_transformed).all() + assert ox.compare_execution(model, new_model, inp_dict) diff --git a/tests/test_move_add_past_mul.py b/tests/test_move_add_past_mul.py index b19e1ce32..a23827d0e 100644 --- a/tests/test_move_add_past_mul.py +++ b/tests/test_move_add_past_mul.py @@ -31,9 +31,7 @@ def test_move_add_past_mul_single(): model.set_initializer("mul_param", np.asarray([2, 4], dtype=np.float32)) new_model = model.transform_repeated(tx.move_add_past_mul) inp_dict = {"top_in": np.asarray([-1.0, 1.0], dtype=np.float32)} - out_orig = ox.execute_onnx(model, inp_dict)["top_out"] - out_transformed = ox.execute_onnx(new_model, inp_dict)["top_out"] - assert np.isclose(out_orig, out_transformed).all() + assert ox.compare_execution(model, new_model, inp_dict) def test_move_add_past_mul_multi(): @@ -65,6 +63,4 @@ def test_move_add_past_mul_multi(): model.set_initializer("mul_param_1", np.asarray([2, -4], dtype=np.float32)) new_model = model.transform_repeated(tx.move_add_past_mul) inp_dict = {"top_in": np.asarray([-1.0, 1.0], dtype=np.float32)} - out_orig = ox.execute_onnx(model, inp_dict)["top_out"] - out_transformed = ox.execute_onnx(new_model, inp_dict)["top_out"] - assert np.isclose(out_orig, out_transformed).all() + assert ox.compare_execution(model, new_model, inp_dict) diff --git a/tests/test_move_scalar_past_matmul.py b/tests/test_move_scalar_past_matmul.py index a9cd35d42..7bbdd7dd8 100644 --- a/tests/test_move_scalar_past_matmul.py +++ b/tests/test_move_scalar_past_matmul.py @@ -33,9 +33,7 @@ def test_move_scalar_mul_past_matmul(): ) new_model = model.transform_repeated(tx.move_scalar_mul_past_matmul) inp_dict = {"top_in": np.asarray([[-1.0, 1.0]], dtype=np.float32)} - out_orig = ox.execute_onnx(model, inp_dict)["top_out"] - out_transformed = ox.execute_onnx(new_model, inp_dict)["top_out"] - assert np.isclose(out_orig, out_transformed).all() + assert ox.compare_execution(model, new_model, inp_dict) assert new_model.graph.node[0].op_type == "MatMul" assert new_model.graph.node[1].op_type == "Mul" assert new_model.graph.node[0].output[0] == new_model.graph.node[1].input[0] @@ -66,9 +64,7 @@ def test_move_scalar_add_past_matmul(): ) new_model = model.transform_repeated(tx.move_scalar_add_past_matmul) inp_dict = {"top_in": np.asarray([[-1.0, 1.0]], dtype=np.float32)} - out_orig = ox.execute_onnx(model, inp_dict)["top_out"] - out_transformed = ox.execute_onnx(new_model, inp_dict)["top_out"] - assert np.isclose(out_orig, out_transformed).all() + assert ox.compare_execution(model, new_model, inp_dict) assert new_model.graph.node[0].op_type == "MatMul" assert new_model.graph.node[1].op_type == "Add" assert new_model.graph.node[0].output[0] == new_model.graph.node[1].input[0] diff --git a/tests/test_sign_to_thres.py b/tests/test_sign_to_thres.py index 4d7f447f8..665638283 100644 --- a/tests/test_sign_to_thres.py +++ b/tests/test_sign_to_thres.py @@ -26,8 +26,6 @@ def test_sign_to_thres(): model = model.transform_single(si.infer_shapes) input_dict = {} input_dict["v"] = np.random.randn(*[6, 3, 2, 2]).astype(np.float32) - expected = oxe.execute_onnx(model, input_dict)["out1"] - model = model.transform_single(sl.convert_sign_to_thres) - assert model.graph.node[0].op_type == "MultiThreshold" - produced = oxe.execute_onnx(model, input_dict)["out1"] - assert np.isclose(expected, produced, atol=1e-3).all() + new_model = model.transform_single(sl.convert_sign_to_thres) + assert new_model.graph.node[0].op_type == "MultiThreshold" + assert oxe.compare_execution(model, new_model, input_dict) -- GitLab