diff --git a/src/finn/custom_op/multithreshold.py b/src/finn/custom_op/multithreshold.py
index 37f8e0950b5fc352c8f9fe005884724f028879a0..bc0a454cdf847d124b12c940b029f51bf2d3e778 100644
--- a/src/finn/custom_op/multithreshold.py
+++ b/src/finn/custom_op/multithreshold.py
@@ -33,16 +33,6 @@ from finn.core.datatype import DataType
 from finn.custom_op import CustomOp
 
 
-def compare(x, y):
-    """Comparison helper function for multithresholding.
-
-    Gets two values and returns 1.0 if x>=y otherwise 0.0."""
-    if x >= y:
-        return 1.0
-    else:
-        return 0.0
-
-
 def multithreshold(v, thresholds, out_scale=None, out_bias=None):
     """Given a set of threshold values t={t_0, t_1 ... t_n} the successive
     thresholding maps any real number x to an integer in the interval [0, n],
@@ -76,8 +66,6 @@ def multithreshold(v, thresholds, out_scale=None, out_bias=None):
     num_act = thresholds.shape[1]
     # reshape inputs to enable channel-wise reading
     vr = v.reshape((v.shape[0], v.shape[1], -1))
-    # save the new shape size of the images
-    num_img_elem = vr.shape[2]
     # initiate output tensor
     ret = np.zeros_like(vr)
     # iterate over thresholds channel-wise
@@ -85,12 +73,10 @@ def multithreshold(v, thresholds, out_scale=None, out_bias=None):
         channel_thresh = thresholds[0] if is_global_threshold else thresholds[t]
         # iterate over batches
         for b in range(num_batch):
-            # iterate over image elements on which the thresholds will be applied
-            for elem in range(num_img_elem):
-                # iterate over the different thresholds for one channel
-                for a in range(num_act):
-                    # apply successive thresholding to every element
-                    ret[b][t][elem] += compare(vr[b][t][elem], channel_thresh[a])
+            # iterate over the different thresholds for one channel
+            for a in range(num_act):
+                ret[b][t] += (vr[b][t] >= channel_thresh[a]).astype(int)
+
     if out_scale is None:
         out_scale = 1.0
     if out_bias is None:
diff --git a/tests/custom_op/test_multi_thresholding.py b/tests/custom_op/test_multithreshold.py
similarity index 61%
rename from tests/custom_op/test_multi_thresholding.py
rename to tests/custom_op/test_multithreshold.py
index 4f2b08675fdabb1bda49972c51892da92e1a0cdc..7e6ad4fe08517290dd22a2c74b2847d007b74b1f 100644
--- a/tests/custom_op/test_multi_thresholding.py
+++ b/tests/custom_op/test_multithreshold.py
@@ -27,11 +27,76 @@
 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
 import numpy as np
-
+import time
 from finn.custom_op.multithreshold import multithreshold
 
 
-def test_execute_multi_thresholding():
+def compare(x, y):
+    """Comparison helper function for multithresholding.
+
+    Gets two values and returns 1.0 if x>=y otherwise 0.0."""
+    if x >= y:
+        return 1.0
+    else:
+        return 0.0
+
+# naive implementation of thresholding for performance comparison
+def multithreshold_elementwise(v, thresholds, out_scale=None, out_bias=None):
+    """Given a set of threshold values t={t_0, t_1 ... t_n} the successive
+    thresholding maps any real number x to an integer in the interval [0, n],
+    where the returned integer is the number of thresholds x is greater than
+    or equal to.
+
+    The output tensor will be scaled by out_scale and biased by out_bias."""
+    # the inputs are expected to be in the shape (N,C,H,W) or (N, C)
+    # the MultiThreshold node supports a data_layout attribute that can be set
+    # to 'NHWC' to support (N,H,W,C) data layout mode for in-out as well
+    # N : Batch size
+    # C : Number of channels
+    # H : Heigth of the input images
+    # W : Width of the input images
+    #
+    # the thresholds are expected to be in the shape (C, B)
+    # C : Number of channels (must be the same value as C in input tensor
+    #     or 1 if all channels use the same threshold value)
+    # B : Desired activation steps => i.e. for 4-bit activation,
+    #     B=7 (2^(n)-1 and n=4)
+    # the output tensor will be scaled by out_scale and biased by out_bias
+    # assert threshold shape
+    is_global_threshold = thresholds.shape[0] == 1
+    assert (
+        v.shape[1] == thresholds.shape[0]
+    ) or is_global_threshold, """"Threshold
+    shape incorrect"""
+    # save the required shape sizes for the loops (N, C and B)
+    num_batch = v.shape[0]
+    num_channel = v.shape[1]
+    num_act = thresholds.shape[1]
+    # reshape inputs to enable channel-wise reading
+    vr = v.reshape((v.shape[0], v.shape[1], -1))
+    # save the new shape size of the images
+    num_img_elem = vr.shape[2]
+    # initiate output tensor
+    ret = np.zeros_like(vr)
+    # iterate over thresholds channel-wise
+    for t in range(num_channel):
+        channel_thresh = thresholds[0] if is_global_threshold else thresholds[t]
+        # iterate over batches
+        for b in range(num_batch):
+            # iterate over image elements on which the thresholds will be applied
+            for elem in range(num_img_elem):
+                # iterate over the different thresholds for one channel
+                for a in range(num_act):
+                    # apply successive thresholding to every element
+                    ret[b][t][elem] += compare(vr[b][t][elem], channel_thresh[a])
+    if out_scale is None:
+        out_scale = 1.0
+    if out_bias is None:
+        out_bias = 0.0
+    return out_scale * ret.reshape(v.shape) + out_bias
+
+
+def test_multithreshold():
 
     inputs = np.ndarray(
         shape=(6, 3, 2, 2),
@@ -223,9 +288,35 @@ def test_execute_multi_thresholding():
     )
 
     results = multithreshold(inputs, thresholds)
-
     assert (results == outputs).all()
 
     results_scaled = multithreshold(inputs, thresholds, 2.0, -1.0)
     outputs_scaled = 2.0 * outputs - 1.0
     assert (results_scaled == outputs_scaled).all()
+
+    # performance and random test
+    np.random.seed(0)
+    inputs = np.random.random((1, 256, 64, 64))
+    thresholds = (np.array([[1, 2, 3, 4, 5, 6]]) - 0.5) / 6
+
+    before = time.time()
+    vec_results = multithreshold(inputs, thresholds)
+    after = time.time()
+    vector_runtime = after - before
+
+    before = time.time()
+    nonvec_results = multithreshold_elementwise(inputs, thresholds)
+    after = time.time()
+    non_vector_runtime = after - before
+
+    assert (vec_results == nonvec_results).all()
+
+    return vector_runtime, non_vector_runtime
+
+
+if __name__ == "__main__":
+    vector_runtime, non_vector_runtime = test_multithreshold()
+
+    print("Runtime non-vectorized: ", non_vector_runtime, "s")
+    print("Runtime vectorized: ", vector_runtime, "s")
+    print("Speed-up: ", non_vector_runtime / vector_runtime)