diff --git a/src/finn/core/multithreshold.py b/src/finn/core/multithreshold.py index 23b5cca5a137ad9d6be5e06436ffdef212e7828c..009259c577879a8aa09ac44ace704af55ca2593d 100755 --- a/src/finn/core/multithreshold.py +++ b/src/finn/core/multithreshold.py @@ -17,11 +17,13 @@ def execute(v, thresholds): # W : Width of the input images # # the thresholds are expected to be in the shape (C, B) - # C : Number of channels (must be the same value as C in input tensor) + # C : Number of channels (must be the same value as C in input tensor or 1 + # if all channels use the same threshold value) # B : Desired activation steps => i.e. for 4-bit activation, B=7 (2^(n)-1 and n=4) - # assert if channel sizes do not match - assert v.shape[1] == thresholds.shape[0] + # assert threshold shape + is_global_threshold = thresholds.shape[0] == 1 + assert (v.shape[1] == thresholds.shape[0]) or is_global_threshold # save the required shape sizes for the loops (N, C and B) num_batch = v.shape[0] @@ -40,7 +42,7 @@ def execute(v, thresholds): # iterate over thresholds channel-wise for t in range(num_channel): - channel_thresh = thresholds[t] + channel_thresh = thresholds[0] if is_global_threshold else thresholds[t] # iterate over batches for b in range(num_batch):