diff --git a/src/finn/core/multithreshold.py b/src/finn/core/multithreshold.py index 009259c577879a8aa09ac44ace704af55ca2593d..2a26829803d0bbdbd0fb088dca13d38103f836a2 100755 --- a/src/finn/core/multithreshold.py +++ b/src/finn/core/multithreshold.py @@ -8,7 +8,7 @@ def compare(x, y): return 0.0 -def execute(v, thresholds): +def execute(v, thresholds, out_scale=1.0, out_bias=0.0): # the inputs are expected to be in the shape (N,C,H,W) # N : Batch size @@ -21,6 +21,8 @@ def execute(v, thresholds): # if all channels use the same threshold value) # B : Desired activation steps => i.e. for 4-bit activation, B=7 (2^(n)-1 and n=4) + # the output tensor will be scaled by out_scale and biased by out_bias + # assert threshold shape is_global_threshold = thresholds.shape[0] == 1 assert (v.shape[1] == thresholds.shape[0]) or is_global_threshold @@ -55,4 +57,4 @@ def execute(v, thresholds): # apply successive thresholding to every element of one channel ret[b][t][elem] += compare(vr[b][t][elem], channel_thresh[a]) - return ret.reshape(v.shape) + return out_scale * ret.reshape(v.shape) + out_bias