mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-04 12:20:20 +03:00
Changes based on review comments
This commit is contained in:
parent
361b64e648
commit
4daf5e9b81
|
@ -5,8 +5,8 @@ from itertools import islice
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import srsly
|
import srsly
|
||||||
from thinc.api import Config, Model, SequenceCategoricalCrossentropy
|
from thinc.api import Config, Model, SequenceCategoricalCrossentropy, NumpyOps
|
||||||
from thinc.types import Floats2d, Ints1d, Ints2d
|
from thinc.types import Floats2d, Ints2d
|
||||||
|
|
||||||
from ._edit_tree_internals.edit_trees import EditTrees
|
from ._edit_tree_internals.edit_trees import EditTrees
|
||||||
from ._edit_tree_internals.schemas import validate_edit_tree
|
from ._edit_tree_internals.schemas import validate_edit_tree
|
||||||
|
@ -115,6 +115,7 @@ class EditTreeLemmatizer(TrainablePipe):
|
||||||
|
|
||||||
self.cfg: Dict[str, Any] = {"labels": []}
|
self.cfg: Dict[str, Any] = {"labels": []}
|
||||||
self.scorer = scorer
|
self.scorer = scorer
|
||||||
|
self.numpy_ops = NumpyOps()
|
||||||
|
|
||||||
def get_loss(
|
def get_loss(
|
||||||
self, examples: Iterable[Example], scores: List[Floats2d]
|
self, examples: Iterable[Example], scores: List[Floats2d]
|
||||||
|
@ -159,17 +160,18 @@ class EditTreeLemmatizer(TrainablePipe):
|
||||||
|
|
||||||
def _scores2guesses(self, docs, scores):
|
def _scores2guesses(self, docs, scores):
|
||||||
guesses = []
|
guesses = []
|
||||||
|
predictions_to_consider = min(self.top_k, len(self.labels))
|
||||||
for doc, doc_scores in zip(docs, scores):
|
for doc, doc_scores in zip(docs, scores):
|
||||||
NumpyOps().asarray(doc_scores)
|
doc_scores = self.numpy_ops.asarray(doc_scores)
|
||||||
doc_compat_guesses = []
|
doc_compat_guesses = []
|
||||||
for i, token in enumerate(doc):
|
for i, token in enumerate(doc):
|
||||||
for _ in range(self.top_k):
|
for _ in range(predictions_to_consider):
|
||||||
candidate = int(doc_scores[i].argmax())
|
candidate = int(doc_scores[i].argmax())
|
||||||
candidate_tree_id = self.cfg["labels"][candidate]
|
candidate_tree_id = self.cfg["labels"][candidate]
|
||||||
if self.trees.apply(candidate_tree_id, token.text) is not None:
|
if self.trees.apply(candidate_tree_id, token.text) is not None:
|
||||||
doc_compat_guesses.append(candidate_tree_id)
|
doc_compat_guesses.append(candidate_tree_id)
|
||||||
break
|
break
|
||||||
doc_scores[i, candidate] = -1
|
doc_scores[i, candidate] = np.finfo(np.float32).min
|
||||||
else:
|
else:
|
||||||
doc_compat_guesses.append(-1)
|
doc_compat_guesses.append(-1)
|
||||||
guesses.append(np.array(doc_compat_guesses))
|
guesses.append(np.array(doc_compat_guesses))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user