diff --git a/src/finn/core/onnx_exec.py b/src/finn/core/onnx_exec.py index efdfaa19d9f9e5dfa41911a2184e989337b3d9c2..7c3123cd5eb29a54dc5cbfb912225ad3fdb0f219 100644 --- a/src/finn/core/onnx_exec.py +++ b/src/finn/core/onnx_exec.py @@ -108,7 +108,9 @@ def execute_node(node, context, graph): context[outp] = output_list[list_ind] -def execute_onnx(model, input_dict, return_full_exec_context=False): +def execute_onnx( + model, input_dict, return_full_exec_context=False, start_node=None, end_node=None +): """Executes given ONNX ModelWrapper with given named inputs. If return_full_exec_context is False, a dict of named outputs is returned @@ -116,7 +118,12 @@ def execute_onnx(model, input_dict, return_full_exec_context=False): If return return_full_exec_context is True, the full set of tensors used by the execution (including inputs, weights, activations and final outputs) - will be returned as a dict.""" + will be returned as a dict. + + When start_node and end_node are set to None, the whole graph is executed. + If they are set to particular ONNX nodes, only the subgraph between (and + including) those nodes is executed. + """ if not model.check_all_tensor_shapes_specified(): raise Exception("Found unspecified tensor shapes, try infer_shapes") @@ -159,7 +166,17 @@ def execute_onnx(model, input_dict, return_full_exec_context=False): # execute the model node by node # we can simply walk down the list since the ONNX spec guarantees that it is # topologically sorted - for node in graph.node: + subgraph = [] + if start_node is None: + start_node = model.graph.node[0] + if end_node is None: + end_node = model.graph.node[-1] + # select the nodes between specified start/end nodes + start_ind = model.get_node_index(start_node) + end_ind = model.get_node_index(end_node) + 1 + assert end_ind >= start_ind, "Start/end nodes must define valid subgraph" + subgraph = graph.node[start_ind:end_ind] + for node in subgraph: if get_sanitize_quant_tensors() != 0: # round input values to match quantization annotation execution_context = sanitize_quant_values( diff --git a/tests/core/test_basic_onnx_exec.py b/tests/core/test_basic_onnx_exec.py index 7b0412432cc6360cb9c42d66417bd187ed142563..ddb2cbfc40c7647970f0c51ecb95340e7d1dddae 100644 --- a/tests/core/test_basic_onnx_exec.py +++ b/tests/core/test_basic_onnx_exec.py @@ -49,19 +49,33 @@ def test_mnist_onnx_download_extract_run(): raw_o = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/output_0.pb") input_tensor = onnx.load_tensor_from_string(raw_i) output_tensor = onnx.load_tensor_from_string(raw_o) - # run using FINN-based execution + # run using FINN-based execution (full graph) input_dict = {"Input3": np_helper.to_array(input_tensor)} - output_dict = oxe.execute_onnx(model, input_dict) + output_dict = oxe.execute_onnx(model, input_dict, return_full_exec_context=True) assert np.isclose( np_helper.to_array(output_tensor), output_dict["Plus214_Output_0"], atol=1e-3 ).all() + # test subgraph execution + start_node = model.graph.node[1] + end_node = model.graph.node[3] + subgraph_i_dict = {start_node.input[0]: output_dict[start_node.input[0]]} + subgraph_o_dict = oxe.execute_onnx( + model, + subgraph_i_dict, + return_full_exec_context=True, + start_node=start_node, + end_node=end_node, + ) + assert np.isclose( + subgraph_o_dict[end_node.output[0]], output_dict[end_node.output[0]], atol=1e-3 + ).all() def test_onnx_exec_internal_rounding(): inp0 = onnx.helper.make_tensor_value_info("inp0", onnx.TensorProto.FLOAT, [2, 2]) inp1 = onnx.helper.make_tensor_value_info("inp1", onnx.TensorProto.FLOAT, [1]) outp = onnx.helper.make_tensor_value_info("outp", onnx.TensorProto.FLOAT, [2, 2]) - mul_node = onnx.helper.make_node("Mul", inputs=["inp0", "inp1"], outputs=["outp"],) + mul_node = onnx.helper.make_node("Mul", inputs=["inp0", "inp1"], outputs=["outp"]) graph = onnx.helper.make_graph( nodes=[mul_node], name="mul_graph", inputs=[inp0, inp1], outputs=[outp] )