diff --git a/src/finn/core/execute_custom_node.py b/src/finn/core/execute_custom_node.py
index 62d8b7fc1ba0ce0d6c83c2b56ef1d75613be974b..e78e9d5de34e013fdbe43aae70c37955b95532b9 100644
--- a/src/finn/core/execute_custom_node.py
+++ b/src/finn/core/execute_custom_node.py
@@ -1,6 +1,7 @@
 # import onnx.helper as helper
 
 import finn.core.multithreshold as multiThresh
+from finn.core.utils import get_by_name
 
 
 def execute_custom_node(node, context, graph):
@@ -11,8 +12,17 @@ def execute_custom_node(node, context, graph):
         # save inputs
         v = context[node.input[0]]
         thresholds = context[node.input[1]]
+        # retrieve attributes if output scaling is used
+        try:
+            out_scale = get_by_name(node.attribute, "out_scale").f
+        except AttributeError:
+            out_scale = None
+        try:
+            out_bias = get_by_name(node.attribute, "out_bias").f
+        except AttributeError:
+            out_bias = None
         # calculate output
-        output = multiThresh.execute(v, thresholds)
+        output = multiThresh.execute(v, thresholds, out_scale, out_bias)
         # setting context according to output
         context[node.output[0]] = output
 
diff --git a/src/finn/core/multithreshold.py b/src/finn/core/multithreshold.py
index 009259c577879a8aa09ac44ace704af55ca2593d..de38971502f1930dbe2090f3d88aad62e209478a 100755
--- a/src/finn/core/multithreshold.py
+++ b/src/finn/core/multithreshold.py
@@ -8,7 +8,7 @@ def compare(x, y):
         return 0.0
 
 
-def execute(v, thresholds):
+def execute(v, thresholds, out_scale=None, out_bias=None):
 
     # the inputs are expected to be in the shape (N,C,H,W)
     # N : Batch size
@@ -21,6 +21,8 @@ def execute(v, thresholds):
     #     if all channels use the same threshold value)
     # B : Desired activation steps => i.e. for 4-bit activation, B=7 (2^(n)-1 and n=4)
 
+    # the output tensor will be scaled by out_scale and biased by out_bias
+
     # assert threshold shape
     is_global_threshold = thresholds.shape[0] == 1
     assert (v.shape[1] == thresholds.shape[0]) or is_global_threshold
@@ -54,5 +56,8 @@ def execute(v, thresholds):
                 for a in range(num_act):
                     # apply successive thresholding to every element of one channel
                     ret[b][t][elem] += compare(vr[b][t][elem], channel_thresh[a])
-
-    return ret.reshape(v.shape)
+    if out_scale is None:
+        out_scale = 1.0
+    if out_bias is None:
+        out_bias = 0.0
+    return out_scale * ret.reshape(v.shape) + out_bias
diff --git a/src/finn/transformation/infer_datatypes.py b/src/finn/transformation/infer_datatypes.py
index a311012fc6631e76e75a37b8dc4d1b99d21ce7c7..2c6d75a1d029fbfea3d5b408cc6e8eb5f70dc7b5 100644
--- a/src/finn/transformation/infer_datatypes.py
+++ b/src/finn/transformation/infer_datatypes.py
@@ -1,4 +1,5 @@
 from finn.core.datatype import DataType
+from finn.core.utils import get_by_name
 from finn.transformation import Transformation
 
 
@@ -8,10 +9,15 @@ def _infer_node_datatype(model, node):
     idtypes = list(map(lambda x: model.get_tensor_datatype(x), node.input))
     odtypes = list(map(lambda x: model.get_tensor_datatype(x), node.output))
     if node.op_type == "MultiThreshold":
-        # number of thresholds decides # output buts, use get_smallest_possible
-        n_thres = model.get_tensor_shape(node.input[1])[1]
-        odtype = DataType.get_smallest_possible(n_thres)
-        model.set_tensor_datatype(node.output[0], odtype)
+        try:
+            odt = get_by_name(node.attribute, "out_dtype").s.decode("utf-8")
+            model.set_tensor_datatype(node.output[0], DataType[odt])
+        except AttributeError:
+            # number of thresholds decides # output bits
+            # use get_smallest_possible, assuming unsigned
+            n_thres = model.get_tensor_shape(node.input[1])[1]
+            odtype = DataType.get_smallest_possible(n_thres)
+            model.set_tensor_datatype(node.output[0], odtype)
     elif node.op_type == "Sign":
         # always produces bipolar outputs
         model.set_tensor_datatype(node.output[0], DataType.BIPOLAR)
diff --git a/src/finn/transformation/streamline.py b/src/finn/transformation/streamline.py
index fb9e530d641063187dd75de30c8ca49936565bae..95e84e9907717aa8106a3a16319202bc29610f65 100644
--- a/src/finn/transformation/streamline.py
+++ b/src/finn/transformation/streamline.py
@@ -16,50 +16,30 @@ class ConvertSignToThres(Transformation):
         for n in graph.node:
             node_ind += 1
             if n.op_type == "Sign":
+                sign_in_name = n.input[0]
                 sign_out_name = n.output[0]
                 # find consumer
                 consumer = model.find_consumer(sign_out_name)
                 assert consumer is not None
-                # change op type and create threshold
-                n.op_type = "MultiThreshold"
+                # create thresholds
                 thres_param_name = model.make_new_valueinfo_name()
                 thres_param = np.asarray([[0]], dtype=np.float32)
-                n.input.append(thres_param_name)
-                n.domain = "finn"
                 model.set_initializer(thres_param_name, thres_param)
-                # convert 0,1 -> -1,+1 with 2*x-1
-                out_shape = model.get_tensor_shape(sign_out_name)
-                # make a mul node
-                # note how set_initializer or set_tensor_shape is called before
-                # calling make_new_valueinfo_name again
-                mul_param_name = model.make_new_valueinfo_name()
-                model.set_initializer(
-                    mul_param_name, np.asarray([[2]], dtype=np.float32)
+                # create a new node
+                mt_node = oh.make_node(
+                    "MultiThreshold",
+                    [sign_in_name, thres_param_name],
+                    [sign_out_name],
+                    domain="finn",
+                    out_scale=2.0,
+                    out_bias=-1.0,
+                    out_dtype="BIPOLAR",
                 )
-                mul_out_name = model.make_new_valueinfo_name()
-                model.set_tensor_shape(mul_out_name, out_shape)
-                mul_node = oh.make_node(
-                    "Mul", [sign_out_name, mul_param_name], [mul_out_name]
-                )
-                # make an add node
-                add_param_name = model.make_new_valueinfo_name()
-                model.set_initializer(
-                    add_param_name, np.asarray([[-1]], dtype=np.float32)
-                )
-                add_out_name = model.make_new_valueinfo_name()
-                model.set_tensor_shape(add_out_name, out_shape)
-                add_node = oh.make_node(
-                    "Add", [mul_out_name, add_param_name], [add_out_name]
-                )
-                # add new nodes to graph at correct position
-                graph.node.insert(node_ind, mul_node)
-                graph.node.insert(node_ind + 1, add_node)
-                # rewrite consumer's input
-                consumer.input[0] = add_out_name
+                # remove old node, add new node to graph at correct position
+                graph.node.insert(node_ind, mt_node)
+                graph.node.remove(n)
                 # add quantization annotations
-                model.set_tensor_datatype(sign_out_name, DataType.BINARY)
-                model.set_tensor_datatype(mul_out_name, DataType.UINT2)
-                model.set_tensor_datatype(add_out_name, DataType.BIPOLAR)
+                model.set_tensor_datatype(sign_out_name, DataType.BIPOLAR)
                 graph_modified = True
         return (model, graph_modified)
 
diff --git a/tests/test_custom_onnx_exec.py b/tests/test_custom_onnx_exec.py
index 0d07b9888d1330753446d589619149c1f8f316cf..e1ff552572e8a6d4d55e204cd21f17e4984ce30d 100644
--- a/tests/test_custom_onnx_exec.py
+++ b/tests/test_custom_onnx_exec.py
@@ -211,3 +211,18 @@ def test_execute_custom_node_multithreshold():
     )
 
     assert (execution_context["out"] == outputs).all()
+
+    # test the optional output scaling features on MultiThreshold
+    node_def = helper.make_node(
+        "MultiThreshold",
+        ["v", "thresholds"],
+        ["out"],
+        domain="finn",
+        out_scale=2.0,
+        out_bias=-1.0,
+    )
+
+    graph_def = helper.make_graph([node_def], "test_model", [v, thresholds], [out])
+    ex_cu_node.execute_custom_node(node_def, execution_context, graph_def)
+    outputs_scaled = 2.0 * outputs - 1.0
+    assert (execution_context["out"] == outputs_scaled).all()
diff --git a/tests/test_multi_thresholding.py b/tests/test_multi_thresholding.py
index 49305e5572f35eb4a2e5f7678c73038777eb8b92..5fd7b4309d68ad568773dbdb39d0df979a767e73 100644
--- a/tests/test_multi_thresholding.py
+++ b/tests/test_multi_thresholding.py
@@ -197,3 +197,7 @@ def test_execute_multi_thresholding():
     results = multi_thresh.execute(inputs, thresholds)
 
     assert (results == outputs).all()
+
+    results_scaled = multi_thresh.execute(inputs, thresholds, 2.0, -1.0)
+    outputs_scaled = 2.0 * outputs - 1.0
+    assert (results_scaled == outputs_scaled).all()