diff --git a/src/finn/custom_op/multithreshold.py b/src/finn/custom_op/multithreshold.py index 37f8e0950b5fc352c8f9fe005884724f028879a0..bc0a454cdf847d124b12c940b029f51bf2d3e778 100644 --- a/src/finn/custom_op/multithreshold.py +++ b/src/finn/custom_op/multithreshold.py @@ -33,16 +33,6 @@ from finn.core.datatype import DataType from finn.custom_op import CustomOp -def compare(x, y): - """Comparison helper function for multithresholding. - - Gets two values and returns 1.0 if x>=y otherwise 0.0.""" - if x >= y: - return 1.0 - else: - return 0.0 - - def multithreshold(v, thresholds, out_scale=None, out_bias=None): """Given a set of threshold values t={t_0, t_1 ... t_n} the successive thresholding maps any real number x to an integer in the interval [0, n], @@ -76,8 +66,6 @@ def multithreshold(v, thresholds, out_scale=None, out_bias=None): num_act = thresholds.shape[1] # reshape inputs to enable channel-wise reading vr = v.reshape((v.shape[0], v.shape[1], -1)) - # save the new shape size of the images - num_img_elem = vr.shape[2] # initiate output tensor ret = np.zeros_like(vr) # iterate over thresholds channel-wise @@ -85,12 +73,10 @@ def multithreshold(v, thresholds, out_scale=None, out_bias=None): channel_thresh = thresholds[0] if is_global_threshold else thresholds[t] # iterate over batches for b in range(num_batch): - # iterate over image elements on which the thresholds will be applied - for elem in range(num_img_elem): - # iterate over the different thresholds for one channel - for a in range(num_act): - # apply successive thresholding to every element - ret[b][t][elem] += compare(vr[b][t][elem], channel_thresh[a]) + # iterate over the different thresholds for one channel + for a in range(num_act): + ret[b][t] += (vr[b][t] >= channel_thresh[a]).astype(int) + if out_scale is None: out_scale = 1.0 if out_bias is None: diff --git a/tests/custom_op/test_multi_thresholding.py b/tests/custom_op/test_multithreshold.py similarity index 61% rename from tests/custom_op/test_multi_thresholding.py rename to tests/custom_op/test_multithreshold.py index 4f2b08675fdabb1bda49972c51892da92e1a0cdc..7e6ad4fe08517290dd22a2c74b2847d007b74b1f 100644 --- a/tests/custom_op/test_multi_thresholding.py +++ b/tests/custom_op/test_multithreshold.py @@ -27,11 +27,76 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import numpy as np - +import time from finn.custom_op.multithreshold import multithreshold -def test_execute_multi_thresholding(): +def compare(x, y): + """Comparison helper function for multithresholding. + + Gets two values and returns 1.0 if x>=y otherwise 0.0.""" + if x >= y: + return 1.0 + else: + return 0.0 + +# naive implementation of thresholding for performance comparison +def multithreshold_elementwise(v, thresholds, out_scale=None, out_bias=None): + """Given a set of threshold values t={t_0, t_1 ... t_n} the successive + thresholding maps any real number x to an integer in the interval [0, n], + where the returned integer is the number of thresholds x is greater than + or equal to. + + The output tensor will be scaled by out_scale and biased by out_bias.""" + # the inputs are expected to be in the shape (N,C,H,W) or (N, C) + # the MultiThreshold node supports a data_layout attribute that can be set + # to 'NHWC' to support (N,H,W,C) data layout mode for in-out as well + # N : Batch size + # C : Number of channels + # H : Heigth of the input images + # W : Width of the input images + # + # the thresholds are expected to be in the shape (C, B) + # C : Number of channels (must be the same value as C in input tensor + # or 1 if all channels use the same threshold value) + # B : Desired activation steps => i.e. for 4-bit activation, + # B=7 (2^(n)-1 and n=4) + # the output tensor will be scaled by out_scale and biased by out_bias + # assert threshold shape + is_global_threshold = thresholds.shape[0] == 1 + assert ( + v.shape[1] == thresholds.shape[0] + ) or is_global_threshold, """"Threshold + shape incorrect""" + # save the required shape sizes for the loops (N, C and B) + num_batch = v.shape[0] + num_channel = v.shape[1] + num_act = thresholds.shape[1] + # reshape inputs to enable channel-wise reading + vr = v.reshape((v.shape[0], v.shape[1], -1)) + # save the new shape size of the images + num_img_elem = vr.shape[2] + # initiate output tensor + ret = np.zeros_like(vr) + # iterate over thresholds channel-wise + for t in range(num_channel): + channel_thresh = thresholds[0] if is_global_threshold else thresholds[t] + # iterate over batches + for b in range(num_batch): + # iterate over image elements on which the thresholds will be applied + for elem in range(num_img_elem): + # iterate over the different thresholds for one channel + for a in range(num_act): + # apply successive thresholding to every element + ret[b][t][elem] += compare(vr[b][t][elem], channel_thresh[a]) + if out_scale is None: + out_scale = 1.0 + if out_bias is None: + out_bias = 0.0 + return out_scale * ret.reshape(v.shape) + out_bias + + +def test_multithreshold(): inputs = np.ndarray( shape=(6, 3, 2, 2), @@ -223,9 +288,35 @@ def test_execute_multi_thresholding(): ) results = multithreshold(inputs, thresholds) - assert (results == outputs).all() results_scaled = multithreshold(inputs, thresholds, 2.0, -1.0) outputs_scaled = 2.0 * outputs - 1.0 assert (results_scaled == outputs_scaled).all() + + # performance and random test + np.random.seed(0) + inputs = np.random.random((1, 256, 64, 64)) + thresholds = (np.array([[1, 2, 3, 4, 5, 6]]) - 0.5) / 6 + + before = time.time() + vec_results = multithreshold(inputs, thresholds) + after = time.time() + vector_runtime = after - before + + before = time.time() + nonvec_results = multithreshold_elementwise(inputs, thresholds) + after = time.time() + non_vector_runtime = after - before + + assert (vec_results == nonvec_results).all() + + return vector_runtime, non_vector_runtime + + +if __name__ == "__main__": + vector_runtime, non_vector_runtime = test_multithreshold() + + print("Runtime non-vectorized: ", non_vector_runtime, "s") + print("Runtime vectorized: ", vector_runtime, "s") + print("Speed-up: ", non_vector_runtime / vector_runtime)