mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-19 20:52:23 +03:00
Fix axis handling in topk
In practice this is only ever used with axis=1, so it wasn't causing issues, even though it was wrong.
This commit is contained in:
parent
f2e0e9dc28
commit
d74fa82c80
|
@ -26,18 +26,18 @@ def doc2clusters(doc: Doc, prefix=DEFAULT_CLUSTER_PREFIX) -> MentionClusters:
|
|||
return out
|
||||
|
||||
|
||||
def topk(xp, arr, k, axis=None):
|
||||
"""Given and array and a k value, give the top values and idxs for each row."""
|
||||
def topk(xp, arr, k, axis=1):
|
||||
"""Given an array and a k value, give the top values and idxs for each row."""
|
||||
|
||||
part = xp.argpartition(arr, -k, axis=1)
|
||||
part = xp.argpartition(arr, -k, axis=axis)
|
||||
idxs = xp.flip(part)[:, :k]
|
||||
|
||||
vals = xp.take_along_axis(arr, idxs, axis=1)
|
||||
vals = xp.take_along_axis(arr, idxs, axis=axis)
|
||||
|
||||
sidxs = xp.argsort(-vals, axis=1)
|
||||
sidxs = xp.argsort(-vals, axis=axis)
|
||||
# map these idxs back to the original
|
||||
oidxs = xp.take_along_axis(idxs, sidxs, axis=1)
|
||||
svals = xp.take_along_axis(vals, sidxs, axis=1)
|
||||
oidxs = xp.take_along_axis(idxs, sidxs, axis=axis)
|
||||
svals = xp.take_along_axis(vals, sidxs, axis=axis)
|
||||
return svals, oidxs
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user