diff --git a/src/finn/transformation/convert_qonnx_to_finn.py b/src/finn/transformation/convert_qonnx_to_finn.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b35af04cf4c94c67cde9ca6cb205c8a0a6b62f0
--- /dev/null
+++ b/src/finn/transformation/convert_qonnx_to_finn.py
@@ -0,0 +1,216 @@
+import numpy as np
+from onnx import TensorProto, helper
+
+from finn.custom_op.registry import getCustomOp
+from finn.transformation.base import Transformation
+
+allowed_identity_successors = [
+    "MatMul",
+    "Conv",
+    "MaxPool",
+    "Reshape",
+    None,
+]
+
+
+class ConvertQuantActToMultiThreshold(Transformation):
+    """Converts Quant nodes in the activation path to MultiThreshold nodes."""
+
+    def apply(self, model):
+        graph = model.graph
+        node_ind = 0
+        graph_modified = False
+
+        for n in graph.node:
+            node_ind += 1
+            if n.op_type == "Quant":
+                running_node_index = node_ind
+                # Check that the node is in the activation path
+                inp = model.get_initializer(n.input[0])
+                out = model.get_initializer(n.output[0])
+                if not (inp is None and out is None):
+                    continue
+                successor = model.find_direct_successors(n)
+                if successor is not None:
+                    successor = successor[0]
+                if model.is_fork_node(n):
+                    raise RuntimeError(
+                        "Forking Quant nodes are not currently supported by FINN."
+                    )
+
+                # ToDo: Check for activation functions behind (or infront of?)
+                #  the Quant node, such as ReLu
+
+                # Check that this is an idendity operation
+                if successor.op_type in allowed_identity_successors:
+                    # Compute thesholds, bias and scale for the new nodes
+                    dtype = model.get_tensor_datatype(n.output[0]).name
+                    if "SCALED" in dtype:
+                        dtype = dtype.replace("SCALED", "")
+                    # Treating Quant node as Quant idendity for now
+                    q_inst = getCustomOp(n)
+                    # Get parameters
+                    quant_scale = model.get_initializer(n.input[1])
+
+                    bit_width = model.get_initializer(n.input[3])
+                    narrow = q_inst.get_nodeattr("narrow")
+                    # ToDo: zero_pt and signed should have some sort of influence or
+                    #  should at least get checked for correct range or value
+                    # zero_pt = model.get_initializer(n.input[2])
+                    # signed = q_inst.get_nodeattr("signed")
+
+                    # Calculate thersholds, see: https://github.com/Xilinx/brevitas/
+                    # blob/a5bfd6dc5e030f0047ac1ee47932b60e8e873e17/src/brevitas/
+                    # export/onnx/finn/handler/act.py#L76
+                    if narrow:
+                        num_distinct_values = 2 ** bit_width - 1
+                    else:
+                        num_distinct_values = 2 ** bit_width
+
+                    num_thresholds = int(num_distinct_values - 1)
+                    flat_scale = quant_scale.flatten()
+                    num_scale_channels = flat_scale.shape[0]
+                    step = np.abs(flat_scale)
+                    half_step = step / 2.0
+                    thresholds = np.empty((num_scale_channels, num_thresholds))
+                    # compute the value of the smallest threshold, we'll neg-bias all
+                    # generated thresholds by this much
+                    min_threshold = -half_step - step * ((num_thresholds // 2) - 1)
+                    if not narrow:
+                        min_threshold -= step
+                    for c in range(num_scale_channels):
+                        for t in range(num_thresholds):
+                            thresholds[c][t] = min_threshold[c] + step[c] * t
+
+                    # ToDo: The index 1 needs to be changed to -1 for the channels last
+                    #  format
+                    num_output_channels = model.get_tensor_shape(n.output[0])[1]
+                    final_shape = (num_output_channels, num_thresholds)
+                    if thresholds.shape != final_shape:
+                        thresholds = np.broadcast_to(thresholds, final_shape)
+
+                    # Calculate bias, see: https://github.com/Xilinx/brevitas/blob/
+                    # a5bfd6dc5e030f0047ac1ee47932b60e8e873e17/src/brevitas/export/
+                    # onnx/finn/handler/act.py#L64
+                    if bit_width == 1:
+                        bias = np.array([-0.5])
+                    else:
+                        if narrow:
+                            min_non_scaled_val = -(2 ** (bit_width - 1) - 1)
+                        else:
+                            min_non_scaled_val = -(2 ** (bit_width - 1))
+                        bias = np.array([min_non_scaled_val])
+
+                    # Calculate scale, see: https://github.com/Xilinx/brevitas/
+                    # blob/a5bfd6dc5e030f0047ac1ee47932b60e8e873e17/src/brevitas/
+                    # export/onnx/finn/handler/act.py#L111
+                    if bit_width != 1:
+                        scale = quant_scale
+                    else:
+                        # ToDo: This needs testing or rewriting when the BinarayQuant op
+                        #  comes around
+                        assert (
+                            quant_scale.flatten().shape[0] == 1
+                        ), "Unsupported BIPOLAR per channel scale"
+                        assert (
+                            quant_scale.flatten().item() == 1.0
+                        ), "Unsupported BIPOLAR scale != 1"
+                        scale = quant_scale * 2
+
+                    # Modify graph
+                    # Insert threshold tensor
+                    thresh_tensor = helper.make_tensor_value_info(
+                        model.make_new_valueinfo_name(),
+                        TensorProto.FLOAT,
+                        final_shape,
+                    )
+                    graph.value_info.append(thresh_tensor)
+                    model.set_initializer(thresh_tensor.name, thresholds)
+
+                    # Insert MultiThreshold node
+                    outp_trans_node = helper.make_node(
+                        "MultiThreshold",
+                        [n.input[0], thresh_tensor.name],
+                        [n.output[0]],
+                        out_dtype=dtype,
+                        domain="finn.custom_op.general",
+                    )
+                    graph.node.insert(running_node_index, outp_trans_node)
+                    running_node_index += 1
+
+                    # Insert Add node
+                    if bias.shape == (1,):
+                        bias = bias[0]
+                        add_shape = tuple()
+                    else:
+                        add_shape = bias.shape
+                    add_tensor = helper.make_tensor_value_info(
+                        model.make_new_valueinfo_name(),
+                        TensorProto.FLOAT,
+                        add_shape,
+                    )
+                    graph.value_info.append(add_tensor)
+                    model.set_initializer(add_tensor.name, bias)
+
+                    output_shape = model.get_tensor_shape(n.output[0])
+                    act_add_tensor = helper.make_tensor_value_info(
+                        model.make_new_valueinfo_name(),
+                        TensorProto.FLOAT,
+                        output_shape,
+                    )
+                    graph.value_info.append(act_add_tensor)
+                    successor.input[0] = act_add_tensor.name
+
+                    add_node = helper.make_node(
+                        "Add",
+                        [n.output[0], add_tensor.name],
+                        [act_add_tensor.name],
+                    )
+                    graph.node.insert(running_node_index, add_node)
+                    running_node_index += 1
+
+                    # Insert Mul node
+                    if scale.shape == (1,):
+                        scale = scale[0]
+                        mul_shape = tuple()
+                    else:
+                        mul_shape = scale.shape
+                    mul_tensor = helper.make_tensor_value_info(
+                        model.make_new_valueinfo_name(),
+                        TensorProto.FLOAT,
+                        mul_shape,
+                    )
+                    graph.value_info.append(mul_tensor)
+                    model.set_initializer(mul_tensor.name, scale)
+
+                    output_shape = model.get_tensor_shape(n.output[0])
+                    act_mul_tensor = helper.make_tensor_value_info(
+                        model.make_new_valueinfo_name(),
+                        TensorProto.FLOAT,
+                        output_shape,
+                    )
+                    graph.value_info.append(act_mul_tensor)
+                    successor.input[0] = act_mul_tensor.name
+
+                    mul_node = helper.make_node(
+                        "Mul",
+                        [act_add_tensor.name, mul_tensor.name],
+                        [act_mul_tensor.name],
+                    )
+                    graph.node.insert(running_node_index, mul_node)
+                    running_node_index += 1
+
+                    # Now remove the Quant node
+                    graph.node.remove(n)
+
+                    # break
+                    graph_modified = True
+                    return (model, graph_modified)
+                else:
+                    raise RuntimeError(
+                        f"Quant nodes with successor nodes of type {successor.op_type} "
+                        f"are currently not supported by FINN and can not be converted "
+                        f"to MultiThreshold nodes."
+                    )
+
+        return (model, graph_modified)