diff --git a/tests/test_basic_onnx_exec.py b/tests/test_basic_onnx_exec.py index 6769abdca8951ee4db9e67667eddbc54dad5b97e..2e13ddc5082e011fb76072e5024ff93a336b8a00 100644 --- a/tests/test_basic_onnx_exec.py +++ b/tests/test_basic_onnx_exec.py @@ -5,10 +5,11 @@ import shutil import numpy as np import onnx import onnx.numpy_helper as np_helper -import onnx.shape_inference as si import wget import finn.core.onnx_exec as oxe +import finn.transformation.infer_shapes as si +from finn.core.modelwrapper import ModelWrapper mnist_onnx_url_base = "https://onnxzoo.blob.core.windows.net/models/opset_8/mnist" mnist_onnx_filename = "mnist.tar.gz" @@ -25,10 +26,8 @@ def test_mnist_onnx_download_extract_run(): with open(mnist_onnx_local_dir + "/mnist/model.onnx", "rb") as f: assert hashlib.md5(f.read()).hexdigest() == "d7cd24a0a76cd492f31065301d468c3d" # load the onnx model - model = onnx.load(mnist_onnx_local_dir + "/mnist/model.onnx") - # call ONNX shape inference to make sure we have value_info fields for all - # the intermediate tensors in the graph - model = si.infer_shapes(model) + model = ModelWrapper(mnist_onnx_local_dir + "/mnist/model.onnx") + model = model.transform_single(si.infer_shapes) # load one of the test vectors input_tensor = onnx.TensorProto() output_tensor = onnx.TensorProto() @@ -38,7 +37,7 @@ def test_mnist_onnx_download_extract_run(): output_tensor.ParseFromString(f.read()) # run using FINN-based execution input_dict = {"Input3": np_helper.to_array(input_tensor)} - output_dict = oxe.execute_onnx(model, input_dict) + output_dict = oxe.execute_onnx(model.model, input_dict) assert np.isclose( np_helper.to_array(output_tensor), output_dict["Plus214_Output_0"], atol=1e-3 ).all() diff --git a/tests/test_batchnorm_to_affine.py b/tests/test_batchnorm_to_affine.py index 7c7efc494a1cf98cb51dd0ef0c97870e7d440a30..482fd8d873c5a4e3af34321418af9147e2f87445 100644 --- a/tests/test_batchnorm_to_affine.py +++ b/tests/test_batchnorm_to_affine.py @@ -7,7 +7,6 @@ import brevitas.onnx as bo import numpy as np import onnx import onnx.numpy_helper as nph -import onnx.shape_inference as si import torch import wget from models.common import get_act_quant, get_quant_linear, get_quant_type, get_stats_op @@ -15,6 +14,7 @@ from torch.nn import BatchNorm1d, Dropout, Module, ModuleList import finn.core.onnx_exec as oxe import finn.transformation.batchnorm_to_affine as tx +import finn.transformation.infer_shapes as si from finn.core.modelwrapper import ModelWrapper FC_OUT_FEATURES = [1024, 1024, 1024] @@ -96,7 +96,7 @@ def test_batchnorm_to_affine(): lfc.load_state_dict(checkpoint["state_dict"]) bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path) model = ModelWrapper(export_onnx_path) - model.model = si.infer_shapes(model.model) + model = model.transform_single(si.infer_shapes) new_model = model.transform_single(tx.batchnorm_to_affine) try: os.remove("/tmp/" + mnist_onnx_filename) diff --git a/tests/test_brevitas_export.py b/tests/test_brevitas_export.py index 2905869847b3f4290ac8f35794a164817c0d76a0..c10813d1522af693164d4d4181932f297dfd9a8f 100644 --- a/tests/test_brevitas_export.py +++ b/tests/test_brevitas_export.py @@ -7,13 +7,14 @@ import brevitas.onnx as bo import numpy as np import onnx import onnx.numpy_helper as nph -import onnx.shape_inference as si import torch import wget from models.common import get_act_quant, get_quant_linear, get_quant_type, get_stats_op from torch.nn import BatchNorm1d, Dropout, Module, ModuleList import finn.core.onnx_exec as oxe +import finn.transformation.infer_shapes as si +from finn.core.modelwrapper import ModelWrapper FC_OUT_FEATURES = [1024, 1024, 1024] INTERMEDIATE_FC_PER_OUT_CH_SCALING = True @@ -90,7 +91,7 @@ class LFC(Module): def test_brevitas_to_onnx_export(): lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1) bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path) - model = onnx.load(export_onnx_path) + model = ModelWrapper(export_onnx_path) # TODO the following way of testing is highly sensitive to small changes # in PyTorch ONNX export: the order, names, count... of nodes could # easily change between different versions, and break this test. @@ -163,10 +164,8 @@ def test_brevitas_to_onnx_export_and_exec(): checkpoint = torch.load(trained_lfc_checkpoint, map_location="cpu") lfc.load_state_dict(checkpoint["state_dict"]) bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path) - model = onnx.load(export_onnx_path) - # call ONNX shape inference to make sure we have value_info fields for all - # the intermediate tensors in the graph - model = si.infer_shapes(model) + model = ModelWrapper(export_onnx_path) + model = model.transform_single(si.infer_shapes) try: os.remove("/tmp/" + mnist_onnx_filename) except OSError: @@ -179,7 +178,7 @@ def test_brevitas_to_onnx_export_and_exec(): input_tensor.ParseFromString(f.read()) # run using FINN-based execution input_dict = {"0": nph.to_array(input_tensor)} - output_dict = oxe.execute_onnx(model, input_dict) + output_dict = oxe.execute_onnx(model.model, input_dict) produced = output_dict[list(output_dict.keys())[0]] # run using PyTorch/Brevitas input_tensor = torch.from_numpy(nph.to_array(input_tensor)).float()