Skip to content
Snippets Groups Projects
Unverified Commit c8ccc29e authored by Yaman Umuroglu's avatar Yaman Umuroglu Committed by GitHub
Browse files

Merge pull request #103 from quetric/feature/vectorized_multithreshold_fxn

Feature/vectorized multithreshold fxn
parents daf33b47 278226b5
No related branches found
No related tags found
No related merge requests found
......@@ -33,16 +33,6 @@ from finn.core.datatype import DataType
from finn.custom_op import CustomOp
def compare(x, y):
"""Comparison helper function for multithresholding.
Gets two values and returns 1.0 if x>=y otherwise 0.0."""
if x >= y:
return 1.0
else:
return 0.0
def multithreshold(v, thresholds, out_scale=None, out_bias=None):
"""Given a set of threshold values t={t_0, t_1 ... t_n} the successive
thresholding maps any real number x to an integer in the interval [0, n],
......@@ -76,8 +66,6 @@ def multithreshold(v, thresholds, out_scale=None, out_bias=None):
num_act = thresholds.shape[1]
# reshape inputs to enable channel-wise reading
vr = v.reshape((v.shape[0], v.shape[1], -1))
# save the new shape size of the images
num_img_elem = vr.shape[2]
# initiate output tensor
ret = np.zeros_like(vr)
# iterate over thresholds channel-wise
......@@ -85,12 +73,10 @@ def multithreshold(v, thresholds, out_scale=None, out_bias=None):
channel_thresh = thresholds[0] if is_global_threshold else thresholds[t]
# iterate over batches
for b in range(num_batch):
# iterate over image elements on which the thresholds will be applied
for elem in range(num_img_elem):
# iterate over the different thresholds for one channel
for a in range(num_act):
# apply successive thresholding to every element
ret[b][t][elem] += compare(vr[b][t][elem], channel_thresh[a])
# iterate over the different thresholds for one channel
for a in range(num_act):
ret[b][t] += (vr[b][t] >= channel_thresh[a]).astype(int)
if out_scale is None:
out_scale = 1.0
if out_bias is None:
......
......@@ -27,11 +27,76 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import numpy as np
import time
from finn.custom_op.multithreshold import multithreshold
def test_execute_multi_thresholding():
def compare(x, y):
"""Comparison helper function for multithresholding.
Gets two values and returns 1.0 if x>=y otherwise 0.0."""
if x >= y:
return 1.0
else:
return 0.0
# naive implementation of thresholding for performance comparison
def multithreshold_elementwise(v, thresholds, out_scale=None, out_bias=None):
"""Given a set of threshold values t={t_0, t_1 ... t_n} the successive
thresholding maps any real number x to an integer in the interval [0, n],
where the returned integer is the number of thresholds x is greater than
or equal to.
The output tensor will be scaled by out_scale and biased by out_bias."""
# the inputs are expected to be in the shape (N,C,H,W) or (N, C)
# the MultiThreshold node supports a data_layout attribute that can be set
# to 'NHWC' to support (N,H,W,C) data layout mode for in-out as well
# N : Batch size
# C : Number of channels
# H : Heigth of the input images
# W : Width of the input images
#
# the thresholds are expected to be in the shape (C, B)
# C : Number of channels (must be the same value as C in input tensor
# or 1 if all channels use the same threshold value)
# B : Desired activation steps => i.e. for 4-bit activation,
# B=7 (2^(n)-1 and n=4)
# the output tensor will be scaled by out_scale and biased by out_bias
# assert threshold shape
is_global_threshold = thresholds.shape[0] == 1
assert (
v.shape[1] == thresholds.shape[0]
) or is_global_threshold, """"Threshold
shape incorrect"""
# save the required shape sizes for the loops (N, C and B)
num_batch = v.shape[0]
num_channel = v.shape[1]
num_act = thresholds.shape[1]
# reshape inputs to enable channel-wise reading
vr = v.reshape((v.shape[0], v.shape[1], -1))
# save the new shape size of the images
num_img_elem = vr.shape[2]
# initiate output tensor
ret = np.zeros_like(vr)
# iterate over thresholds channel-wise
for t in range(num_channel):
channel_thresh = thresholds[0] if is_global_threshold else thresholds[t]
# iterate over batches
for b in range(num_batch):
# iterate over image elements on which the thresholds will be applied
for elem in range(num_img_elem):
# iterate over the different thresholds for one channel
for a in range(num_act):
# apply successive thresholding to every element
ret[b][t][elem] += compare(vr[b][t][elem], channel_thresh[a])
if out_scale is None:
out_scale = 1.0
if out_bias is None:
out_bias = 0.0
return out_scale * ret.reshape(v.shape) + out_bias
def test_multithreshold():
inputs = np.ndarray(
shape=(6, 3, 2, 2),
......@@ -223,9 +288,35 @@ def test_execute_multi_thresholding():
)
results = multithreshold(inputs, thresholds)
assert (results == outputs).all()
results_scaled = multithreshold(inputs, thresholds, 2.0, -1.0)
outputs_scaled = 2.0 * outputs - 1.0
assert (results_scaled == outputs_scaled).all()
# performance and random test
np.random.seed(0)
inputs = np.random.random((1, 256, 64, 64))
thresholds = (np.array([[1, 2, 3, 4, 5, 6]]) - 0.5) / 6
before = time.time()
vec_results = multithreshold(inputs, thresholds)
after = time.time()
vector_runtime = after - before
before = time.time()
nonvec_results = multithreshold_elementwise(inputs, thresholds)
after = time.time()
non_vector_runtime = after - before
assert (vec_results == nonvec_results).all()
return vector_runtime, non_vector_runtime
if __name__ == "__main__":
vector_runtime, non_vector_runtime = test_multithreshold()
print("Runtime non-vectorized: ", non_vector_runtime, "s")
print("Runtime vectorized: ", vector_runtime, "s")
print("Speed-up: ", non_vector_runtime / vector_runtime)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment