diff --git a/src/finn/transformation/qonnx/convert_qonnx_to_finn.py b/src/finn/transformation/qonnx/convert_qonnx_to_finn.py
index 5b218f2c38592afff3b790395154454e563028bb..ae49d3cd21f805339deab8658aaa6a324a72ea98 100644
--- a/src/finn/transformation/qonnx/convert_qonnx_to_finn.py
+++ b/src/finn/transformation/qonnx/convert_qonnx_to_finn.py
@@ -52,8 +52,6 @@ class ConvertQONNXtoFINN(Transformation):
         model = model.transform(FoldQuantWeights())
         # Convert activations
         model = model.transform(ConvertQuantActToMultiThreshold())
-        # Infer types again
-        model = model.transform(InferDataTypes())
 
         # Unset FINN datatypes from MultiThreshold node output tensors to avoid warnings
         mt_nodes = model.get_nodes_by_op_type("MultiThreshold")
diff --git a/src/finn/transformation/qonnx/qonnx_activation_handlers.py b/src/finn/transformation/qonnx/qonnx_activation_handlers.py
index d5c00c73dab68a479a1f0f7cce7e7395d4fb6bd4..26c65a4cad600029d452326896b041e71f9423e7 100644
--- a/src/finn/transformation/qonnx/qonnx_activation_handlers.py
+++ b/src/finn/transformation/qonnx/qonnx_activation_handlers.py
@@ -123,9 +123,6 @@ class QuantActBaseHandler(ABC):
         graph = model.graph
         n = self._q_node
         running_node_index = self._q_index
-        successor = model.find_direct_successors(n)
-        if successor is not None:
-            successor = successor[0]
 
         # Calculate insertion parameters
         parameter_dict = self.calculate_node_parameters()
@@ -155,69 +152,105 @@ class QuantActBaseHandler(ABC):
         graph.node.insert(running_node_index, outp_trans_node)
         running_node_index += 1
 
-        # Insert Add node
-        if adder_bias.shape == (1,):
-            adder_bias = adder_bias[0]
-            add_shape = tuple()
+        # Get the MultiThreshold node instance to work with
+        mt_inst = getCustomOp(graph.node[running_node_index - 1])
+
+        # Set scale and bias
+        # If these values are scalar then they can be set as attributes
+        # of the MultiThreshold node, if not they get inserted as adder and mul nodes
+        # behind the MultiTrheshold nodes.
+        scale_compatible = adder_bias.shape == (1,) or len(adder_bias.shape) == 0
+        bias_compatible = mul_scale.shape == (1,) or len(mul_scale.shape) == 0
+        if scale_compatible and bias_compatible and True:
+            # Get Quant parameters
+            mul_scale = np.atleast_1d(mul_scale)
+            # ONNX only accepts 64bit floats as attributes
+            mul_scale = mul_scale.astype(dtype=np.float64)
+            adder_bias = np.atleast_1d(adder_bias)
+            adder_bias = adder_bias.astype(dtype=np.float64)
+
+            # Set Bias and scale
+            mt_inst.set_nodeattr("out_scale", mul_scale[0])
+            # FINN applies scale first then bias,
+            # which is the other way around in Brevitas,
+            # we thus need to adjust the bias in the MultiThreshold node
+            mt_inst.set_nodeattr("out_bias", adder_bias[0] * mul_scale[0])
         else:
-            add_shape = adder_bias.shape
-        add_tensor = helper.make_tensor_value_info(
-            model.make_new_valueinfo_name(),
-            TensorProto.FLOAT,
-            add_shape,
-        )
-        graph.value_info.append(add_tensor)
-        model.set_initializer(add_tensor.name, adder_bias)
-
-        output_shape = model.get_tensor_shape(n.output[0])
-        act_add_tensor = helper.make_tensor_value_info(
-            model.make_new_valueinfo_name(),
-            TensorProto.FLOAT,
-            output_shape,
-        )
-        graph.value_info.append(act_add_tensor)
-        if successor is not None:
-            successor.input[0] = act_add_tensor.name
-
-        add_node = helper.make_node(
-            "Add",
-            [n.output[0], add_tensor.name],
-            [act_add_tensor.name],
-        )
-        graph.node.insert(running_node_index, add_node)
-        running_node_index += 1
-
-        # Insert Mul node
-        if mul_scale.shape == (1,):
-            mul_scale = mul_scale[0]
-            mul_shape = tuple()
-        else:
-            mul_shape = mul_scale.shape
-        mul_tensor = helper.make_tensor_value_info(
-            model.make_new_valueinfo_name(),
-            TensorProto.FLOAT,
-            mul_shape,
-        )
-        graph.value_info.append(mul_tensor)
-        model.set_initializer(mul_tensor.name, mul_scale)
-
-        output_shape = model.get_tensor_shape(n.output[0])
-        act_mul_tensor = helper.make_tensor_value_info(
-            model.make_new_valueinfo_name(),
-            TensorProto.FLOAT,
-            output_shape,
-        )
-        graph.value_info.append(act_mul_tensor)
-        if successor is not None:
-            successor.input[0] = act_mul_tensor.name
-
-        mul_node = helper.make_node(
-            "Mul",
-            [act_add_tensor.name, mul_tensor.name],
-            [act_mul_tensor.name],
-        )
-        graph.node.insert(running_node_index, mul_node)
-        running_node_index += 1
+            if bias_compatible:
+                adder_bias = np.atleast_1d(adder_bias)
+                # ONNX only accepts 64bit floats as attributes
+                adder_bias = adder_bias.astype(dtype=np.float64)[0]
+                add_shape = tuple()
+            else:
+                add_shape = adder_bias.shape
+
+            in_tensor = n.output[0]
+            successor_node = model.find_direct_successors(n)
+            if successor_node is not None:
+                successor_node = successor_node[0]
+            # Insert Add node
+            add_tensor = helper.make_tensor_value_info(
+                model.make_new_valueinfo_name(),
+                TensorProto.FLOAT,
+                add_shape,
+            )
+            graph.value_info.append(add_tensor)
+            model.set_initializer(add_tensor.name, adder_bias)
+
+            output_shape = model.get_tensor_shape(n.output[0])
+            act_add_tensor = helper.make_tensor_value_info(
+                model.make_new_valueinfo_name(),
+                TensorProto.FLOAT,
+                output_shape,
+            )
+            graph.value_info.append(act_add_tensor)
+            if successor_node is not None:
+                successor_node.input[0] = act_add_tensor.name
+
+            add_node = helper.make_node(
+                "Add",
+                [in_tensor, add_tensor.name],
+                [act_add_tensor.name],
+            )
+            graph.node.insert(running_node_index, add_node)
+            running_node_index += 1
+
+            # Re-point the input node for the next node to insert
+            in_tensor = act_add_tensor.name
+
+            # Set scale
+            # Insert Mul node
+            if mul_scale:
+                mul_scale = np.atleast_1d(mul_scale)
+                mul_scale = mul_scale.astype(dtype=np.float64)[0]
+                mul_shape = tuple()
+            else:
+                mul_shape = mul_scale.shape
+            mul_tensor = helper.make_tensor_value_info(
+                model.make_new_valueinfo_name(),
+                TensorProto.FLOAT,
+                mul_shape,
+            )
+            graph.value_info.append(mul_tensor)
+            model.set_initializer(mul_tensor.name, mul_scale)
+
+            output_shape = model.get_tensor_shape(n.output[0])
+            act_mul_tensor = helper.make_tensor_value_info(
+                model.make_new_valueinfo_name(),
+                TensorProto.FLOAT,
+                output_shape,
+            )
+            graph.value_info.append(act_mul_tensor)
+            if successor_node is not None:
+                successor_node.input[0] = act_mul_tensor.name
+
+            mul_node = helper.make_node(
+                "Mul",
+                [in_tensor, mul_tensor.name],
+                [act_mul_tensor.name],
+            )
+            graph.node.insert(running_node_index, mul_node)
+            running_node_index += 1
 
         # Remove activation node
         self._remove_activation_node()