From d74fa82c80a3ccdd6f78fbf02c824d82e5e7e2e8 Mon Sep 17 00:00:00 2001 From: Paul O'Leary McCann Date: Sat, 3 Jul 2021 18:39:25 +0900 Subject: [PATCH] Fix axis handling in topk In practice this is only ever used with axis=1, so it wasn't causing issues, even though it was wrong. --- spacy/ml/models/coref_util.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/spacy/ml/models/coref_util.py b/spacy/ml/models/coref_util.py index 56b238c2f..e045ad31b 100644 --- a/spacy/ml/models/coref_util.py +++ b/spacy/ml/models/coref_util.py @@ -26,18 +26,18 @@ def doc2clusters(doc: Doc, prefix=DEFAULT_CLUSTER_PREFIX) -> MentionClusters: return out -def topk(xp, arr, k, axis=None): - """Given and array and a k value, give the top values and idxs for each row.""" +def topk(xp, arr, k, axis=1): + """Given an array and a k value, give the top values and idxs for each row.""" - part = xp.argpartition(arr, -k, axis=1) + part = xp.argpartition(arr, -k, axis=axis) idxs = xp.flip(part)[:, :k] - vals = xp.take_along_axis(arr, idxs, axis=1) + vals = xp.take_along_axis(arr, idxs, axis=axis) - sidxs = xp.argsort(-vals, axis=1) + sidxs = xp.argsort(-vals, axis=axis) # map these idxs back to the original - oidxs = xp.take_along_axis(idxs, sidxs, axis=1) - svals = xp.take_along_axis(vals, sidxs, axis=1) + oidxs = xp.take_along_axis(idxs, sidxs, axis=axis) + svals = xp.take_along_axis(vals, sidxs, axis=axis) return svals, oidxs