Skip to content
Snippets Groups Projects
Commit 79393378 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

cleanup onnx_exec, add a basic test

parent fc3ec9c7
No related branches found
No related tags found
No related merge requests found
......@@ -4,3 +4,4 @@ onnx
onnxruntime
pre-commit
sphinx
wget
......@@ -27,12 +27,10 @@
import numpy as np
import onnx
import onnx.helper as helper
import onnx.shape_inference as si
import onnxruntime as rt
from onnx import numpy_helper as np_helper
model = onnx.load_model("model.onnx")
graph = model.graph
def valueinfo_to_tensor(vi):
"""Creates an all-zeroes numpy tensor from a ValueInfoProto."""
......@@ -60,7 +58,6 @@ def execute_node(node, context, graph):
input_dict = dict()
for inp in node.input:
input_dict[inp] = context[inp]
print("Input shape for %s: %s" % (inp, context[inp].shape))
sess = rt.InferenceSession(node_model.SerializeToString())
output_list = sess.run(None, input_dict)
for output_ind in range(len(node.output)):
......@@ -71,43 +68,59 @@ def execute_node(node, context, graph):
% (str(output_list[output_ind].shape.shape), str(context[outp].shape))
)
context[outp] = output_list[output_ind]
print("Output shape for %s: %s" % (outp, context[outp].shape))
# first, we need to make sure that every variable required by the graph has
# some buffer associated with it. this includes graph inputs (which includes
# the input data as well as the trained parameters) and the graph ValueInfo
# (intermediate tensors between layers)
# we'll keep all our buffers in this dict here:
execution_context = dict()
# make empty tensors for all the graph inputs and outputs
for vi in graph.input:
new_tensor = valueinfo_to_tensor(vi)
execution_context[vi.name] = new_tensor
for vi in graph.output:
new_tensor = valueinfo_to_tensor(vi)
execution_context[vi.name] = new_tensor
# make empty tensors for all intermediate buffers
# TODO are we guaranteed to have the .value_info filled?
# do we need to call ONNX shape inference first?
for vi in graph.value_info:
new_tensor = valueinfo_to_tensor(vi)
execution_context[vi.name] = new_tensor
# fill in the constants provided by the initializers (TensorProto to npy)
for t in graph.initializer:
execution_context[t.name] = np_helper.to_array(t)
# now call each node in the graph nodes list
# we can simply walk down the list since the ONNX spec guarantees that it is
# topologically sorted
all_used_ops = set()
for node in graph.node:
print("Node name: %s Type: %s" % (node.name, node.op_type))
all_used_ops.add(node.op_type)
print("Input(s): " + str(node.input))
print("Output(s): " + str(node.output))
print("Attribute(s): " + str(node.attribute))
execute_node(node, execution_context, graph)
def execute_onnx(model, input_dict):
"""Execute given ONNX model with given named inputs to return named outputs."""
print("Final output(s): ")
print(execution_context[graph.output[0].name])
# call ONNX shape inference to make sure we have value_info fields for all
# the intermediate tensors in the graph
model = si.infer_shapes(model)
graph = model.graph
# first, we need to make sure that every variable required by the graph has
# some buffer associated with it. this includes graph inputs (which includes
# the input data as well as the trained parameters) and the graph ValueInfo
# (intermediate tensors between layers)
# we'll keep all our buffers in this dict here:
execution_context = dict()
# make empty tensors for all the graph inputs and outputs
for vi in graph.input:
new_tensor = valueinfo_to_tensor(vi)
execution_context[vi.name] = new_tensor
for vi in graph.output:
new_tensor = valueinfo_to_tensor(vi)
execution_context[vi.name] = new_tensor
# make empty tensors for all intermediate buffers
for vi in graph.value_info:
new_tensor = valueinfo_to_tensor(vi)
execution_context[vi.name] = new_tensor
# fill in the constants provided by the initializers (TensorProto to npy)
for t in graph.initializer:
execution_context[t.name] = np_helper.to_array(t)
# fill in any inputs provided to this function
for inp_name in input_dict.keys():
if inp_name in execution_context:
if execution_context[inp_name].shape == input_dict[inp_name].shape:
execution_context[inp_name] = input_dict[inp_name]
else:
raise Exception(
"Shape mismatch for provided input %s: found %s expected %s "
% (
inp_name,
str(execution_context[inp_name].shape),
str(input_dict[inp_name].shape),
)
)
else:
raise Exception("Provided input not found in graph context: %s" % inp_name)
# now call each node in the graph nodes list
# we can simply walk down the list since the ONNX spec guarantees that it is
# topologically sorted
for node in graph.node:
execute_node(node, execution_context, graph)
# provide outputs as dict
output_dict = dict()
for out_tensor in graph.output:
out_name = out_tensor.name
output_dict[out_name] = execution_context[out_name]
return output_dict
import hashlib
import os
import shutil
import numpy as np
import onnx
import onnx.numpy_helper as np_helper
import wget
import finn.core.onnx_exec as oxe
mnist_onnx_url_base = "https://onnxzoo.blob.core.windows.net/models/opset_8/mnist"
mnist_onnx_filename = "mnist.tar.gz"
mnist_onnx_local_dir = "/tmp/mnist_onnx"
def test_mnist_onnx_download_extract_run():
dl_ret = wget.download(mnist_onnx_url_base + "/" + mnist_onnx_filename, out="/tmp")
shutil.unpack_archive(dl_ret, mnist_onnx_local_dir)
with open(mnist_onnx_local_dir + "/mnist/model.onnx", "rb") as f:
assert hashlib.md5(f.read()).hexdigest() == "d7cd24a0a76cd492f31065301d468c3d"
# load the onnx model
model = onnx.load(mnist_onnx_local_dir + "/mnist/model.onnx")
# load one of the test vectors
input_tensor = onnx.TensorProto()
output_tensor = onnx.TensorProto()
with open(mnist_onnx_local_dir + "/mnist/test_data_set_0/input_0.pb", "rb") as f:
input_tensor.ParseFromString(f.read())
with open(mnist_onnx_local_dir + "/mnist/test_data_set_0/output_0.pb", "rb") as f:
output_tensor.ParseFromString(f.read())
# run using FINN-based execution
input_dict = {"Input3": np_helper.to_array(input_tensor)}
output_dict = oxe.execute_onnx(model, input_dict)
assert np.isclose(
np_helper.to_array(output_tensor), output_dict["Plus214_Output_0"]
).all()
# remove the downloaded model and extracted files
os.remove(dl_ret)
shutil.rmtree(mnist_onnx_local_dir)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment