Convert argmax result to raw integer

Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>
This commit is contained in:
Richard Hudson 2022-12-22 15:36:14 +01:00 committed by GitHub
parent 9430306076
commit 95a4835342
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -165,7 +165,7 @@ class EditTreeLemmatizer(TrainablePipe):
doc_compat_guesses = [] doc_compat_guesses = []
for i, token in enumerate(doc): for i, token in enumerate(doc):
for _ in range(self.top_k): for _ in range(self.top_k):
candidate = doc_scores[i].argmax() candidate = int(doc_scores[i].argmax())
candidate_tree_id = self.cfg["labels"][candidate] candidate_tree_id = self.cfg["labels"][candidate]
if self.trees.apply(candidate_tree_id, token.text) is not None: if self.trees.apply(candidate_tree_id, token.text) is not None:
doc_compat_guesses.append(candidate_tree_id) doc_compat_guesses.append(candidate_tree_id)