diff --git a/src/finn/core/execute_custom_node.py b/src/finn/core/execute_custom_node.py
index 678794a53b4cd3a0fbfbc658ac80c8e8f687ff05..1d7c0f6b28be6dde803ccb6d0370eb8c69154e6a 100644
--- a/src/finn/core/execute_custom_node.py
+++ b/src/finn/core/execute_custom_node.py
@@ -1,42 +1,23 @@
-#import onnx.helper as helper
+# import onnx.helper as helper
 
 import finn.core.MultiThreshold as multiThresh
 
-def execute_custom_node(node, context, graph) :
-    """Call custom implementation to execute a single custom node. Input/output provided via context."""
-    
-    if node.op_type == 'MultiThreshold' :
-        node_inputs = list(filter(lambda x: x.name in node.input, graph.input))
-        
-        # extract shape size of input tensors to determine which is input and which thresholds
-        shape_dict = {}
-        for inputs in node_inputs :
-            shape_dict[inputs.name]=0
-            for dim_value in inputs.type.tensor_type.shape.dim :
-                shape_dict[inputs.name] += 1
-        
-        # store input values in right tensors according to the shape size
-        for inputs in node_inputs :
-            if shape_dict[inputs.name] == 4 :
-                v = context[inputs.name]
-            else :
-                thresholds = context[inputs.name]
-        
-        output_list = multiThresh.execute(v, thresholds) 
-        
-        for output_ind in node.output:
-            print(output_ind)
-            #outp = node.output[output_ind]
-            #if output_list[output_ind].shape != context[outp].shape:
-            #    raise Exception(
-            #        "Output shapes disagree after node execution: found %s vs expected %s"
-            #        % (str(output_list[output_ind].shape.shape), str(context[outp].shape))
-            #    )
-            #context[outp] = output_list[output_ind]
-
-
-    else :
-        raise Exception(
-                "This custom node is currently not supported."
-        )
 
+def execute_custom_node(node, context, graph):
+    """Call custom implementation to execute a single custom node.
+    Input/output provided via context."""
+
+    if node.op_type == "MultiThreshold":
+
+        v = context[node.input[0]]
+        thresholds = context[node.input[1]]
+
+        # calculate output
+        output = multiThresh.execute(v, thresholds)
+
+        # setting context according to output
+        context[node.output[0]] = output
+
+    else:
+        # exception if op_type is not supported
+        raise Exception("This custom node is currently not supported.")
diff --git a/src/finn/core/onnx_exec.py b/src/finn/core/onnx_exec.py
index 8afced5c8f117e01253423f952201ce14124456e..0e4f0237d64331a02cf6b85d5d0ea07645957abe 100644
--- a/src/finn/core/onnx_exec.py
+++ b/src/finn/core/onnx_exec.py
@@ -29,41 +29,53 @@ import copy
 import onnx.helper as helper
 import onnxruntime as rt
 
+import finn.core.execute_custom_node as ex_cu_node
+
 
 def execute_node(node, context, graph):
     """Call onnxruntime to execute a single node. Input/output provided via context."""
 
-    # onnxruntime unfortunately does not implement run_node as defined by ONNX,
-    # it can only execute entire models -- so we create a model which solely
-    # consists of our current node.
-    node_inputs = list(filter(lambda x: x.name in node.input, graph.input))
-    node_inputs += list(filter(lambda x: x.name in node.input, graph.value_info))
-    node_outputs = list(filter(lambda x: x.name in node.output, graph.output))
-    node_outputs += list(filter(lambda x: x.name in node.output, graph.value_info))
-    node_graph = helper.make_graph(
-        nodes=[node], name="single-node-exec", inputs=node_inputs, outputs=node_outputs
-    )
-    node_model = helper.make_model(node_graph)
-    input_dict = dict()
-    for inp in node.input:
-        input_dict[inp] = context[inp]
+    # run node with custom function or by using onnxruntime
+
     if node.domain == "finn":
-        print("Domain is finn")
-        #execute_costum_node(node, context, graph)
+
+        ex_cu_node.execute_custom_node(node, context, graph)
 
     else:
-        print("Domain is empty")
+
+        # onnxruntime unfortunately does not implement run_node as defined by ONNX,
+        # it can only execute entire models -- so we create a model which solely
+        # consists of our current node.
+        node_inputs = list(filter(lambda x: x.name in node.input, graph.input))
+        node_inputs += list(filter(lambda x: x.name in node.input, graph.value_info))
+        node_outputs = list(filter(lambda x: x.name in node.output, graph.output))
+        node_outputs += list(filter(lambda x: x.name in node.output, graph.value_info))
+        node_graph = helper.make_graph(
+            nodes=[node],
+            name="single-node-exec",
+            inputs=node_inputs,
+            outputs=node_outputs,
+        )
+        node_model = helper.make_model(node_graph)
+        input_dict = dict()
+        for inp in node.input:
+            input_dict[inp] = context[inp]
+
         sess = rt.InferenceSession(node_model.SerializeToString())
         output_list = sess.run(None, input_dict)
 
-    for output_ind in range(len(node.output)):
-        outp = node.output[output_ind]
-        if output_list[output_ind].shape != context[outp].shape:
-            raise Exception(
-                "Output shapes disagree after node execution: found %s vs expected %s"
-                % (str(output_list[output_ind].shape.shape), str(context[outp].shape))
-            )
-        context[outp] = output_list[output_ind]
+        for output_ind in range(len(node.output)):
+            outp = node.output[output_ind]
+            if output_list[output_ind].shape != context[outp].shape:
+                raise Exception(
+                    """Output shapes disagree after node execution:
+                    found %s vs expected %s"""
+                    % (
+                        str(output_list[output_ind].shape.shape),
+                        str(context[outp].shape),
+                    )
+                )
+            context[outp] = output_list[output_ind]
 
 
 def execute_onnx(model, input_dict, return_full_exec_context=False):
diff --git a/tests/test_custom_onnx_exec.py b/tests/test_custom_onnx_exec.py
index 828e10d1d4a00cbf3970c15ca22d2e0cd85ff9dd..f355653311421f41ca5d1c7687d3571b14e6264c 100644
--- a/tests/test_custom_onnx_exec.py
+++ b/tests/test_custom_onnx_exec.py
@@ -1,13 +1,12 @@
 import numpy as np
-
-import onnx
-from onnx import helper
-from onnx import AttributeProto, TensorProto, GraphProto
+# import onnx
+# from onnx import AttributeProto, GraphProto, TensorProto, helper
+from onnx import TensorProto, helper
 
 import finn.core.execute_custom_node as ex_cu_node
 
 
-def test_execute_custom_node() :
+def test_execute_custom_node():
     inputs = np.ndarray(
         shape=(6, 3, 2, 2),
         buffer=np.array(
@@ -117,31 +116,100 @@ def test_execute_custom_node() :
         ),
     )
 
-    v = helper.make_tensor_value_info('v', TensorProto.FLOAT, [6, 3, 2, 2])
-    thresholds = helper.make_tensor_value_info('thresholds', TensorProto.FLOAT, [3, 7])
-    out = helper.make_tensor_value_info('out', TensorProto.FLOAT, [6, 3, 2, 2])
+    v = helper.make_tensor_value_info("v", TensorProto.FLOAT, [6, 3, 2, 2])
+    thresholds = helper.make_tensor_value_info("thresholds", TensorProto.FLOAT, [3, 7])
+    out = helper.make_tensor_value_info("out", TensorProto.FLOAT, [6, 3, 2, 2])
 
     node_def = helper.make_node(
-            'MultiThreshold',
-            ['v', 'thresholds'],
-            ['out'],
-            domain='finn'
-            )
-
+        "MultiThreshold", ["v", "thresholds"], ["out"], domain="finn"
+    )
 
-    graph_def = helper.make_graph(
-        [node_def],
-        "test_model",
-        [v, thresholds],
-        [out]
-        )
-    
-    model = helper.make_model(graph_def, producer_name='onnx-example')
+    graph_def = helper.make_graph([node_def], "test_model", [v, thresholds], [out])
 
     execution_context = {}
-    execution_context['v'] = inputs
-    execution_context['thresholds'] = threshold_values
+    execution_context["v"] = inputs
+    execution_context["thresholds"] = threshold_values
 
-    print(ex_cu_node.execute_custom_node(node_def, execution_context, graph_def))
+    ex_cu_node.execute_custom_node(node_def, execution_context, graph_def)
 
+    outputs = np.ndarray(
+        shape=(6, 3, 2, 2),
+        buffer=np.array(
+            [
+                4.0,
+                3.0,
+                1.0,
+                4.0,
+                5.0,
+                2.0,
+                2.0,
+                4.0,
+                3.0,
+                3.0,
+                3.0,
+                1.0,
+                5.0,
+                0.0,
+                1.0,
+                4.0,
+                1.0,
+                4.0,
+                6.0,
+                7.0,
+                7.0,
+                1.0,
+                1.0,
+                3.0,
+                3.0,
+                3.0,
+                1.0,
+                3.0,
+                4.0,
+                2.0,
+                3.0,
+                7.0,
+                3.0,
+                3.0,
+                1.0,
+                1.0,
+                7.0,
+                5.0,
+                4.0,
+                6.0,
+                2.0,
+                2.0,
+                1.0,
+                1.0,
+                2.0,
+                1.0,
+                3.0,
+                3.0,
+                2.0,
+                5.0,
+                3.0,
+                3.0,
+                4.0,
+                5.0,
+                7.0,
+                3.0,
+                1.0,
+                3.0,
+                2.0,
+                1.0,
+                4.0,
+                6.0,
+                6.0,
+                0.0,
+                1.0,
+                1.0,
+                3.0,
+                6.0,
+                1.0,
+                1.0,
+                6.0,
+                7.0,
+            ]
+        ),
+    )
 
+    assert (execution_context["out"] == outputs).all()