Change topk to sort descending

Shouldn't change correctness but is a little clearer
This commit is contained in:
Paul O'Leary McCann 2021-06-13 19:42:24 +09:00
parent d71198ed36
commit 96be7e8858

View File

@ -34,7 +34,7 @@ def topk(xp, arr, k, axis=None):
vals = xp.take_along_axis(arr, idxs, axis=1)
sidxs = xp.argsort(vals, axis=1)
sidxs = xp.argsort(-vals, axis=1)
# map these idxs back to the original
oidxs = xp.take_along_axis(idxs, sidxs, axis=1)
svals = xp.take_along_axis(vals, sidxs, axis=1)