Skip to content
Snippets Groups Projects
Commit b0941313 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[MultiThres] support applying same threshold to all channels

parent ba43d9c9
No related branches found
No related tags found
No related merge requests found
......@@ -17,11 +17,13 @@ def execute(v, thresholds):
# W : Width of the input images
#
# the thresholds are expected to be in the shape (C, B)
# C : Number of channels (must be the same value as C in input tensor)
# C : Number of channels (must be the same value as C in input tensor or 1
# if all channels use the same threshold value)
# B : Desired activation steps => i.e. for 4-bit activation, B=7 (2^(n)-1 and n=4)
# assert if channel sizes do not match
assert v.shape[1] == thresholds.shape[0]
# assert threshold shape
is_global_threshold = thresholds.shape[0] == 1
assert (v.shape[1] == thresholds.shape[0]) or is_global_threshold
# save the required shape sizes for the loops (N, C and B)
num_batch = v.shape[0]
......@@ -40,7 +42,7 @@ def execute(v, thresholds):
# iterate over thresholds channel-wise
for t in range(num_channel):
channel_thresh = thresholds[t]
channel_thresh = thresholds[0] if is_global_threshold else thresholds[t]
# iterate over batches
for b in range(num_batch):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment