Changes based on review comments

This commit is contained in:
richardpaulhudson 2022-12-23 15:36:45 +01:00
parent 361b64e648
commit 4daf5e9b81

View File

@ -5,8 +5,8 @@ from itertools import islice
import numpy as np import numpy as np
import srsly import srsly
from thinc.api import Config, Model, SequenceCategoricalCrossentropy from thinc.api import Config, Model, SequenceCategoricalCrossentropy, NumpyOps
from thinc.types import Floats2d, Ints1d, Ints2d from thinc.types import Floats2d, Ints2d
from ._edit_tree_internals.edit_trees import EditTrees from ._edit_tree_internals.edit_trees import EditTrees
from ._edit_tree_internals.schemas import validate_edit_tree from ._edit_tree_internals.schemas import validate_edit_tree
@ -115,6 +115,7 @@ class EditTreeLemmatizer(TrainablePipe):
self.cfg: Dict[str, Any] = {"labels": []} self.cfg: Dict[str, Any] = {"labels": []}
self.scorer = scorer self.scorer = scorer
self.numpy_ops = NumpyOps()
def get_loss( def get_loss(
self, examples: Iterable[Example], scores: List[Floats2d] self, examples: Iterable[Example], scores: List[Floats2d]
@ -159,17 +160,18 @@ class EditTreeLemmatizer(TrainablePipe):
def _scores2guesses(self, docs, scores): def _scores2guesses(self, docs, scores):
guesses = [] guesses = []
predictions_to_consider = min(self.top_k, len(self.labels))
for doc, doc_scores in zip(docs, scores): for doc, doc_scores in zip(docs, scores):
NumpyOps().asarray(doc_scores) doc_scores = self.numpy_ops.asarray(doc_scores)
doc_compat_guesses = [] doc_compat_guesses = []
for i, token in enumerate(doc): for i, token in enumerate(doc):
for _ in range(self.top_k): for _ in range(predictions_to_consider):
candidate = int(doc_scores[i].argmax()) candidate = int(doc_scores[i].argmax())
candidate_tree_id = self.cfg["labels"][candidate] candidate_tree_id = self.cfg["labels"][candidate]
if self.trees.apply(candidate_tree_id, token.text) is not None: if self.trees.apply(candidate_tree_id, token.text) is not None:
doc_compat_guesses.append(candidate_tree_id) doc_compat_guesses.append(candidate_tree_id)
break break
doc_scores[i, candidate] = -1 doc_scores[i, candidate] = np.finfo(np.float32).min
else: else:
doc_compat_guesses.append(-1) doc_compat_guesses.append(-1)
guesses.append(np.array(doc_compat_guesses)) guesses.append(np.array(doc_compat_guesses))