diff --git a/spacy/ml/models/coref_util.py b/spacy/ml/models/coref_util.py index 56b238c2f..e045ad31b 100644 --- a/spacy/ml/models/coref_util.py +++ b/spacy/ml/models/coref_util.py @@ -26,18 +26,18 @@ def doc2clusters(doc: Doc, prefix=DEFAULT_CLUSTER_PREFIX) -> MentionClusters: return out -def topk(xp, arr, k, axis=None): - """Given and array and a k value, give the top values and idxs for each row.""" +def topk(xp, arr, k, axis=1): + """Given an array and a k value, give the top values and idxs for each row.""" - part = xp.argpartition(arr, -k, axis=1) + part = xp.argpartition(arr, -k, axis=axis) idxs = xp.flip(part)[:, :k] - vals = xp.take_along_axis(arr, idxs, axis=1) + vals = xp.take_along_axis(arr, idxs, axis=axis) - sidxs = xp.argsort(-vals, axis=1) + sidxs = xp.argsort(-vals, axis=axis) # map these idxs back to the original - oidxs = xp.take_along_axis(idxs, sidxs, axis=1) - svals = xp.take_along_axis(vals, sidxs, axis=1) + oidxs = xp.take_along_axis(idxs, sidxs, axis=axis) + svals = xp.take_along_axis(vals, sidxs, axis=axis) return svals, oidxs