diff --git a/src/finn/core/multi_thresholding.py b/src/finn/core/multi_thresholding.py index 2bd6aee2c551bef77839e358c5e9cf0c6a4f1ec5..f3b64f0dd90e319a3521be853d8c51180e936a04 100755 --- a/src/finn/core/multi_thresholding.py +++ b/src/finn/core/multi_thresholding.py @@ -2,38 +2,48 @@ import numpy as np def compare(x, y): - if x >= y: - return 1.0 - else: - return 0.0 + if x >= y: + return 1.0 + else: + return 0.0 def execute(v, thresholds): - # reshape inputs to enable channel-wise reading - vr = v.reshape((thresholds.shape[1], -1)) - - # calculate the channelinterval for the for loops - num_channels = thresholds.shape[0] - channel_interval = int(vr.shape[1] / num_channels) - - # initiate output tensor - ret = np.zeros_like(vr) - - # initiate helper variable i for channel-wise thresholding - i = -1 - - # iterate over thresholds channel-wise - for t in thresholds: - print(t) - i += 1 - - # calculate the lower and upper limit in which elements belong to one channel - ce1_low_lim = i * channel_interval - ce1_up_lim = (i + 1) * channel_interval - - # iterate in ascending order over the thresholds belonging to one channel - for c in range(thresholds.shape[1]): - for ce0 in range(vr.shape[0]): - for ce1 in range(ce1_low_lim, ce1_up_lim): - ret[ce0][ce1] += compare(vr[ce0][ce1], t[c]) - return ret.reshape(v.shape) + + # the inputs are expected to be in the shape (N,C,H,W) + # N : Batch size + # C : Number of channels + # H : Heigth of the input images + # W : Width of the input images + # + # the thresholds are expected to be in the shape (C, B) + # C : Number of channels (must be the same value as C in input tensor) + # B : Desired activation steps => i.e. for 4-bit activation, B=7 (2^(n)-1 and n=4) + + # assert if channel sizes do not match + assert v.shape[1] == thresholds.shape[0] + + num_batch = v.shape[0] + num_channel = v.shape[1] + + num_act = thresholds.shape[1] + + # reshape inputs to enable channel-wise reading + vr = v.reshape((v.shape[0], v.shape[1], -1)) + + num_img_elem = vr.shape[2] + + # initiate output tensor + ret = np.zeros_like(vr) + + # iterate over thresholds channel-wise + for t in range(num_channel): + channel_thresh = thresholds[t] + for b in range(num_batch): + for elem in range(num_img_elem): + print(vr[b][t][elem]) + for a in range(num_act): + print(channel_thresh[a]) + ret[b][t][elem] += compare(vr[b][t][elem], channel_thresh[a]) + print(ret) + return ret.reshape(v.shape) diff --git a/tests/test_multi_thresholding.py b/tests/test_multi_thresholding.py deleted file mode 100644 index 7e1033cf8dbe38ce7269be685cc66bd7d1d877d0..0000000000000000000000000000000000000000 --- a/tests/test_multi_thresholding.py +++ /dev/null @@ -1,24 +0,0 @@ -import numpy as np - -import finn.core.multi_thresholding as multi_thresh - - -def test_execute_multi_thresholding(): - inputs = np.genfromtxt( - "../src/finn/data/multi-thresholding/input.csv", delimiter="," - ) - inputs = inputs.reshape(7, 3, 2, 2) - - thresholds = np.genfromtxt( - "../src/finn/data/multi-thresholding/thresholds.csv", delimiter="," - ) - thresholds = thresholds.reshape(3, 7) - - outputs = np.genfromtxt( - "../src/finn/data/multi-thresholding/output.csv", delimiter="," - ) - outputs = outputs.reshape(7, 3, 2, 2) - - results = multi_thresh.execute(inputs, thresholds) - - assert np.isclose(outputs, results, atol=1e-3).all()