diff --git a/src/finn/core/multi_thresholding.py b/src/finn/core/multi_thresholding.py
index 2bd6aee2c551bef77839e358c5e9cf0c6a4f1ec5..f3b64f0dd90e319a3521be853d8c51180e936a04 100755
--- a/src/finn/core/multi_thresholding.py
+++ b/src/finn/core/multi_thresholding.py
@@ -2,38 +2,48 @@ import numpy as np
 
 
 def compare(x, y):
-    if x >= y:
-        return 1.0
-    else:
-        return 0.0
+	if x >= y:
+		return 1.0
+	else:
+		return 0.0
 
 
 def execute(v, thresholds):
-    # reshape inputs to enable channel-wise reading
-    vr = v.reshape((thresholds.shape[1], -1))
-
-    # calculate the channelinterval for the for loops
-    num_channels = thresholds.shape[0]
-    channel_interval = int(vr.shape[1] / num_channels)
-
-    # initiate output tensor
-    ret = np.zeros_like(vr)
-
-    # initiate helper variable i for channel-wise thresholding
-    i = -1
-
-    # iterate over thresholds channel-wise
-    for t in thresholds:
-        print(t)
-        i += 1
-
-        # calculate the lower and upper limit in which elements belong to one channel
-        ce1_low_lim = i * channel_interval
-        ce1_up_lim = (i + 1) * channel_interval
-
-        # iterate in ascending order over the thresholds belonging to one channel
-        for c in range(thresholds.shape[1]):
-            for ce0 in range(vr.shape[0]):
-                for ce1 in range(ce1_low_lim, ce1_up_lim):
-                    ret[ce0][ce1] += compare(vr[ce0][ce1], t[c])
-    return ret.reshape(v.shape)
+
+	# the inputs are expected to be in the shape (N,C,H,W)
+	# N : Batch size
+	# C : Number of channels
+	# H : Heigth of the input images
+	# W : Width of the input images
+	#
+	# the thresholds are expected to be in the shape (C, B)
+	# C : Number of channels (must be the same value as C in input tensor)
+	# B : Desired activation steps => i.e. for 4-bit activation, B=7 (2^(n)-1 and n=4)
+
+	# assert if channel sizes do not match
+	assert v.shape[1] == thresholds.shape[0]
+
+	num_batch = v.shape[0]
+	num_channel = v.shape[1]
+
+	num_act = thresholds.shape[1]
+
+	# reshape inputs to enable channel-wise reading
+	vr = v.reshape((v.shape[0], v.shape[1], -1))
+
+	num_img_elem = vr.shape[2]
+
+	# initiate output tensor
+	ret = np.zeros_like(vr)
+
+	# iterate over thresholds channel-wise
+	for t in range(num_channel):
+		channel_thresh = thresholds[t]
+		for b in range(num_batch):
+			for elem in range(num_img_elem):
+				print(vr[b][t][elem])
+				for a in range(num_act):
+					print(channel_thresh[a])
+					ret[b][t][elem] += compare(vr[b][t][elem], channel_thresh[a])
+	print(ret)
+	return ret.reshape(v.shape)
diff --git a/tests/test_multi_thresholding.py b/tests/test_multi_thresholding.py
deleted file mode 100644
index 7e1033cf8dbe38ce7269be685cc66bd7d1d877d0..0000000000000000000000000000000000000000
--- a/tests/test_multi_thresholding.py
+++ /dev/null
@@ -1,24 +0,0 @@
-import numpy as np
-
-import finn.core.multi_thresholding as multi_thresh
-
-
-def test_execute_multi_thresholding():
-    inputs = np.genfromtxt(
-        "../src/finn/data/multi-thresholding/input.csv", delimiter=","
-    )
-    inputs = inputs.reshape(7, 3, 2, 2)
-
-    thresholds = np.genfromtxt(
-        "../src/finn/data/multi-thresholding/thresholds.csv", delimiter=","
-    )
-    thresholds = thresholds.reshape(3, 7)
-
-    outputs = np.genfromtxt(
-        "../src/finn/data/multi-thresholding/output.csv", delimiter=","
-    )
-    outputs = outputs.reshape(7, 3, 2, 2)
-
-    results = multi_thresh.execute(inputs, thresholds)
-
-    assert np.isclose(outputs, results, atol=1e-3).all()