mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
one more losses fix
This commit is contained in:
parent
478a14a619
commit
44e14ccae8
|
@ -227,10 +227,13 @@ class Tagger(TrainablePipe):
|
||||||
|
|
||||||
DOCS: https://nightly.spacy.io/api/tagger#rehearse
|
DOCS: https://nightly.spacy.io/api/tagger#rehearse
|
||||||
"""
|
"""
|
||||||
|
if losses is None:
|
||||||
|
losses = {}
|
||||||
|
losses.setdefault(self.name, 0.0)
|
||||||
validate_examples(examples, "Tagger.rehearse")
|
validate_examples(examples, "Tagger.rehearse")
|
||||||
docs = [eg.predicted for eg in examples]
|
docs = [eg.predicted for eg in examples]
|
||||||
if self._rehearsal_model is None:
|
if self._rehearsal_model is None:
|
||||||
return
|
return losses
|
||||||
if not any(len(doc) for doc in docs):
|
if not any(len(doc) for doc in docs):
|
||||||
# Handle cases where there are no tokens in any docs.
|
# Handle cases where there are no tokens in any docs.
|
||||||
return losses
|
return losses
|
||||||
|
@ -240,9 +243,7 @@ class Tagger(TrainablePipe):
|
||||||
gradient = guesses - target
|
gradient = guesses - target
|
||||||
backprop(gradient)
|
backprop(gradient)
|
||||||
self.finish_update(sgd)
|
self.finish_update(sgd)
|
||||||
if losses is not None:
|
losses[self.name] += (gradient**2).sum()
|
||||||
losses.setdefault(self.name, 0.0)
|
|
||||||
losses[self.name] += (gradient**2).sum()
|
|
||||||
return losses
|
return losses
|
||||||
|
|
||||||
def get_loss(self, examples, scores):
|
def get_loss(self, examples, scores):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user