diff --git a/spacy/pipeline/edit_tree_lemmatizer.py b/spacy/pipeline/edit_tree_lemmatizer.py index a56c9975e..c65770955 100644 --- a/spacy/pipeline/edit_tree_lemmatizer.py +++ b/spacy/pipeline/edit_tree_lemmatizer.py @@ -160,27 +160,18 @@ class EditTreeLemmatizer(TrainablePipe): def _scores2guesses(self, docs, scores): guesses = [] for doc, doc_scores in zip(docs, scores): - if self.top_k == 1: - doc_guesses = doc_scores.argmax(axis=1).reshape(-1, 1) - else: - doc_guesses = np.argsort(doc_scores)[..., : -self.top_k - 1 : -1] - - if not isinstance(doc_guesses, np.ndarray): - doc_guesses = doc_guesses.get() - doc_compat_guesses = [] - for token, candidates in zip(doc, doc_guesses): - tree_id = -1 - for candidate in candidates: + for i, token in enumerate(doc): + for _ in range(self.top_k): + candidate = doc_scores[i].argmax() candidate_tree_id = self.cfg["labels"][candidate] - if self.trees.apply(candidate_tree_id, token.text) is not None: - tree_id = candidate_tree_id + doc_compat_guesses.append(candidate_tree_id) break - doc_compat_guesses.append(tree_id) - + doc_scores[i, candidate] = -1 + else: + doc_compat_guesses.append(-1) guesses.append(np.array(doc_compat_guesses)) - return guesses def set_annotations(self, docs: Iterable[Doc], batch_tree_ids):