diff --git a/src/finn/core/onnx_exec.py b/src/finn/core/onnx_exec.py
index efdfaa19d9f9e5dfa41911a2184e989337b3d9c2..7c3123cd5eb29a54dc5cbfb912225ad3fdb0f219 100644
--- a/src/finn/core/onnx_exec.py
+++ b/src/finn/core/onnx_exec.py
@@ -108,7 +108,9 @@ def execute_node(node, context, graph):
                 context[outp] = output_list[list_ind]
 
 
-def execute_onnx(model, input_dict, return_full_exec_context=False):
+def execute_onnx(
+    model, input_dict, return_full_exec_context=False, start_node=None, end_node=None
+):
     """Executes given ONNX ModelWrapper with given named inputs.
 
     If return_full_exec_context is False, a dict of named outputs is returned
@@ -116,7 +118,12 @@ def execute_onnx(model, input_dict, return_full_exec_context=False):
 
     If return return_full_exec_context is True, the full set of tensors used by
     the execution (including inputs, weights, activations and final outputs)
-    will be returned as a dict."""
+    will be returned as a dict.
+
+    When start_node and end_node are set to None, the whole graph is executed.
+    If they are set to particular ONNX nodes, only the subgraph between (and
+    including) those nodes is executed.
+    """
 
     if not model.check_all_tensor_shapes_specified():
         raise Exception("Found unspecified tensor shapes, try infer_shapes")
@@ -159,7 +166,17 @@ def execute_onnx(model, input_dict, return_full_exec_context=False):
         # execute the model node by node
         # we can simply walk down the list since the ONNX spec guarantees that it is
         # topologically sorted
-        for node in graph.node:
+        subgraph = []
+        if start_node is None:
+            start_node = model.graph.node[0]
+        if end_node is None:
+            end_node = model.graph.node[-1]
+        # select the nodes between specified start/end nodes
+        start_ind = model.get_node_index(start_node)
+        end_ind = model.get_node_index(end_node) + 1
+        assert end_ind >= start_ind, "Start/end nodes must define valid subgraph"
+        subgraph = graph.node[start_ind:end_ind]
+        for node in subgraph:
             if get_sanitize_quant_tensors() != 0:
                 # round input values to match quantization annotation
                 execution_context = sanitize_quant_values(
diff --git a/tests/core/test_basic_onnx_exec.py b/tests/core/test_basic_onnx_exec.py
index 7b0412432cc6360cb9c42d66417bd187ed142563..ddb2cbfc40c7647970f0c51ecb95340e7d1dddae 100644
--- a/tests/core/test_basic_onnx_exec.py
+++ b/tests/core/test_basic_onnx_exec.py
@@ -49,19 +49,33 @@ def test_mnist_onnx_download_extract_run():
     raw_o = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/output_0.pb")
     input_tensor = onnx.load_tensor_from_string(raw_i)
     output_tensor = onnx.load_tensor_from_string(raw_o)
-    # run using FINN-based execution
+    # run using FINN-based execution (full graph)
     input_dict = {"Input3": np_helper.to_array(input_tensor)}
-    output_dict = oxe.execute_onnx(model, input_dict)
+    output_dict = oxe.execute_onnx(model, input_dict, return_full_exec_context=True)
     assert np.isclose(
         np_helper.to_array(output_tensor), output_dict["Plus214_Output_0"], atol=1e-3
     ).all()
+    # test subgraph execution
+    start_node = model.graph.node[1]
+    end_node = model.graph.node[3]
+    subgraph_i_dict = {start_node.input[0]: output_dict[start_node.input[0]]}
+    subgraph_o_dict = oxe.execute_onnx(
+        model,
+        subgraph_i_dict,
+        return_full_exec_context=True,
+        start_node=start_node,
+        end_node=end_node,
+    )
+    assert np.isclose(
+        subgraph_o_dict[end_node.output[0]], output_dict[end_node.output[0]], atol=1e-3
+    ).all()
 
 
 def test_onnx_exec_internal_rounding():
     inp0 = onnx.helper.make_tensor_value_info("inp0", onnx.TensorProto.FLOAT, [2, 2])
     inp1 = onnx.helper.make_tensor_value_info("inp1", onnx.TensorProto.FLOAT, [1])
     outp = onnx.helper.make_tensor_value_info("outp", onnx.TensorProto.FLOAT, [2, 2])
-    mul_node = onnx.helper.make_node("Mul", inputs=["inp0", "inp1"], outputs=["outp"],)
+    mul_node = onnx.helper.make_node("Mul", inputs=["inp0", "inp1"], outputs=["outp"])
     graph = onnx.helper.make_graph(
         nodes=[mul_node], name="mul_graph", inputs=[inp0, inp1], outputs=[outp]
     )