Skip to content
Snippets Groups Projects
Commit 1d3a00d2 authored by Tobi-Alonso's avatar Tobi-Alonso
Browse files

[CustomOp] Replace elementwise with vector implementation of multithreshold function

parent daf33b47
No related branches found
No related tags found
No related merge requests found
......@@ -33,16 +33,6 @@ from finn.core.datatype import DataType
from finn.custom_op import CustomOp
def compare(x, y):
"""Comparison helper function for multithresholding.
Gets two values and returns 1.0 if x>=y otherwise 0.0."""
if x >= y:
return 1.0
else:
return 0.0
def multithreshold(v, thresholds, out_scale=None, out_bias=None):
"""Given a set of threshold values t={t_0, t_1 ... t_n} the successive
thresholding maps any real number x to an integer in the interval [0, n],
......@@ -76,8 +66,6 @@ def multithreshold(v, thresholds, out_scale=None, out_bias=None):
num_act = thresholds.shape[1]
# reshape inputs to enable channel-wise reading
vr = v.reshape((v.shape[0], v.shape[1], -1))
# save the new shape size of the images
num_img_elem = vr.shape[2]
# initiate output tensor
ret = np.zeros_like(vr)
# iterate over thresholds channel-wise
......@@ -85,12 +73,10 @@ def multithreshold(v, thresholds, out_scale=None, out_bias=None):
channel_thresh = thresholds[0] if is_global_threshold else thresholds[t]
# iterate over batches
for b in range(num_batch):
# iterate over image elements on which the thresholds will be applied
for elem in range(num_img_elem):
# iterate over the different thresholds for one channel
for a in range(num_act):
# apply successive thresholding to every element
ret[b][t][elem] += compare(vr[b][t][elem], channel_thresh[a])
# iterate over the different thresholds for one channel
for a in range(num_act):
ret[b][t] += (vr[b][t] >= channel_thresh[a]).astype(int)
if out_scale is None:
out_scale = 1.0
if out_bias is None:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment