diff --git a/spacy/pipeline/senter.pyx b/spacy/pipeline/senter.pyx index 00664131b..fc05c9c82 100644 --- a/spacy/pipeline/senter.pyx +++ b/spacy/pipeline/senter.pyx @@ -95,6 +95,49 @@ class SentenceRecognizer(Tagger): else: doc.c[j].sent_start = -1 + def update(self, examples, *, drop=0., sgd=None, losses=None, set_annotations=False): + """Learn from a batch of documents and gold-standard information, + updating the pipe's model. Delegates to predict and get_loss. + + examples (Iterable[Example]): A batch of Example objects. + drop (float): The dropout rate. + set_annotations (bool): Whether or not to update the Example objects + with the predictions. + sgd (thinc.api.Optimizer): The optimizer. + losses (Dict[str, float]): Optional record of the loss during training. + Updated using the component name as the key. + RETURNS (Dict[str, float]): The updated losses dictionary. + + DOCS: https://nightly.spacy.io/api/tagger#update + """ + if losses is None: + losses = {} + losses.setdefault(self.name, 0.0) + validate_examples(examples, "Tagger.update") + if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples): + # Handle cases where there are no tokens in any docs. + return + if not any(eg.reference.is_sentenced for eg in examples): + # Handle cases where there are no tagged tokens in any docs. + return + set_dropout_rate(self.model, drop) + tag_scores, bp_tag_scores = self.model.begin_update([eg.predicted for eg in examples]) + for sc in tag_scores: + if self.model.ops.xp.isnan(sc.sum()): + raise ValueError(Errors.E940) + loss, d_tag_scores = self.get_loss(examples, tag_scores) + bp_tag_scores(d_tag_scores) + if sgd not in (None, False): + self.model.finish_update(sgd) + + losses[self.name] += loss + if set_annotations: + docs = [eg.predicted for eg in examples] + self.set_annotations(docs, self._scores2guesses(tag_scores)) + return losses + + + def get_loss(self, examples, scores): """Find the loss and gradient of loss for the batch of documents and their predicted scores.