Skip to content
Snippets Groups Projects
Commit 6953ea00 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Exec] enable optional scale/shift args for multithresholding

parent e709d19d
No related branches found
No related tags found
No related merge requests found
......@@ -8,7 +8,7 @@ def compare(x, y):
return 0.0
def execute(v, thresholds):
def execute(v, thresholds, out_scale=1.0, out_bias=0.0):
# the inputs are expected to be in the shape (N,C,H,W)
# N : Batch size
......@@ -21,6 +21,8 @@ def execute(v, thresholds):
# if all channels use the same threshold value)
# B : Desired activation steps => i.e. for 4-bit activation, B=7 (2^(n)-1 and n=4)
# the output tensor will be scaled by out_scale and biased by out_bias
# assert threshold shape
is_global_threshold = thresholds.shape[0] == 1
assert (v.shape[1] == thresholds.shape[0]) or is_global_threshold
......@@ -55,4 +57,4 @@ def execute(v, thresholds):
# apply successive thresholding to every element of one channel
ret[b][t][elem] += compare(vr[b][t][elem], channel_thresh[a])
return ret.reshape(v.shape)
return out_scale * ret.reshape(v.shape) + out_bias
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment