diff --git a/tests/test_basic_onnx_exec.py b/tests/test_basic_onnx_exec.py
index 6769abdca8951ee4db9e67667eddbc54dad5b97e..2e13ddc5082e011fb76072e5024ff93a336b8a00 100644
--- a/tests/test_basic_onnx_exec.py
+++ b/tests/test_basic_onnx_exec.py
@@ -5,10 +5,11 @@ import shutil
 import numpy as np
 import onnx
 import onnx.numpy_helper as np_helper
-import onnx.shape_inference as si
 import wget
 
 import finn.core.onnx_exec as oxe
+import finn.transformation.infer_shapes as si
+from finn.core.modelwrapper import ModelWrapper
 
 mnist_onnx_url_base = "https://onnxzoo.blob.core.windows.net/models/opset_8/mnist"
 mnist_onnx_filename = "mnist.tar.gz"
@@ -25,10 +26,8 @@ def test_mnist_onnx_download_extract_run():
     with open(mnist_onnx_local_dir + "/mnist/model.onnx", "rb") as f:
         assert hashlib.md5(f.read()).hexdigest() == "d7cd24a0a76cd492f31065301d468c3d"
     # load the onnx model
-    model = onnx.load(mnist_onnx_local_dir + "/mnist/model.onnx")
-    # call ONNX shape inference to make sure we have value_info fields for all
-    # the intermediate tensors in the graph
-    model = si.infer_shapes(model)
+    model = ModelWrapper(mnist_onnx_local_dir + "/mnist/model.onnx")
+    model = model.transform_single(si.infer_shapes)
     # load one of the test vectors
     input_tensor = onnx.TensorProto()
     output_tensor = onnx.TensorProto()
@@ -38,7 +37,7 @@ def test_mnist_onnx_download_extract_run():
         output_tensor.ParseFromString(f.read())
     # run using FINN-based execution
     input_dict = {"Input3": np_helper.to_array(input_tensor)}
-    output_dict = oxe.execute_onnx(model, input_dict)
+    output_dict = oxe.execute_onnx(model.model, input_dict)
     assert np.isclose(
         np_helper.to_array(output_tensor), output_dict["Plus214_Output_0"], atol=1e-3
     ).all()
diff --git a/tests/test_batchnorm_to_affine.py b/tests/test_batchnorm_to_affine.py
index 7c7efc494a1cf98cb51dd0ef0c97870e7d440a30..482fd8d873c5a4e3af34321418af9147e2f87445 100644
--- a/tests/test_batchnorm_to_affine.py
+++ b/tests/test_batchnorm_to_affine.py
@@ -7,7 +7,6 @@ import brevitas.onnx as bo
 import numpy as np
 import onnx
 import onnx.numpy_helper as nph
-import onnx.shape_inference as si
 import torch
 import wget
 from models.common import get_act_quant, get_quant_linear, get_quant_type, get_stats_op
@@ -15,6 +14,7 @@ from torch.nn import BatchNorm1d, Dropout, Module, ModuleList
 
 import finn.core.onnx_exec as oxe
 import finn.transformation.batchnorm_to_affine as tx
+import finn.transformation.infer_shapes as si
 from finn.core.modelwrapper import ModelWrapper
 
 FC_OUT_FEATURES = [1024, 1024, 1024]
@@ -96,7 +96,7 @@ def test_batchnorm_to_affine():
     lfc.load_state_dict(checkpoint["state_dict"])
     bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
     model = ModelWrapper(export_onnx_path)
-    model.model = si.infer_shapes(model.model)
+    model = model.transform_single(si.infer_shapes)
     new_model = model.transform_single(tx.batchnorm_to_affine)
     try:
         os.remove("/tmp/" + mnist_onnx_filename)
diff --git a/tests/test_brevitas_export.py b/tests/test_brevitas_export.py
index 2905869847b3f4290ac8f35794a164817c0d76a0..c10813d1522af693164d4d4181932f297dfd9a8f 100644
--- a/tests/test_brevitas_export.py
+++ b/tests/test_brevitas_export.py
@@ -7,13 +7,14 @@ import brevitas.onnx as bo
 import numpy as np
 import onnx
 import onnx.numpy_helper as nph
-import onnx.shape_inference as si
 import torch
 import wget
 from models.common import get_act_quant, get_quant_linear, get_quant_type, get_stats_op
 from torch.nn import BatchNorm1d, Dropout, Module, ModuleList
 
 import finn.core.onnx_exec as oxe
+import finn.transformation.infer_shapes as si
+from finn.core.modelwrapper import ModelWrapper
 
 FC_OUT_FEATURES = [1024, 1024, 1024]
 INTERMEDIATE_FC_PER_OUT_CH_SCALING = True
@@ -90,7 +91,7 @@ class LFC(Module):
 def test_brevitas_to_onnx_export():
     lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1)
     bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
-    model = onnx.load(export_onnx_path)
+    model = ModelWrapper(export_onnx_path)
     # TODO the following way of testing is highly sensitive to small changes
     # in PyTorch ONNX export: the order, names, count... of nodes could
     # easily change between different versions, and break this test.
@@ -163,10 +164,8 @@ def test_brevitas_to_onnx_export_and_exec():
     checkpoint = torch.load(trained_lfc_checkpoint, map_location="cpu")
     lfc.load_state_dict(checkpoint["state_dict"])
     bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
-    model = onnx.load(export_onnx_path)
-    # call ONNX shape inference to make sure we have value_info fields for all
-    # the intermediate tensors in the graph
-    model = si.infer_shapes(model)
+    model = ModelWrapper(export_onnx_path)
+    model = model.transform_single(si.infer_shapes)
     try:
         os.remove("/tmp/" + mnist_onnx_filename)
     except OSError:
@@ -179,7 +178,7 @@ def test_brevitas_to_onnx_export_and_exec():
         input_tensor.ParseFromString(f.read())
     # run using FINN-based execution
     input_dict = {"0": nph.to_array(input_tensor)}
-    output_dict = oxe.execute_onnx(model, input_dict)
+    output_dict = oxe.execute_onnx(model.model, input_dict)
     produced = output_dict[list(output_dict.keys())[0]]
     # run using PyTorch/Brevitas
     input_tensor = torch.from_numpy(nph.to_array(input_tensor)).float()