Add update method to senter

This commit is contained in:
Matthew Honnibal 2020-09-12 19:55:12 +02:00
parent 949c36b876
commit 6fbb31a136

View File

@ -95,6 +95,49 @@ class SentenceRecognizer(Tagger):
else:
doc.c[j].sent_start = -1
def update(self, examples, *, drop=0., sgd=None, losses=None, set_annotations=False):
"""Learn from a batch of documents and gold-standard information,
updating the pipe's model. Delegates to predict and get_loss.
examples (Iterable[Example]): A batch of Example objects.
drop (float): The dropout rate.
set_annotations (bool): Whether or not to update the Example objects
with the predictions.
sgd (thinc.api.Optimizer): The optimizer.
losses (Dict[str, float]): Optional record of the loss during training.
Updated using the component name as the key.
RETURNS (Dict[str, float]): The updated losses dictionary.
DOCS: https://nightly.spacy.io/api/tagger#update
"""
if losses is None:
losses = {}
losses.setdefault(self.name, 0.0)
validate_examples(examples, "Tagger.update")
if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples):
# Handle cases where there are no tokens in any docs.
return
if not any(eg.reference.is_sentenced for eg in examples):
# Handle cases where there are no tagged tokens in any docs.
return
set_dropout_rate(self.model, drop)
tag_scores, bp_tag_scores = self.model.begin_update([eg.predicted for eg in examples])
for sc in tag_scores:
if self.model.ops.xp.isnan(sc.sum()):
raise ValueError(Errors.E940)
loss, d_tag_scores = self.get_loss(examples, tag_scores)
bp_tag_scores(d_tag_scores)
if sgd not in (None, False):
self.model.finish_update(sgd)
losses[self.name] += loss
if set_annotations:
docs = [eg.predicted for eg in examples]
self.set_annotations(docs, self._scores2guesses(tag_scores))
return losses
def get_loss(self, examples, scores):
"""Find the loss and gradient of loss for the batch of documents and
their predicted scores.