diff --git a/tests/test_batchnorm_to_affine.py b/tests/test_batchnorm_to_affine.py index 6c72aab2b4d25b6062dc52e7a67cf217bb43625e..bb66f98f490069ffc12c4503a5a897c37bcf93b8 100644 --- a/tests/test_batchnorm_to_affine.py +++ b/tests/test_batchnorm_to_affine.py @@ -2,7 +2,6 @@ import os from pkgutil import get_data import brevitas.onnx as bo -import numpy as np import onnx import onnx.numpy_helper as nph import torch @@ -34,10 +33,6 @@ def test_batchnorm_to_affine(): # load one of the test vectors raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb") input_tensor = onnx.load_tensor_from_string(raw_i) - out_old = model.graph.output[0].name - out_new = new_model.graph.output[0].name input_dict = {"0": nph.to_array(input_tensor)} - output_original = oxe.execute_onnx(model, input_dict)[out_old] - output_transformed = oxe.execute_onnx(new_model, input_dict)[out_new] - assert np.isclose(output_transformed, output_original, atol=1e-3).all() + assert oxe.compare_execution(model, new_model, input_dict) os.remove(export_onnx_path) diff --git a/tests/test_collapse_repeated_op.py b/tests/test_collapse_repeated_op.py index 224df9c3b37de85278f48480460ed186934e3487..d8cdde3c654873b368653347863af05317bc5bcb 100644 --- a/tests/test_collapse_repeated_op.py +++ b/tests/test_collapse_repeated_op.py @@ -38,6 +38,4 @@ def test_collapse_repeated_op(): new_model = model.transform_repeated(tx.collapse_repeated_add) new_model = new_model.transform_repeated(tx.collapse_repeated_mul) inp_dict = {"top_in": np.asarray([-1.0, 1.0], dtype=np.float32)} - out_orig = ox.execute_onnx(model, inp_dict)["top_out"] - out_transformed = ox.execute_onnx(new_model, inp_dict)["top_out"] - assert np.isclose(out_orig, out_transformed).all() + assert ox.compare_execution(model, new_model, inp_dict) diff --git a/tests/test_move_add_past_mul.py b/tests/test_move_add_past_mul.py index b19e1ce326b0d14da86e4324c62db1df688eb886..a23827d0e91a543cf5595c816628169a251bf4ce 100644 --- a/tests/test_move_add_past_mul.py +++ b/tests/test_move_add_past_mul.py @@ -31,9 +31,7 @@ def test_move_add_past_mul_single(): model.set_initializer("mul_param", np.asarray([2, 4], dtype=np.float32)) new_model = model.transform_repeated(tx.move_add_past_mul) inp_dict = {"top_in": np.asarray([-1.0, 1.0], dtype=np.float32)} - out_orig = ox.execute_onnx(model, inp_dict)["top_out"] - out_transformed = ox.execute_onnx(new_model, inp_dict)["top_out"] - assert np.isclose(out_orig, out_transformed).all() + assert ox.compare_execution(model, new_model, inp_dict) def test_move_add_past_mul_multi(): @@ -65,6 +63,4 @@ def test_move_add_past_mul_multi(): model.set_initializer("mul_param_1", np.asarray([2, -4], dtype=np.float32)) new_model = model.transform_repeated(tx.move_add_past_mul) inp_dict = {"top_in": np.asarray([-1.0, 1.0], dtype=np.float32)} - out_orig = ox.execute_onnx(model, inp_dict)["top_out"] - out_transformed = ox.execute_onnx(new_model, inp_dict)["top_out"] - assert np.isclose(out_orig, out_transformed).all() + assert ox.compare_execution(model, new_model, inp_dict) diff --git a/tests/test_move_scalar_past_matmul.py b/tests/test_move_scalar_past_matmul.py index a9cd35d425a1fbf3109b36a7fc6b24bd23706a47..7bbdd7dd8506437371147faa98310b8caa115318 100644 --- a/tests/test_move_scalar_past_matmul.py +++ b/tests/test_move_scalar_past_matmul.py @@ -33,9 +33,7 @@ def test_move_scalar_mul_past_matmul(): ) new_model = model.transform_repeated(tx.move_scalar_mul_past_matmul) inp_dict = {"top_in": np.asarray([[-1.0, 1.0]], dtype=np.float32)} - out_orig = ox.execute_onnx(model, inp_dict)["top_out"] - out_transformed = ox.execute_onnx(new_model, inp_dict)["top_out"] - assert np.isclose(out_orig, out_transformed).all() + assert ox.compare_execution(model, new_model, inp_dict) assert new_model.graph.node[0].op_type == "MatMul" assert new_model.graph.node[1].op_type == "Mul" assert new_model.graph.node[0].output[0] == new_model.graph.node[1].input[0] @@ -66,9 +64,7 @@ def test_move_scalar_add_past_matmul(): ) new_model = model.transform_repeated(tx.move_scalar_add_past_matmul) inp_dict = {"top_in": np.asarray([[-1.0, 1.0]], dtype=np.float32)} - out_orig = ox.execute_onnx(model, inp_dict)["top_out"] - out_transformed = ox.execute_onnx(new_model, inp_dict)["top_out"] - assert np.isclose(out_orig, out_transformed).all() + assert ox.compare_execution(model, new_model, inp_dict) assert new_model.graph.node[0].op_type == "MatMul" assert new_model.graph.node[1].op_type == "Add" assert new_model.graph.node[0].output[0] == new_model.graph.node[1].input[0] diff --git a/tests/test_sign_to_thres.py b/tests/test_sign_to_thres.py index 4d7f447f827fa9ddcf0698e89a7257021c94d802..6656382837565b0d6e3f723973036fe9981dfc58 100644 --- a/tests/test_sign_to_thres.py +++ b/tests/test_sign_to_thres.py @@ -26,8 +26,6 @@ def test_sign_to_thres(): model = model.transform_single(si.infer_shapes) input_dict = {} input_dict["v"] = np.random.randn(*[6, 3, 2, 2]).astype(np.float32) - expected = oxe.execute_onnx(model, input_dict)["out1"] - model = model.transform_single(sl.convert_sign_to_thres) - assert model.graph.node[0].op_type == "MultiThreshold" - produced = oxe.execute_onnx(model, input_dict)["out1"] - assert np.isclose(expected, produced, atol=1e-3).all() + new_model = model.transform_single(sl.convert_sign_to_thres) + assert new_model.graph.node[0].op_type == "MultiThreshold" + assert oxe.compare_execution(model, new_model, input_dict)