mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
Track loss in tagger
This commit is contained in:
parent
8875590081
commit
c1d3ff517a
|
@ -294,6 +294,8 @@ class NeuralTagger(BaseThincComponent):
|
||||||
doc.is_tagged = True
|
doc.is_tagged = True
|
||||||
|
|
||||||
def update(self, docs_tokvecs, golds, drop=0., sgd=None, losses=None):
|
def update(self, docs_tokvecs, golds, drop=0., sgd=None, losses=None):
|
||||||
|
if losses is not None and self.name not in losses:
|
||||||
|
losses[self.name] = 0.
|
||||||
docs, tokvecs = docs_tokvecs
|
docs, tokvecs = docs_tokvecs
|
||||||
|
|
||||||
if self.model.nI is None:
|
if self.model.nI is None:
|
||||||
|
@ -302,6 +304,8 @@ class NeuralTagger(BaseThincComponent):
|
||||||
loss, d_tag_scores = self.get_loss(docs, golds, tag_scores)
|
loss, d_tag_scores = self.get_loss(docs, golds, tag_scores)
|
||||||
|
|
||||||
d_tokvecs = bp_tag_scores(d_tag_scores, sgd=sgd)
|
d_tokvecs = bp_tag_scores(d_tag_scores, sgd=sgd)
|
||||||
|
if losses is not None:
|
||||||
|
losses[self.name] += loss
|
||||||
return d_tokvecs
|
return d_tokvecs
|
||||||
|
|
||||||
def get_loss(self, docs, golds, scores):
|
def get_loss(self, docs, golds, scores):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user