Skip to content
Snippets Groups Projects
Commit c5ce46c4 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Core] allow subgraph exec in execute_onnx

parent 15d8aae0
No related branches found
No related tags found
No related merge requests found
......@@ -108,7 +108,9 @@ def execute_node(node, context, graph):
context[outp] = output_list[list_ind]
def execute_onnx(model, input_dict, return_full_exec_context=False):
def execute_onnx(
model, input_dict, return_full_exec_context=False, start_node=None, end_node=None
):
"""Executes given ONNX ModelWrapper with given named inputs.
If return_full_exec_context is False, a dict of named outputs is returned
......@@ -116,7 +118,12 @@ def execute_onnx(model, input_dict, return_full_exec_context=False):
If return return_full_exec_context is True, the full set of tensors used by
the execution (including inputs, weights, activations and final outputs)
will be returned as a dict."""
will be returned as a dict.
When start_node and end_node are set to None, the whole graph is executed.
If they are set to particular ONNX nodes, only the subgraph between (and
including) those nodes is executed.
"""
if not model.check_all_tensor_shapes_specified():
raise Exception("Found unspecified tensor shapes, try infer_shapes")
......@@ -159,7 +166,17 @@ def execute_onnx(model, input_dict, return_full_exec_context=False):
# execute the model node by node
# we can simply walk down the list since the ONNX spec guarantees that it is
# topologically sorted
for node in graph.node:
subgraph = []
if start_node is None:
start_node = model.graph.node[0]
if end_node is None:
end_node = model.graph.node[-1]
# select the nodes between specified start/end nodes
start_ind = model.get_node_index(start_node)
end_ind = model.get_node_index(end_node) + 1
assert end_ind >= start_ind, "Start/end nodes must define valid subgraph"
subgraph = graph.node[start_ind:end_ind]
for node in subgraph:
if get_sanitize_quant_tensors() != 0:
# round input values to match quantization annotation
execution_context = sanitize_quant_values(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment