Skip to content
Snippets Groups Projects
Commit c81c6a70 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

Merge branch 'feature/subgraph_exec' into dev

parents 15d8aae0 c81a7b55
No related branches found
No related tags found
No related merge requests found
......@@ -108,7 +108,9 @@ def execute_node(node, context, graph):
context[outp] = output_list[list_ind]
def execute_onnx(model, input_dict, return_full_exec_context=False):
def execute_onnx(
model, input_dict, return_full_exec_context=False, start_node=None, end_node=None
):
"""Executes given ONNX ModelWrapper with given named inputs.
If return_full_exec_context is False, a dict of named outputs is returned
......@@ -116,7 +118,12 @@ def execute_onnx(model, input_dict, return_full_exec_context=False):
If return return_full_exec_context is True, the full set of tensors used by
the execution (including inputs, weights, activations and final outputs)
will be returned as a dict."""
will be returned as a dict.
When start_node and end_node are set to None, the whole graph is executed.
If they are set to particular ONNX nodes, only the subgraph between (and
including) those nodes is executed.
"""
if not model.check_all_tensor_shapes_specified():
raise Exception("Found unspecified tensor shapes, try infer_shapes")
......@@ -159,7 +166,17 @@ def execute_onnx(model, input_dict, return_full_exec_context=False):
# execute the model node by node
# we can simply walk down the list since the ONNX spec guarantees that it is
# topologically sorted
for node in graph.node:
subgraph = []
if start_node is None:
start_node = model.graph.node[0]
if end_node is None:
end_node = model.graph.node[-1]
# select the nodes between specified start/end nodes
start_ind = model.get_node_index(start_node)
end_ind = model.get_node_index(end_node) + 1
assert end_ind >= start_ind, "Start/end nodes must define valid subgraph"
subgraph = graph.node[start_ind:end_ind]
for node in subgraph:
if get_sanitize_quant_tensors() != 0:
# round input values to match quantization annotation
execution_context = sanitize_quant_values(
......
......@@ -49,19 +49,33 @@ def test_mnist_onnx_download_extract_run():
raw_o = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/output_0.pb")
input_tensor = onnx.load_tensor_from_string(raw_i)
output_tensor = onnx.load_tensor_from_string(raw_o)
# run using FINN-based execution
# run using FINN-based execution (full graph)
input_dict = {"Input3": np_helper.to_array(input_tensor)}
output_dict = oxe.execute_onnx(model, input_dict)
output_dict = oxe.execute_onnx(model, input_dict, return_full_exec_context=True)
assert np.isclose(
np_helper.to_array(output_tensor), output_dict["Plus214_Output_0"], atol=1e-3
).all()
# test subgraph execution
start_node = model.graph.node[1]
end_node = model.graph.node[3]
subgraph_i_dict = {start_node.input[0]: output_dict[start_node.input[0]]}
subgraph_o_dict = oxe.execute_onnx(
model,
subgraph_i_dict,
return_full_exec_context=True,
start_node=start_node,
end_node=end_node,
)
assert np.isclose(
subgraph_o_dict[end_node.output[0]], output_dict[end_node.output[0]], atol=1e-3
).all()
def test_onnx_exec_internal_rounding():
inp0 = onnx.helper.make_tensor_value_info("inp0", onnx.TensorProto.FLOAT, [2, 2])
inp1 = onnx.helper.make_tensor_value_info("inp1", onnx.TensorProto.FLOAT, [1])
outp = onnx.helper.make_tensor_value_info("outp", onnx.TensorProto.FLOAT, [2, 2])
mul_node = onnx.helper.make_node("Mul", inputs=["inp0", "inp1"], outputs=["outp"],)
mul_node = onnx.helper.make_node("Mul", inputs=["inp0", "inp1"], outputs=["outp"])
graph = onnx.helper.make_graph(
nodes=[mul_node], name="mul_graph", inputs=[inp0, inp1], outputs=[outp]
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment