From b0941313b762c3fd90a8f59d9a62e4cc47ad7462 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <yamanu@xilinx.com>
Date: Fri, 1 Nov 2019 15:01:47 +0000
Subject: [PATCH] [MultiThres] support applying same threshold to all channels

---
 src/finn/core/multithreshold.py | 10 ++++++----
 1 file changed, 6 insertions(+), 4 deletions(-)

diff --git a/src/finn/core/multithreshold.py b/src/finn/core/multithreshold.py
index 23b5cca5a..009259c57 100755
--- a/src/finn/core/multithreshold.py
+++ b/src/finn/core/multithreshold.py
@@ -17,11 +17,13 @@ def execute(v, thresholds):
     # W : Width of the input images
     #
     # the thresholds are expected to be in the shape (C, B)
-    # C : Number of channels (must be the same value as C in input tensor)
+    # C : Number of channels (must be the same value as C in input tensor or 1
+    #     if all channels use the same threshold value)
     # B : Desired activation steps => i.e. for 4-bit activation, B=7 (2^(n)-1 and n=4)
 
-    # assert if channel sizes do not match
-    assert v.shape[1] == thresholds.shape[0]
+    # assert threshold shape
+    is_global_threshold = thresholds.shape[0] == 1
+    assert (v.shape[1] == thresholds.shape[0]) or is_global_threshold
 
     # save the required shape sizes for the loops (N, C and B)
     num_batch = v.shape[0]
@@ -40,7 +42,7 @@ def execute(v, thresholds):
 
     # iterate over thresholds channel-wise
     for t in range(num_channel):
-        channel_thresh = thresholds[t]
+        channel_thresh = thresholds[0] if is_global_threshold else thresholds[t]
 
         # iterate over batches
         for b in range(num_batch):
-- 
GitLab