diff --git a/src/finn/core/onnx_exec.py b/src/finn/core/onnx_exec.py
index ffe58609692adc96237cc509bc2b5928e703bde3..5a57fe0c409f9497316f4081eea0043ca1b77819 100644
--- a/src/finn/core/onnx_exec.py
+++ b/src/finn/core/onnx_exec.py
@@ -27,7 +27,6 @@
 import copy
 
 import onnx.helper as helper
-import onnx.shape_inference as si
 import onnxruntime as rt
 from onnx import numpy_helper as np_helper
 
@@ -72,9 +71,6 @@ def execute_onnx(model, input_dict, return_full_exec_context=False):
     the execution (including inputs, weights, activations and final outputs)
     will be returned as a dict."""
 
-    # call ONNX shape inference to make sure we have value_info fields for all
-    # the intermediate tensors in the graph
-    model = si.infer_shapes(model)
     graph = model.graph
     # first, we need to make sure that every variable required by the graph has
     # some buffer associated with it. this includes graph inputs (which includes
@@ -132,7 +128,6 @@ def execute_onnx_and_make_model(model, input_dict):
     """Execute given ONNX model with given named inputs and return a new model
     where an initializer is provided for each tensor."""
 
-    model = si.infer_shapes(model)
     # retrieve the full execution context
     execution_context = execute_onnx(model, input_dict, True)
     new_model = copy.deepcopy(model)
diff --git a/src/finn/transformation/batchnorm_to_affine.py b/src/finn/transformation/batchnorm_to_affine.py
index 0fe3e8a131fbade8f6058193f68fe8c28fe28eec..dfdad3ba28cc7922f8776d05edad10c21a8bb370 100644
--- a/src/finn/transformation/batchnorm_to_affine.py
+++ b/src/finn/transformation/batchnorm_to_affine.py
@@ -11,7 +11,6 @@ import finn.transformation.general as tg
 def batchnorm_to_affine(model):
     """Replaces any test-time BatchNorm layers with Mul-Add layers."""
     new_model = copy.deepcopy(model)
-    new_model = si.infer_shapes(new_model)
     graph = new_model.graph
     nodes_to_remove = []
     node_ind = 0
diff --git a/tests/test_basic_onnx_exec.py b/tests/test_basic_onnx_exec.py
index 9fa596fbbe0766a0344ce609ca9665349836fcc0..6769abdca8951ee4db9e67667eddbc54dad5b97e 100644
--- a/tests/test_basic_onnx_exec.py
+++ b/tests/test_basic_onnx_exec.py
@@ -5,6 +5,7 @@ import shutil
 import numpy as np
 import onnx
 import onnx.numpy_helper as np_helper
+import onnx.shape_inference as si
 import wget
 
 import finn.core.onnx_exec as oxe
@@ -25,6 +26,9 @@ def test_mnist_onnx_download_extract_run():
         assert hashlib.md5(f.read()).hexdigest() == "d7cd24a0a76cd492f31065301d468c3d"
     # load the onnx model
     model = onnx.load(mnist_onnx_local_dir + "/mnist/model.onnx")
+    # call ONNX shape inference to make sure we have value_info fields for all
+    # the intermediate tensors in the graph
+    model = si.infer_shapes(model)
     # load one of the test vectors
     input_tensor = onnx.TensorProto()
     output_tensor = onnx.TensorProto()
diff --git a/tests/test_batchnorm_to_affine.py b/tests/test_batchnorm_to_affine.py
index 79f6357bbc2127468b57e86b76f6eb28e089514e..34abb9a1973c1d8c8f20610cb7d69129f06032e0 100644
--- a/tests/test_batchnorm_to_affine.py
+++ b/tests/test_batchnorm_to_affine.py
@@ -7,6 +7,7 @@ import brevitas.onnx as bo
 import numpy as np
 import onnx
 import onnx.numpy_helper as nph
+import onnx.shape_inference as si
 import torch
 import wget
 from models.common import get_act_quant, get_quant_linear, get_quant_type, get_stats_op
@@ -94,6 +95,7 @@ def test_batchnorm_to_affine():
     lfc.load_state_dict(checkpoint["state_dict"])
     bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
     model = onnx.load(export_onnx_path)
+    model = si.infer_shapes(model)
     new_model = tx.batchnorm_to_affine(model)
     try:
         os.remove("/tmp/" + mnist_onnx_filename)
diff --git a/tests/test_brevitas_export.py b/tests/test_brevitas_export.py
index 11d8b1cf33a6b48d35768f4a5a45430e81cdfc71..2905869847b3f4290ac8f35794a164817c0d76a0 100644
--- a/tests/test_brevitas_export.py
+++ b/tests/test_brevitas_export.py
@@ -7,6 +7,7 @@ import brevitas.onnx as bo
 import numpy as np
 import onnx
 import onnx.numpy_helper as nph
+import onnx.shape_inference as si
 import torch
 import wget
 from models.common import get_act_quant, get_quant_linear, get_quant_type, get_stats_op
@@ -163,6 +164,9 @@ def test_brevitas_to_onnx_export_and_exec():
     lfc.load_state_dict(checkpoint["state_dict"])
     bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
     model = onnx.load(export_onnx_path)
+    # call ONNX shape inference to make sure we have value_info fields for all
+    # the intermediate tensors in the graph
+    model = si.infer_shapes(model)
     try:
         os.remove("/tmp/" + mnist_onnx_filename)
     except OSError: