Fix axis handling in topk

In practice this is only ever used with axis=1, so it wasn't causing
issues, even though it was wrong.
This commit is contained in:
Paul O'Leary McCann 2021-07-03 18:39:25 +09:00
parent f2e0e9dc28
commit d74fa82c80

View File

@ -26,18 +26,18 @@ def doc2clusters(doc: Doc, prefix=DEFAULT_CLUSTER_PREFIX) -> MentionClusters:
return out
def topk(xp, arr, k, axis=None):
"""Given and array and a k value, give the top values and idxs for each row."""
def topk(xp, arr, k, axis=1):
"""Given an array and a k value, give the top values and idxs for each row."""
part = xp.argpartition(arr, -k, axis=1)
part = xp.argpartition(arr, -k, axis=axis)
idxs = xp.flip(part)[:, :k]
vals = xp.take_along_axis(arr, idxs, axis=1)
vals = xp.take_along_axis(arr, idxs, axis=axis)
sidxs = xp.argsort(-vals, axis=1)
sidxs = xp.argsort(-vals, axis=axis)
# map these idxs back to the original
oidxs = xp.take_along_axis(idxs, sidxs, axis=1)
svals = xp.take_along_axis(vals, sidxs, axis=1)
oidxs = xp.take_along_axis(idxs, sidxs, axis=axis)
svals = xp.take_along_axis(vals, sidxs, axis=axis)
return svals, oidxs