From b0941313b762c3fd90a8f59d9a62e4cc47ad7462 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <yamanu@xilinx.com> Date: Fri, 1 Nov 2019 15:01:47 +0000 Subject: [PATCH] [MultiThres] support applying same threshold to all channels --- src/finn/core/multithreshold.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/finn/core/multithreshold.py b/src/finn/core/multithreshold.py index 23b5cca5a..009259c57 100755 --- a/src/finn/core/multithreshold.py +++ b/src/finn/core/multithreshold.py @@ -17,11 +17,13 @@ def execute(v, thresholds): # W : Width of the input images # # the thresholds are expected to be in the shape (C, B) - # C : Number of channels (must be the same value as C in input tensor) + # C : Number of channels (must be the same value as C in input tensor or 1 + # if all channels use the same threshold value) # B : Desired activation steps => i.e. for 4-bit activation, B=7 (2^(n)-1 and n=4) - # assert if channel sizes do not match - assert v.shape[1] == thresholds.shape[0] + # assert threshold shape + is_global_threshold = thresholds.shape[0] == 1 + assert (v.shape[1] == thresholds.shape[0]) or is_global_threshold # save the required shape sizes for the loops (N, C and B) num_batch = v.shape[0] @@ -40,7 +42,7 @@ def execute(v, thresholds): # iterate over thresholds channel-wise for t in range(num_channel): - channel_thresh = thresholds[t] + channel_thresh = thresholds[0] if is_global_threshold else thresholds[t] # iterate over batches for b in range(num_batch): -- GitLab