diff --git a/src/finn/core/onnx_exec.py b/src/finn/core/onnx_exec.py index ffe58609692adc96237cc509bc2b5928e703bde3..5a57fe0c409f9497316f4081eea0043ca1b77819 100644 --- a/src/finn/core/onnx_exec.py +++ b/src/finn/core/onnx_exec.py @@ -27,7 +27,6 @@ import copy import onnx.helper as helper -import onnx.shape_inference as si import onnxruntime as rt from onnx import numpy_helper as np_helper @@ -72,9 +71,6 @@ def execute_onnx(model, input_dict, return_full_exec_context=False): the execution (including inputs, weights, activations and final outputs) will be returned as a dict.""" - # call ONNX shape inference to make sure we have value_info fields for all - # the intermediate tensors in the graph - model = si.infer_shapes(model) graph = model.graph # first, we need to make sure that every variable required by the graph has # some buffer associated with it. this includes graph inputs (which includes @@ -132,7 +128,6 @@ def execute_onnx_and_make_model(model, input_dict): """Execute given ONNX model with given named inputs and return a new model where an initializer is provided for each tensor.""" - model = si.infer_shapes(model) # retrieve the full execution context execution_context = execute_onnx(model, input_dict, True) new_model = copy.deepcopy(model) diff --git a/src/finn/transformation/batchnorm_to_affine.py b/src/finn/transformation/batchnorm_to_affine.py index 0fe3e8a131fbade8f6058193f68fe8c28fe28eec..dfdad3ba28cc7922f8776d05edad10c21a8bb370 100644 --- a/src/finn/transformation/batchnorm_to_affine.py +++ b/src/finn/transformation/batchnorm_to_affine.py @@ -11,7 +11,6 @@ import finn.transformation.general as tg def batchnorm_to_affine(model): """Replaces any test-time BatchNorm layers with Mul-Add layers.""" new_model = copy.deepcopy(model) - new_model = si.infer_shapes(new_model) graph = new_model.graph nodes_to_remove = [] node_ind = 0 diff --git a/tests/test_basic_onnx_exec.py b/tests/test_basic_onnx_exec.py index 9fa596fbbe0766a0344ce609ca9665349836fcc0..6769abdca8951ee4db9e67667eddbc54dad5b97e 100644 --- a/tests/test_basic_onnx_exec.py +++ b/tests/test_basic_onnx_exec.py @@ -5,6 +5,7 @@ import shutil import numpy as np import onnx import onnx.numpy_helper as np_helper +import onnx.shape_inference as si import wget import finn.core.onnx_exec as oxe @@ -25,6 +26,9 @@ def test_mnist_onnx_download_extract_run(): assert hashlib.md5(f.read()).hexdigest() == "d7cd24a0a76cd492f31065301d468c3d" # load the onnx model model = onnx.load(mnist_onnx_local_dir + "/mnist/model.onnx") + # call ONNX shape inference to make sure we have value_info fields for all + # the intermediate tensors in the graph + model = si.infer_shapes(model) # load one of the test vectors input_tensor = onnx.TensorProto() output_tensor = onnx.TensorProto() diff --git a/tests/test_batchnorm_to_affine.py b/tests/test_batchnorm_to_affine.py index 79f6357bbc2127468b57e86b76f6eb28e089514e..34abb9a1973c1d8c8f20610cb7d69129f06032e0 100644 --- a/tests/test_batchnorm_to_affine.py +++ b/tests/test_batchnorm_to_affine.py @@ -7,6 +7,7 @@ import brevitas.onnx as bo import numpy as np import onnx import onnx.numpy_helper as nph +import onnx.shape_inference as si import torch import wget from models.common import get_act_quant, get_quant_linear, get_quant_type, get_stats_op @@ -94,6 +95,7 @@ def test_batchnorm_to_affine(): lfc.load_state_dict(checkpoint["state_dict"]) bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path) model = onnx.load(export_onnx_path) + model = si.infer_shapes(model) new_model = tx.batchnorm_to_affine(model) try: os.remove("/tmp/" + mnist_onnx_filename) diff --git a/tests/test_brevitas_export.py b/tests/test_brevitas_export.py index 11d8b1cf33a6b48d35768f4a5a45430e81cdfc71..2905869847b3f4290ac8f35794a164817c0d76a0 100644 --- a/tests/test_brevitas_export.py +++ b/tests/test_brevitas_export.py @@ -7,6 +7,7 @@ import brevitas.onnx as bo import numpy as np import onnx import onnx.numpy_helper as nph +import onnx.shape_inference as si import torch import wget from models.common import get_act_quant, get_quant_linear, get_quant_type, get_stats_op @@ -163,6 +164,9 @@ def test_brevitas_to_onnx_export_and_exec(): lfc.load_state_dict(checkpoint["state_dict"]) bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path) model = onnx.load(export_onnx_path) + # call ONNX shape inference to make sure we have value_info fields for all + # the intermediate tensors in the graph + model = si.infer_shapes(model) try: os.remove("/tmp/" + mnist_onnx_filename) except OSError: