diff --git a/src/finn/util/test.py b/src/finn/util/test.py index 3cd4248c5fbf438ac7dd7974adb38d251d389a07..e070f8d89d668f3d068cb75df83e5700a24e49b0 100644 --- a/src/finn/util/test.py +++ b/src/finn/util/test.py @@ -77,6 +77,11 @@ def get_test_model_untrained(netname, wbits, abits): return get_test_model(netname, wbits, abits, pretrained=False) +def get_topk(vec, k): + "Return indices of the top-k values in given array vec (treated as 1D)." + return np.flip(vec.flatten().argsort())[:k] + + def soft_verify_topk(invec, idxvec, k): """Check that the topK indices provided actually point to the topK largest values in the input vector"""