Changes based on review comments

This commit is contained in:
richardpaulhudson 2022-12-23 15:36:45 +01:00
parent 361b64e648
commit 4daf5e9b81

View File

@ -5,8 +5,8 @@ from itertools import islice
import numpy as np
import srsly
from thinc.api import Config, Model, SequenceCategoricalCrossentropy
from thinc.types import Floats2d, Ints1d, Ints2d
from thinc.api import Config, Model, SequenceCategoricalCrossentropy, NumpyOps
from thinc.types import Floats2d, Ints2d
from ._edit_tree_internals.edit_trees import EditTrees
from ._edit_tree_internals.schemas import validate_edit_tree
@ -115,6 +115,7 @@ class EditTreeLemmatizer(TrainablePipe):
self.cfg: Dict[str, Any] = {"labels": []}
self.scorer = scorer
self.numpy_ops = NumpyOps()
def get_loss(
self, examples: Iterable[Example], scores: List[Floats2d]
@ -159,17 +160,18 @@ class EditTreeLemmatizer(TrainablePipe):
def _scores2guesses(self, docs, scores):
guesses = []
predictions_to_consider = min(self.top_k, len(self.labels))
for doc, doc_scores in zip(docs, scores):
NumpyOps().asarray(doc_scores)
doc_scores = self.numpy_ops.asarray(doc_scores)
doc_compat_guesses = []
for i, token in enumerate(doc):
for _ in range(self.top_k):
for _ in range(predictions_to_consider):
candidate = int(doc_scores[i].argmax())
candidate_tree_id = self.cfg["labels"][candidate]
if self.trees.apply(candidate_tree_id, token.text) is not None:
doc_compat_guesses.append(candidate_tree_id)
break
doc_scores[i, candidate] = -1
doc_scores[i, candidate] = np.finfo(np.float32).min
else:
doc_compat_guesses.append(-1)
guesses.append(np.array(doc_compat_guesses))