Skip to content
Snippets Groups Projects
Commit 6dd32dd1 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

switch to onnxruntime, create node model for execution

parent 3dd7b6f7
No related branches found
No related tags found
No related merge requests found
......@@ -27,10 +27,9 @@
import onnx
import onnx.helper as helper
import numpy as np
import caffe2.python.onnx.backend as be
from functools import reduce
from onnx import numpy_helper as np_helper
import onnxruntime as rt
model = onnx.load_model("model.onnx")
......@@ -42,18 +41,34 @@ def valueinfo_to_tensor(vi):
dims = [x.dim_value for x in vi.type.tensor_type.shape.dim]
return np.zeros(dims, dtype = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[vi.type.tensor_type.elem_type])
def execute_node(node, context):
"""Call Caffe2 to execute a single node. Input/output provided via context."""
def execute_node(node, context, graph):
"""Call onnxruntime to execute a single node. Input/output provided via context."""
# onnxruntime unfortunately does not implement run_node as defined by ONNX,
# it can only execute entire models -- so we create a model which solely
# consists of our current node.
node_inputs = list(filter(lambda x: x.name in node.input, graph.input))
node_inputs += list(filter(lambda x: x.name in node.input, graph.value_info))
node_outputs = list(filter(lambda x: x.name in node.output, graph.output))
node_outputs += (list(filter(lambda x: x.name in node.output, graph.value_info)))
node_graph = helper.make_graph(
nodes = [node],
name = "single-node-exec",
inputs = node_inputs,
outputs = node_outputs
)
node_model = helper.make_model(node_graph)
input_dict = dict()
for inp in node.input:
input_dict[inp] = context[inp]
print("Input shape for %s: %s" % (inp, context[inp].shape))
output_dict = be.run_node(node, input_dict)
for outp in node.output:
if output_dict[outp].shape != context[outp].shape:
raise Exception("Output shapes disagree after node execution: found %s vs expected %s" % (str(output_dict[outp].shape), str(context[outp].shape)))
context[outp] = output_dict[outp]
sess = rt.InferenceSession(node_model.SerializeToString())
output_list = sess.run(None, input_dict)
for output_ind in range(len(node.output)):
outp = node.output[output_ind]
if output_list[output_ind].shape != context[outp].shape:
raise Exception("Output shapes disagree after node execution: found %s vs expected %s" % (str(output_list[output_ind].shape.shape), str(context[outp].shape)))
context[outp] = output_list[output_ind]
print("Output shape for %s: %s" % (outp, context[outp].shape))
# first, we need to make sure that every variable required by the graph has
......@@ -89,7 +104,7 @@ for node in graph.node:
print("Input(s): " + str(node.input))
print("Output(s): " + str(node.output))
print("Attribute(s): " + str(node.attribute))
execute_node(node, execution_context)
execute_node(node, execution_context, graph)
print("Final output(s): ")
print(execution_context[graph.output[0].name])
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment