mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-19 20:52:23 +03:00
Fix axis handling in topk
In practice this is only ever used with axis=1, so it wasn't causing issues, even though it was wrong.
This commit is contained in:
parent
f2e0e9dc28
commit
d74fa82c80
|
@ -26,18 +26,18 @@ def doc2clusters(doc: Doc, prefix=DEFAULT_CLUSTER_PREFIX) -> MentionClusters:
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def topk(xp, arr, k, axis=None):
|
def topk(xp, arr, k, axis=1):
|
||||||
"""Given and array and a k value, give the top values and idxs for each row."""
|
"""Given an array and a k value, give the top values and idxs for each row."""
|
||||||
|
|
||||||
part = xp.argpartition(arr, -k, axis=1)
|
part = xp.argpartition(arr, -k, axis=axis)
|
||||||
idxs = xp.flip(part)[:, :k]
|
idxs = xp.flip(part)[:, :k]
|
||||||
|
|
||||||
vals = xp.take_along_axis(arr, idxs, axis=1)
|
vals = xp.take_along_axis(arr, idxs, axis=axis)
|
||||||
|
|
||||||
sidxs = xp.argsort(-vals, axis=1)
|
sidxs = xp.argsort(-vals, axis=axis)
|
||||||
# map these idxs back to the original
|
# map these idxs back to the original
|
||||||
oidxs = xp.take_along_axis(idxs, sidxs, axis=1)
|
oidxs = xp.take_along_axis(idxs, sidxs, axis=axis)
|
||||||
svals = xp.take_along_axis(vals, sidxs, axis=1)
|
svals = xp.take_along_axis(vals, sidxs, axis=axis)
|
||||||
return svals, oidxs
|
return svals, oidxs
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user