diff --git a/src/finn/onnx_exec.py b/src/finn/onnx_exec.py
index 65a5644824ef5b71cd07e26a44edd22ba03d1715..ffb1cccffe50c66aa10fb4cc8a664893fa152d4c 100644
--- a/src/finn/onnx_exec.py
+++ b/src/finn/onnx_exec.py
@@ -27,10 +27,9 @@
 import onnx
 import onnx.helper as helper
 import numpy as np
-import caffe2.python.onnx.backend as be
 from functools import reduce
 from onnx import numpy_helper as np_helper
-
+import onnxruntime as rt
 
 
 model = onnx.load_model("model.onnx")
@@ -42,18 +41,34 @@ def valueinfo_to_tensor(vi):
   dims = [x.dim_value for x in vi.type.tensor_type.shape.dim]
   return np.zeros(dims, dtype = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[vi.type.tensor_type.elem_type])
 
-def execute_node(node, context):
-  """Call Caffe2 to execute a single node. Input/output provided via context."""
+def execute_node(node, context, graph):
+  """Call onnxruntime to execute a single node. Input/output provided via context."""
 
+  # onnxruntime unfortunately does not implement run_node as defined by ONNX,
+  # it can only execute entire models -- so we create a model which solely
+  # consists of our current node.
+  node_inputs = list(filter(lambda x: x.name in node.input, graph.input))
+  node_inputs += list(filter(lambda x: x.name in node.input, graph.value_info))
+  node_outputs = list(filter(lambda x: x.name in node.output, graph.output))
+  node_outputs += (list(filter(lambda x: x.name in node.output, graph.value_info)))
+  node_graph = helper.make_graph(
+    nodes = [node],
+    name = "single-node-exec",
+    inputs = node_inputs,
+    outputs = node_outputs
+  )
+  node_model = helper.make_model(node_graph)
   input_dict = dict()
   for inp in node.input:
     input_dict[inp] = context[inp]
     print("Input shape for %s: %s" % (inp, context[inp].shape))
-  output_dict = be.run_node(node, input_dict)
-  for outp in node.output:
-    if output_dict[outp].shape != context[outp].shape:
-      raise Exception("Output shapes disagree after node execution: found %s vs expected %s" % (str(output_dict[outp].shape), str(context[outp].shape)))
-    context[outp] = output_dict[outp]
+  sess = rt.InferenceSession(node_model.SerializeToString())
+  output_list = sess.run(None, input_dict)
+  for output_ind in range(len(node.output)):
+    outp = node.output[output_ind]
+    if output_list[output_ind].shape != context[outp].shape:
+      raise Exception("Output shapes disagree after node execution: found %s vs expected %s" % (str(output_list[output_ind].shape.shape), str(context[outp].shape)))
+    context[outp] = output_list[output_ind]
     print("Output shape for %s: %s" % (outp, context[outp].shape))
 
 # first, we need to make sure that every variable required by the graph has
@@ -89,7 +104,7 @@ for node in graph.node:
   print("Input(s): " + str(node.input))
   print("Output(s): " + str(node.output))
   print("Attribute(s): " + str(node.attribute))
-  execute_node(node, execution_context)
+  execute_node(node, execution_context, graph)
 
 print("Final output(s): ")
 print(execution_context[graph.output[0].name])