Add update method to senter

This commit is contained in:
Matthew Honnibal 2020-09-12 19:55:12 +02:00
parent 949c36b876
commit 6fbb31a136

View File

@ -95,6 +95,49 @@ class SentenceRecognizer(Tagger):
else: else:
doc.c[j].sent_start = -1 doc.c[j].sent_start = -1
def update(self, examples, *, drop=0., sgd=None, losses=None, set_annotations=False):
"""Learn from a batch of documents and gold-standard information,
updating the pipe's model. Delegates to predict and get_loss.
examples (Iterable[Example]): A batch of Example objects.
drop (float): The dropout rate.
set_annotations (bool): Whether or not to update the Example objects
with the predictions.
sgd (thinc.api.Optimizer): The optimizer.
losses (Dict[str, float]): Optional record of the loss during training.
Updated using the component name as the key.
RETURNS (Dict[str, float]): The updated losses dictionary.
DOCS: https://nightly.spacy.io/api/tagger#update
"""
if losses is None:
losses = {}
losses.setdefault(self.name, 0.0)
validate_examples(examples, "Tagger.update")
if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples):
# Handle cases where there are no tokens in any docs.
return
if not any(eg.reference.is_sentenced for eg in examples):
# Handle cases where there are no tagged tokens in any docs.
return
set_dropout_rate(self.model, drop)
tag_scores, bp_tag_scores = self.model.begin_update([eg.predicted for eg in examples])
for sc in tag_scores:
if self.model.ops.xp.isnan(sc.sum()):
raise ValueError(Errors.E940)
loss, d_tag_scores = self.get_loss(examples, tag_scores)
bp_tag_scores(d_tag_scores)
if sgd not in (None, False):
self.model.finish_update(sgd)
losses[self.name] += loss
if set_annotations:
docs = [eg.predicted for eg in examples]
self.set_annotations(docs, self._scores2guesses(tag_scores))
return losses
def get_loss(self, examples, scores): def get_loss(self, examples, scores):
"""Find the loss and gradient of loss for the batch of documents and """Find the loss and gradient of loss for the batch of documents and
their predicted scores. their predicted scores.