Skip to content
Snippets Groups Projects
Commit dde783bc authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Core] prefer calling shape inference externally

this will be eventually performed in ModelWrapper automatically
when needed
parent 4594c2fc
No related branches found
No related tags found
No related merge requests found
......@@ -27,7 +27,6 @@
import copy
import onnx.helper as helper
import onnx.shape_inference as si
import onnxruntime as rt
from onnx import numpy_helper as np_helper
......@@ -72,9 +71,6 @@ def execute_onnx(model, input_dict, return_full_exec_context=False):
the execution (including inputs, weights, activations and final outputs)
will be returned as a dict."""
# call ONNX shape inference to make sure we have value_info fields for all
# the intermediate tensors in the graph
model = si.infer_shapes(model)
graph = model.graph
# first, we need to make sure that every variable required by the graph has
# some buffer associated with it. this includes graph inputs (which includes
......@@ -132,7 +128,6 @@ def execute_onnx_and_make_model(model, input_dict):
"""Execute given ONNX model with given named inputs and return a new model
where an initializer is provided for each tensor."""
model = si.infer_shapes(model)
# retrieve the full execution context
execution_context = execute_onnx(model, input_dict, True)
new_model = copy.deepcopy(model)
......
......@@ -11,7 +11,6 @@ import finn.transformation.general as tg
def batchnorm_to_affine(model):
"""Replaces any test-time BatchNorm layers with Mul-Add layers."""
new_model = copy.deepcopy(model)
new_model = si.infer_shapes(new_model)
graph = new_model.graph
nodes_to_remove = []
node_ind = 0
......
......@@ -5,6 +5,7 @@ import shutil
import numpy as np
import onnx
import onnx.numpy_helper as np_helper
import onnx.shape_inference as si
import wget
import finn.core.onnx_exec as oxe
......@@ -25,6 +26,9 @@ def test_mnist_onnx_download_extract_run():
assert hashlib.md5(f.read()).hexdigest() == "d7cd24a0a76cd492f31065301d468c3d"
# load the onnx model
model = onnx.load(mnist_onnx_local_dir + "/mnist/model.onnx")
# call ONNX shape inference to make sure we have value_info fields for all
# the intermediate tensors in the graph
model = si.infer_shapes(model)
# load one of the test vectors
input_tensor = onnx.TensorProto()
output_tensor = onnx.TensorProto()
......
......@@ -7,6 +7,7 @@ import brevitas.onnx as bo
import numpy as np
import onnx
import onnx.numpy_helper as nph
import onnx.shape_inference as si
import torch
import wget
from models.common import get_act_quant, get_quant_linear, get_quant_type, get_stats_op
......@@ -94,6 +95,7 @@ def test_batchnorm_to_affine():
lfc.load_state_dict(checkpoint["state_dict"])
bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
model = onnx.load(export_onnx_path)
model = si.infer_shapes(model)
new_model = tx.batchnorm_to_affine(model)
try:
os.remove("/tmp/" + mnist_onnx_filename)
......
......@@ -7,6 +7,7 @@ import brevitas.onnx as bo
import numpy as np
import onnx
import onnx.numpy_helper as nph
import onnx.shape_inference as si
import torch
import wget
from models.common import get_act_quant, get_quant_linear, get_quant_type, get_stats_op
......@@ -163,6 +164,9 @@ def test_brevitas_to_onnx_export_and_exec():
lfc.load_state_dict(checkpoint["state_dict"])
bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
model = onnx.load(export_onnx_path)
# call ONNX shape inference to make sure we have value_info fields for all
# the intermediate tensors in the graph
model = si.infer_shapes(model)
try:
os.remove("/tmp/" + mnist_onnx_filename)
except OSError:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment