diff --git a/src/finn/onnx_exec.py b/src/finn/onnx_exec.py index 65a5644824ef5b71cd07e26a44edd22ba03d1715..ffb1cccffe50c66aa10fb4cc8a664893fa152d4c 100644 --- a/src/finn/onnx_exec.py +++ b/src/finn/onnx_exec.py @@ -27,10 +27,9 @@ import onnx import onnx.helper as helper import numpy as np -import caffe2.python.onnx.backend as be from functools import reduce from onnx import numpy_helper as np_helper - +import onnxruntime as rt model = onnx.load_model("model.onnx") @@ -42,18 +41,34 @@ def valueinfo_to_tensor(vi): dims = [x.dim_value for x in vi.type.tensor_type.shape.dim] return np.zeros(dims, dtype = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[vi.type.tensor_type.elem_type]) -def execute_node(node, context): - """Call Caffe2 to execute a single node. Input/output provided via context.""" +def execute_node(node, context, graph): + """Call onnxruntime to execute a single node. Input/output provided via context.""" + # onnxruntime unfortunately does not implement run_node as defined by ONNX, + # it can only execute entire models -- so we create a model which solely + # consists of our current node. + node_inputs = list(filter(lambda x: x.name in node.input, graph.input)) + node_inputs += list(filter(lambda x: x.name in node.input, graph.value_info)) + node_outputs = list(filter(lambda x: x.name in node.output, graph.output)) + node_outputs += (list(filter(lambda x: x.name in node.output, graph.value_info))) + node_graph = helper.make_graph( + nodes = [node], + name = "single-node-exec", + inputs = node_inputs, + outputs = node_outputs + ) + node_model = helper.make_model(node_graph) input_dict = dict() for inp in node.input: input_dict[inp] = context[inp] print("Input shape for %s: %s" % (inp, context[inp].shape)) - output_dict = be.run_node(node, input_dict) - for outp in node.output: - if output_dict[outp].shape != context[outp].shape: - raise Exception("Output shapes disagree after node execution: found %s vs expected %s" % (str(output_dict[outp].shape), str(context[outp].shape))) - context[outp] = output_dict[outp] + sess = rt.InferenceSession(node_model.SerializeToString()) + output_list = sess.run(None, input_dict) + for output_ind in range(len(node.output)): + outp = node.output[output_ind] + if output_list[output_ind].shape != context[outp].shape: + raise Exception("Output shapes disagree after node execution: found %s vs expected %s" % (str(output_list[output_ind].shape.shape), str(context[outp].shape))) + context[outp] = output_list[output_ind] print("Output shape for %s: %s" % (outp, context[outp].shape)) # first, we need to make sure that every variable required by the graph has @@ -89,7 +104,7 @@ for node in graph.node: print("Input(s): " + str(node.input)) print("Output(s): " + str(node.output)) print("Attribute(s): " + str(node.attribute)) - execute_node(node, execution_context) + execute_node(node, execution_context, graph) print("Final output(s): ") print(execution_context[graph.output[0].name])