diff --git a/src/finn/core/execute_custom_node.py b/src/finn/core/execute_custom_node.py
index e78e9d5de34e013fdbe43aae70c37955b95532b9..1923a64237c3277b30bb0bbc8276858203e9ef28 100644
--- a/src/finn/core/execute_custom_node.py
+++ b/src/finn/core/execute_custom_node.py
@@ -1,31 +1,16 @@
 # import onnx.helper as helper
 
-import finn.core.multithreshold as multiThresh
-from finn.core.utils import get_by_name
+import finn.custom_op.registry as registry
 
 
 def execute_custom_node(node, context, graph):
     """Call custom implementation to execute a single custom node.
     Input/output provided via context."""
-
-    if node.op_type == "MultiThreshold":
-        # save inputs
-        v = context[node.input[0]]
-        thresholds = context[node.input[1]]
-        # retrieve attributes if output scaling is used
-        try:
-            out_scale = get_by_name(node.attribute, "out_scale").f
-        except AttributeError:
-            out_scale = None
-        try:
-            out_bias = get_by_name(node.attribute, "out_bias").f
-        except AttributeError:
-            out_bias = None
-        # calculate output
-        output = multiThresh.execute(v, thresholds, out_scale, out_bias)
-        # setting context according to output
-        context[node.output[0]] = output
-
-    else:
+    op_type = node.op_type
+    try:
+        # lookup op_type in registry of CustomOps
+        inst = registry.custom_op[op_type]()
+        inst.execute_node(node, context, graph)
+    except KeyError:
         # exception if op_type is not supported
-        raise Exception("This custom node is currently not supported.")
+        raise Exception("Custom op_type %s is currently not supported." % op_type)
diff --git a/src/finn/core/multithreshold.py b/src/finn/core/multithreshold.py
deleted file mode 100755
index de38971502f1930dbe2090f3d88aad62e209478a..0000000000000000000000000000000000000000
--- a/src/finn/core/multithreshold.py
+++ /dev/null
@@ -1,63 +0,0 @@
-import numpy as np
-
-
-def compare(x, y):
-    if x >= y:
-        return 1.0
-    else:
-        return 0.0
-
-
-def execute(v, thresholds, out_scale=None, out_bias=None):
-
-    # the inputs are expected to be in the shape (N,C,H,W)
-    # N : Batch size
-    # C : Number of channels
-    # H : Heigth of the input images
-    # W : Width of the input images
-    #
-    # the thresholds are expected to be in the shape (C, B)
-    # C : Number of channels (must be the same value as C in input tensor or 1
-    #     if all channels use the same threshold value)
-    # B : Desired activation steps => i.e. for 4-bit activation, B=7 (2^(n)-1 and n=4)
-
-    # the output tensor will be scaled by out_scale and biased by out_bias
-
-    # assert threshold shape
-    is_global_threshold = thresholds.shape[0] == 1
-    assert (v.shape[1] == thresholds.shape[0]) or is_global_threshold
-
-    # save the required shape sizes for the loops (N, C and B)
-    num_batch = v.shape[0]
-    num_channel = v.shape[1]
-
-    num_act = thresholds.shape[1]
-
-    # reshape inputs to enable channel-wise reading
-    vr = v.reshape((v.shape[0], v.shape[1], -1))
-
-    # save the new shape size of the images
-    num_img_elem = vr.shape[2]
-
-    # initiate output tensor
-    ret = np.zeros_like(vr)
-
-    # iterate over thresholds channel-wise
-    for t in range(num_channel):
-        channel_thresh = thresholds[0] if is_global_threshold else thresholds[t]
-
-        # iterate over batches
-        for b in range(num_batch):
-
-            # iterate over image elements on which the thresholds should be applied
-            for elem in range(num_img_elem):
-
-                # iterate over the different thresholds that correspond to one channel
-                for a in range(num_act):
-                    # apply successive thresholding to every element of one channel
-                    ret[b][t][elem] += compare(vr[b][t][elem], channel_thresh[a])
-    if out_scale is None:
-        out_scale = 1.0
-    if out_bias is None:
-        out_bias = 0.0
-    return out_scale * ret.reshape(v.shape) + out_bias
diff --git a/src/finn/custom_op/__init__.py b/src/finn/custom_op/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd7024affedd15eb564b4e588a2b79343e95d371
--- /dev/null
+++ b/src/finn/custom_op/__init__.py
@@ -0,0 +1,18 @@
+from abc import ABC, abstractmethod
+
+
+class CustomOp(ABC):
+    def __init__(self):
+        super().__init__()
+
+    @abstractmethod
+    def make_shape_compatible_op(self, node):
+        pass
+
+    @abstractmethod
+    def infer_node_datatype(self, node, model):
+        pass
+
+    @abstractmethod
+    def execute_node(self, node, context, graph):
+        pass
diff --git a/src/finn/custom_op/multithreshold.py b/src/finn/custom_op/multithreshold.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec8f1ad4ceded8a40e182af5d79ccd803e0b8c33
--- /dev/null
+++ b/src/finn/custom_op/multithreshold.py
@@ -0,0 +1,91 @@
+import numpy as np
+import onnx.helper as helper
+
+from finn.core.datatype import DataType
+from finn.core.utils import get_by_name
+from finn.custom_op import CustomOp
+
+
+class MultiThreshold(CustomOp):
+    def make_shape_compatible_op(self, node):
+        return helper.make_node("Relu", [node.input[0]], [node.output[0]])
+
+    def infer_node_datatype(self, node, model):
+        try:
+            odt = get_by_name(node.attribute, "out_dtype").s.decode("utf-8")
+            model.set_tensor_datatype(node.output[0], DataType[odt])
+        except AttributeError:
+            # number of thresholds decides # output bits
+            # use get_smallest_possible, assuming unsigned
+            n_thres = model.get_tensor_shape(node.input[1])[1]
+            odtype = DataType.get_smallest_possible(n_thres)
+            model.set_tensor_datatype(node.output[0], odtype)
+
+    def execute_node(self, node, context, graph):
+        # save inputs
+        v = context[node.input[0]]
+        thresholds = context[node.input[1]]
+        # retrieve attributes if output scaling is used
+        try:
+            out_scale = get_by_name(node.attribute, "out_scale").f
+        except AttributeError:
+            out_scale = None
+        try:
+            out_bias = get_by_name(node.attribute, "out_bias").f
+        except AttributeError:
+            out_bias = None
+        # calculate output
+        output = self._execute(v, thresholds, out_scale, out_bias)
+        # setting context according to output
+        context[node.output[0]] = output
+
+    def _compare(self, x, y):
+        if x >= y:
+            return 1.0
+        else:
+            return 0.0
+
+    def _execute(self, v, thresholds, out_scale=None, out_bias=None):
+        # the inputs are expected to be in the shape (N,C,H,W)
+        # N : Batch size
+        # C : Number of channels
+        # H : Heigth of the input images
+        # W : Width of the input images
+        #
+        # the thresholds are expected to be in the shape (C, B)
+        # C : Number of channels (must be the same value as C in input tensor
+        #     or 1 if all channels use the same threshold value)
+        # B : Desired activation steps => i.e. for 4-bit activation,
+        #     B=7 (2^(n)-1 and n=4)
+        # the output tensor will be scaled by out_scale and biased by out_bias
+        # assert threshold shape
+        is_global_threshold = thresholds.shape[0] == 1
+        assert (v.shape[1] == thresholds.shape[0]) or is_global_threshold
+        # save the required shape sizes for the loops (N, C and B)
+        num_batch = v.shape[0]
+        num_channel = v.shape[1]
+        num_act = thresholds.shape[1]
+        # reshape inputs to enable channel-wise reading
+        vr = v.reshape((v.shape[0], v.shape[1], -1))
+        # save the new shape size of the images
+        num_img_elem = vr.shape[2]
+        # initiate output tensor
+        ret = np.zeros_like(vr)
+        # iterate over thresholds channel-wise
+        for t in range(num_channel):
+            channel_thresh = thresholds[0] if is_global_threshold else thresholds[t]
+            # iterate over batches
+            for b in range(num_batch):
+                # iterate over image elements on which the thresholds will be applied
+                for elem in range(num_img_elem):
+                    # iterate over the different thresholds for one channel
+                    for a in range(num_act):
+                        # apply successive thresholding to every element
+                        ret[b][t][elem] += self._compare(
+                            vr[b][t][elem], channel_thresh[a]
+                        )
+        if out_scale is None:
+            out_scale = 1.0
+        if out_bias is None:
+            out_bias = 0.0
+        return out_scale * ret.reshape(v.shape) + out_bias
diff --git a/src/finn/custom_op/registry.py b/src/finn/custom_op/registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..07bdbf1cab6848dc25b7bb556000829617620747
--- /dev/null
+++ b/src/finn/custom_op/registry.py
@@ -0,0 +1,8 @@
+# make sure new CustomOp subclasses are imported here so that they get
+# registered and plug in correctly into the infrastructure
+from finn.custom_op.multithreshold import MultiThreshold
+
+# create a mapping of all known CustomOp names and classes
+custom_op = {}
+
+custom_op["MultiThreshold"] = MultiThreshold
diff --git a/src/finn/transformation/infer_datatypes.py b/src/finn/transformation/infer_datatypes.py
index 2c6d75a1d029fbfea3d5b408cc6e8eb5f70dc7b5..60ea43c97d0298969ab4b3a280ed9bd3f62cbab8 100644
--- a/src/finn/transformation/infer_datatypes.py
+++ b/src/finn/transformation/infer_datatypes.py
@@ -1,5 +1,5 @@
+import finn.custom_op.registry as registry
 from finn.core.datatype import DataType
-from finn.core.utils import get_by_name
 from finn.transformation import Transformation
 
 
@@ -8,36 +8,37 @@ def _infer_node_datatype(model, node):
     changes were made."""
     idtypes = list(map(lambda x: model.get_tensor_datatype(x), node.input))
     odtypes = list(map(lambda x: model.get_tensor_datatype(x), node.output))
-    if node.op_type == "MultiThreshold":
+    op_type = node.op_type
+    if node.domain == "finn":
+        # handle DataType inference for CustomOp
         try:
-            odt = get_by_name(node.attribute, "out_dtype").s.decode("utf-8")
-            model.set_tensor_datatype(node.output[0], DataType[odt])
-        except AttributeError:
-            # number of thresholds decides # output bits
-            # use get_smallest_possible, assuming unsigned
-            n_thres = model.get_tensor_shape(node.input[1])[1]
-            odtype = DataType.get_smallest_possible(n_thres)
-            model.set_tensor_datatype(node.output[0], odtype)
-    elif node.op_type == "Sign":
-        # always produces bipolar outputs
-        model.set_tensor_datatype(node.output[0], DataType.BIPOLAR)
-    elif node.op_type == "MatMul":
-        if len(list(filter(lambda x: x == DataType.FLOAT32, idtypes))) != 0:
-            # node has at least one float input, output is also float
-            model.set_tensor_datatype(node.output[0], DataType.FLOAT32)
-        else:
-            # TODO compute minimum / maximum result to minimize bitwidth
-            # use (u)int32 accumulators for now
-            has_signed_inp = len(list(filter(lambda x: x.signed(), idtypes))) != 0
-            if has_signed_inp:
-                odtype = DataType.INT32
-            else:
-                odtype = DataType.UINT32
-            model.set_tensor_datatype(node.output[0], odtype)
+            # lookup op_type in registry of CustomOps
+            inst = registry.custom_op[op_type]()
+            inst.infer_node_datatype(node, model)
+        except KeyError:
+            # exception if op_type is not supported
+            raise Exception("Custom op_type %s is currently not supported." % op_type)
     else:
-        # unknown, assume node produces float32 outputs
-        for o in node.output:
-            model.set_tensor_datatype(o, DataType.FLOAT32)
+        if node.op_type == "Sign":
+            # always produces bipolar outputs
+            model.set_tensor_datatype(node.output[0], DataType.BIPOLAR)
+        elif node.op_type == "MatMul":
+            if len(list(filter(lambda x: x == DataType.FLOAT32, idtypes))) != 0:
+                # node has at least one float input, output is also float
+                model.set_tensor_datatype(node.output[0], DataType.FLOAT32)
+            else:
+                # TODO compute minimum / maximum result to minimize bitwidth
+                # use (u)int32 accumulators for now
+                has_signed_inp = len(list(filter(lambda x: x.signed(), idtypes))) != 0
+                if has_signed_inp:
+                    odtype = DataType.INT32
+                else:
+                    odtype = DataType.UINT32
+                model.set_tensor_datatype(node.output[0], odtype)
+        else:
+            # unknown, assume node produces float32 outputs
+            for o in node.output:
+                model.set_tensor_datatype(o, DataType.FLOAT32)
     # compare old and new output dtypes to see if anything changed
     new_odtypes = list(map(lambda x: model.get_tensor_datatype(x), node.output))
     graph_modified = new_odtypes != odtypes
diff --git a/src/finn/transformation/infer_shapes.py b/src/finn/transformation/infer_shapes.py
index 3e79587d4d6b9d021ddd11b3edb11b42e73c1297..2772ae38524f8339d7b66825ff05656aae4f606d 100644
--- a/src/finn/transformation/infer_shapes.py
+++ b/src/finn/transformation/infer_shapes.py
@@ -1,6 +1,6 @@
-import onnx.helper as helper
 import onnx.shape_inference as si
 
+import finn.custom_op.registry as registry
 from finn.core.modelwrapper import ModelWrapper
 from finn.transformation import Transformation
 
@@ -9,10 +9,14 @@ def _make_shape_compatible_op(node):
     """Return a shape-compatible non-FINN op for a given FINN op. Used for
     shape inference with custom ops."""
     assert node.domain == "finn"
-    if node.op_type == "MultiThreshold":
-        return helper.make_node("Relu", [node.input[0]], [node.output[0]])
-    else:
-        raise Exception("No known shape-compatible op for %s" % node.op_type)
+    op_type = node.op_type
+    try:
+        # lookup op_type in registry of CustomOps
+        inst = registry.custom_op[op_type]()
+        return inst.make_shape_compatible_op(node)
+    except KeyError:
+        # exception if op_type is not supported
+        raise Exception("Custom op_type %s is currently not supported." % op_type)
 
 
 def _hide_finn_ops(model):
diff --git a/tests/test_multi_thresholding.py b/tests/test_multi_thresholding.py
index 5fd7b4309d68ad568773dbdb39d0df979a767e73..b16d1602a65284ba366e9c331370c15096981ece 100644
--- a/tests/test_multi_thresholding.py
+++ b/tests/test_multi_thresholding.py
@@ -1,6 +1,6 @@
 import numpy as np
 
-import finn.core.multithreshold as multi_thresh
+from finn.custom_op.multithreshold import MultiThreshold
 
 
 def test_execute_multi_thresholding():
@@ -194,10 +194,11 @@ def test_execute_multi_thresholding():
         ),
     )
 
-    results = multi_thresh.execute(inputs, thresholds)
+    multi_thresh = MultiThreshold()
+    results = multi_thresh._execute(inputs, thresholds)
 
     assert (results == outputs).all()
 
-    results_scaled = multi_thresh.execute(inputs, thresholds, 2.0, -1.0)
+    results_scaled = multi_thresh._execute(inputs, thresholds, 2.0, -1.0)
     outputs_scaled = 2.0 * outputs - 1.0
     assert (results_scaled == outputs_scaled).all()