diff --git a/src/finn/transformation/infer_datatypes.py b/src/finn/transformation/infer_datatypes.py
index 2c6d75a1d029fbfea3d5b408cc6e8eb5f70dc7b5..60ea43c97d0298969ab4b3a280ed9bd3f62cbab8 100644
--- a/src/finn/transformation/infer_datatypes.py
+++ b/src/finn/transformation/infer_datatypes.py
@@ -1,5 +1,5 @@
+import finn.custom_op.registry as registry
 from finn.core.datatype import DataType
-from finn.core.utils import get_by_name
 from finn.transformation import Transformation
 
 
@@ -8,36 +8,37 @@ def _infer_node_datatype(model, node):
     changes were made."""
     idtypes = list(map(lambda x: model.get_tensor_datatype(x), node.input))
     odtypes = list(map(lambda x: model.get_tensor_datatype(x), node.output))
-    if node.op_type == "MultiThreshold":
+    op_type = node.op_type
+    if node.domain == "finn":
+        # handle DataType inference for CustomOp
         try:
-            odt = get_by_name(node.attribute, "out_dtype").s.decode("utf-8")
-            model.set_tensor_datatype(node.output[0], DataType[odt])
-        except AttributeError:
-            # number of thresholds decides # output bits
-            # use get_smallest_possible, assuming unsigned
-            n_thres = model.get_tensor_shape(node.input[1])[1]
-            odtype = DataType.get_smallest_possible(n_thres)
-            model.set_tensor_datatype(node.output[0], odtype)
-    elif node.op_type == "Sign":
-        # always produces bipolar outputs
-        model.set_tensor_datatype(node.output[0], DataType.BIPOLAR)
-    elif node.op_type == "MatMul":
-        if len(list(filter(lambda x: x == DataType.FLOAT32, idtypes))) != 0:
-            # node has at least one float input, output is also float
-            model.set_tensor_datatype(node.output[0], DataType.FLOAT32)
-        else:
-            # TODO compute minimum / maximum result to minimize bitwidth
-            # use (u)int32 accumulators for now
-            has_signed_inp = len(list(filter(lambda x: x.signed(), idtypes))) != 0
-            if has_signed_inp:
-                odtype = DataType.INT32
-            else:
-                odtype = DataType.UINT32
-            model.set_tensor_datatype(node.output[0], odtype)
+            # lookup op_type in registry of CustomOps
+            inst = registry.custom_op[op_type]()
+            inst.infer_node_datatype(node, model)
+        except KeyError:
+            # exception if op_type is not supported
+            raise Exception("Custom op_type %s is currently not supported." % op_type)
     else:
-        # unknown, assume node produces float32 outputs
-        for o in node.output:
-            model.set_tensor_datatype(o, DataType.FLOAT32)
+        if node.op_type == "Sign":
+            # always produces bipolar outputs
+            model.set_tensor_datatype(node.output[0], DataType.BIPOLAR)
+        elif node.op_type == "MatMul":
+            if len(list(filter(lambda x: x == DataType.FLOAT32, idtypes))) != 0:
+                # node has at least one float input, output is also float
+                model.set_tensor_datatype(node.output[0], DataType.FLOAT32)
+            else:
+                # TODO compute minimum / maximum result to minimize bitwidth
+                # use (u)int32 accumulators for now
+                has_signed_inp = len(list(filter(lambda x: x.signed(), idtypes))) != 0
+                if has_signed_inp:
+                    odtype = DataType.INT32
+                else:
+                    odtype = DataType.UINT32
+                model.set_tensor_datatype(node.output[0], odtype)
+        else:
+            # unknown, assume node produces float32 outputs
+            for o in node.output:
+                model.set_tensor_datatype(o, DataType.FLOAT32)
     # compare old and new output dtypes to see if anything changed
     new_odtypes = list(map(lambda x: model.get_tensor_datatype(x), node.output))
     graph_modified = new_odtypes != odtypes