From 4daf5e9b81313bcea1e4837f37ffea091c03a114 Mon Sep 17 00:00:00 2001 From: richardpaulhudson Date: Fri, 23 Dec 2022 15:36:45 +0100 Subject: [PATCH] Changes based on review comments --- spacy/pipeline/edit_tree_lemmatizer.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/spacy/pipeline/edit_tree_lemmatizer.py b/spacy/pipeline/edit_tree_lemmatizer.py index 9c5ca0846..1483096d1 100644 --- a/spacy/pipeline/edit_tree_lemmatizer.py +++ b/spacy/pipeline/edit_tree_lemmatizer.py @@ -5,8 +5,8 @@ from itertools import islice import numpy as np import srsly -from thinc.api import Config, Model, SequenceCategoricalCrossentropy -from thinc.types import Floats2d, Ints1d, Ints2d +from thinc.api import Config, Model, SequenceCategoricalCrossentropy, NumpyOps +from thinc.types import Floats2d, Ints2d from ._edit_tree_internals.edit_trees import EditTrees from ._edit_tree_internals.schemas import validate_edit_tree @@ -115,6 +115,7 @@ class EditTreeLemmatizer(TrainablePipe): self.cfg: Dict[str, Any] = {"labels": []} self.scorer = scorer + self.numpy_ops = NumpyOps() def get_loss( self, examples: Iterable[Example], scores: List[Floats2d] @@ -159,17 +160,18 @@ class EditTreeLemmatizer(TrainablePipe): def _scores2guesses(self, docs, scores): guesses = [] + predictions_to_consider = min(self.top_k, len(self.labels)) for doc, doc_scores in zip(docs, scores): - NumpyOps().asarray(doc_scores) + doc_scores = self.numpy_ops.asarray(doc_scores) doc_compat_guesses = [] for i, token in enumerate(doc): - for _ in range(self.top_k): + for _ in range(predictions_to_consider): candidate = int(doc_scores[i].argmax()) candidate_tree_id = self.cfg["labels"][candidate] if self.trees.apply(candidate_tree_id, token.text) is not None: doc_compat_guesses.append(candidate_tree_id) break - doc_scores[i, candidate] = -1 + doc_scores[i, candidate] = np.finfo(np.float32).min else: doc_compat_guesses.append(-1) guesses.append(np.array(doc_compat_guesses))