diff --git a/src/finn/custom_op/multithreshold.py b/src/finn/custom_op/multithreshold.py index 37f8e0950b5fc352c8f9fe005884724f028879a0..bc0a454cdf847d124b12c940b029f51bf2d3e778 100644 --- a/src/finn/custom_op/multithreshold.py +++ b/src/finn/custom_op/multithreshold.py @@ -33,16 +33,6 @@ from finn.core.datatype import DataType from finn.custom_op import CustomOp -def compare(x, y): - """Comparison helper function for multithresholding. - - Gets two values and returns 1.0 if x>=y otherwise 0.0.""" - if x >= y: - return 1.0 - else: - return 0.0 - - def multithreshold(v, thresholds, out_scale=None, out_bias=None): """Given a set of threshold values t={t_0, t_1 ... t_n} the successive thresholding maps any real number x to an integer in the interval [0, n], @@ -76,8 +66,6 @@ def multithreshold(v, thresholds, out_scale=None, out_bias=None): num_act = thresholds.shape[1] # reshape inputs to enable channel-wise reading vr = v.reshape((v.shape[0], v.shape[1], -1)) - # save the new shape size of the images - num_img_elem = vr.shape[2] # initiate output tensor ret = np.zeros_like(vr) # iterate over thresholds channel-wise @@ -85,12 +73,10 @@ def multithreshold(v, thresholds, out_scale=None, out_bias=None): channel_thresh = thresholds[0] if is_global_threshold else thresholds[t] # iterate over batches for b in range(num_batch): - # iterate over image elements on which the thresholds will be applied - for elem in range(num_img_elem): - # iterate over the different thresholds for one channel - for a in range(num_act): - # apply successive thresholding to every element - ret[b][t][elem] += compare(vr[b][t][elem], channel_thresh[a]) + # iterate over the different thresholds for one channel + for a in range(num_act): + ret[b][t] += (vr[b][t] >= channel_thresh[a]).astype(int) + if out_scale is None: out_scale = 1.0 if out_bias is None: