From 1d3a00d20fe672a02a73da8e6d43c0b6aff093a8 Mon Sep 17 00:00:00 2001
From: Tobi-Alonso <tobi.alonso@gmail.com>
Date: Thu, 14 May 2020 17:08:15 +0100
Subject: [PATCH] [CustomOp] Replace elementwise with vector implementation of
 multithreshold function

---
 src/finn/custom_op/multithreshold.py | 22 ++++------------------
 1 file changed, 4 insertions(+), 18 deletions(-)

diff --git a/src/finn/custom_op/multithreshold.py b/src/finn/custom_op/multithreshold.py
index 37f8e0950..bc0a454cd 100644
--- a/src/finn/custom_op/multithreshold.py
+++ b/src/finn/custom_op/multithreshold.py
@@ -33,16 +33,6 @@ from finn.core.datatype import DataType
 from finn.custom_op import CustomOp
 
 
-def compare(x, y):
-    """Comparison helper function for multithresholding.
-
-    Gets two values and returns 1.0 if x>=y otherwise 0.0."""
-    if x >= y:
-        return 1.0
-    else:
-        return 0.0
-
-
 def multithreshold(v, thresholds, out_scale=None, out_bias=None):
     """Given a set of threshold values t={t_0, t_1 ... t_n} the successive
     thresholding maps any real number x to an integer in the interval [0, n],
@@ -76,8 +66,6 @@ def multithreshold(v, thresholds, out_scale=None, out_bias=None):
     num_act = thresholds.shape[1]
     # reshape inputs to enable channel-wise reading
     vr = v.reshape((v.shape[0], v.shape[1], -1))
-    # save the new shape size of the images
-    num_img_elem = vr.shape[2]
     # initiate output tensor
     ret = np.zeros_like(vr)
     # iterate over thresholds channel-wise
@@ -85,12 +73,10 @@ def multithreshold(v, thresholds, out_scale=None, out_bias=None):
         channel_thresh = thresholds[0] if is_global_threshold else thresholds[t]
         # iterate over batches
         for b in range(num_batch):
-            # iterate over image elements on which the thresholds will be applied
-            for elem in range(num_img_elem):
-                # iterate over the different thresholds for one channel
-                for a in range(num_act):
-                    # apply successive thresholding to every element
-                    ret[b][t][elem] += compare(vr[b][t][elem], channel_thresh[a])
+            # iterate over the different thresholds for one channel
+            for a in range(num_act):
+                ret[b][t] += (vr[b][t] >= channel_thresh[a]).astype(int)
+
     if out_scale is None:
         out_scale = 1.0
     if out_bias is None:
-- 
GitLab