mirror of
https://github.com/explosion/spaCy.git
synced 2025-04-19 08:31:59 +03:00
Add update method to senter
This commit is contained in:
parent
949c36b876
commit
6fbb31a136
|
@ -95,6 +95,49 @@ class SentenceRecognizer(Tagger):
|
|||
else:
|
||||
doc.c[j].sent_start = -1
|
||||
|
||||
def update(self, examples, *, drop=0., sgd=None, losses=None, set_annotations=False):
|
||||
"""Learn from a batch of documents and gold-standard information,
|
||||
updating the pipe's model. Delegates to predict and get_loss.
|
||||
|
||||
examples (Iterable[Example]): A batch of Example objects.
|
||||
drop (float): The dropout rate.
|
||||
set_annotations (bool): Whether or not to update the Example objects
|
||||
with the predictions.
|
||||
sgd (thinc.api.Optimizer): The optimizer.
|
||||
losses (Dict[str, float]): Optional record of the loss during training.
|
||||
Updated using the component name as the key.
|
||||
RETURNS (Dict[str, float]): The updated losses dictionary.
|
||||
|
||||
DOCS: https://nightly.spacy.io/api/tagger#update
|
||||
"""
|
||||
if losses is None:
|
||||
losses = {}
|
||||
losses.setdefault(self.name, 0.0)
|
||||
validate_examples(examples, "Tagger.update")
|
||||
if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples):
|
||||
# Handle cases where there are no tokens in any docs.
|
||||
return
|
||||
if not any(eg.reference.is_sentenced for eg in examples):
|
||||
# Handle cases where there are no tagged tokens in any docs.
|
||||
return
|
||||
set_dropout_rate(self.model, drop)
|
||||
tag_scores, bp_tag_scores = self.model.begin_update([eg.predicted for eg in examples])
|
||||
for sc in tag_scores:
|
||||
if self.model.ops.xp.isnan(sc.sum()):
|
||||
raise ValueError(Errors.E940)
|
||||
loss, d_tag_scores = self.get_loss(examples, tag_scores)
|
||||
bp_tag_scores(d_tag_scores)
|
||||
if sgd not in (None, False):
|
||||
self.model.finish_update(sgd)
|
||||
|
||||
losses[self.name] += loss
|
||||
if set_annotations:
|
||||
docs = [eg.predicted for eg in examples]
|
||||
self.set_annotations(docs, self._scores2guesses(tag_scores))
|
||||
return losses
|
||||
|
||||
|
||||
|
||||
def get_loss(self, examples, scores):
|
||||
"""Find the loss and gradient of loss for the batch of documents and
|
||||
their predicted scores.
|
||||
|
|
Loading…
Reference in New Issue
Block a user