Refactor _scores2guesses

This commit is contained in:
richardpaulhudson 2022-12-22 13:50:53 +01:00
parent eef3d950b4
commit 278181ab59

View File

@ -160,27 +160,18 @@ class EditTreeLemmatizer(TrainablePipe):
def _scores2guesses(self, docs, scores): def _scores2guesses(self, docs, scores):
guesses = [] guesses = []
for doc, doc_scores in zip(docs, scores): for doc, doc_scores in zip(docs, scores):
if self.top_k == 1:
doc_guesses = doc_scores.argmax(axis=1).reshape(-1, 1)
else:
doc_guesses = np.argsort(doc_scores)[..., : -self.top_k - 1 : -1]
if not isinstance(doc_guesses, np.ndarray):
doc_guesses = doc_guesses.get()
doc_compat_guesses = [] doc_compat_guesses = []
for token, candidates in zip(doc, doc_guesses): for i, token in enumerate(doc):
tree_id = -1 for _ in range(self.top_k):
for candidate in candidates: candidate = doc_scores[i].argmax()
candidate_tree_id = self.cfg["labels"][candidate] candidate_tree_id = self.cfg["labels"][candidate]
if self.trees.apply(candidate_tree_id, token.text) is not None: if self.trees.apply(candidate_tree_id, token.text) is not None:
tree_id = candidate_tree_id doc_compat_guesses.append(candidate_tree_id)
break break
doc_compat_guesses.append(tree_id) doc_scores[i, candidate] = -1
else:
doc_compat_guesses.append(-1)
guesses.append(np.array(doc_compat_guesses)) guesses.append(np.array(doc_compat_guesses))
return guesses return guesses
def set_annotations(self, docs: Iterable[Doc], batch_tree_ids): def set_annotations(self, docs: Iterable[Doc], batch_tree_ids):