Track loss in tagger

This commit is contained in:
Matthew Honnibal 2017-08-20 14:42:23 +02:00
parent 8875590081
commit c1d3ff517a

View File

@ -294,6 +294,8 @@ class NeuralTagger(BaseThincComponent):
doc.is_tagged = True
def update(self, docs_tokvecs, golds, drop=0., sgd=None, losses=None):
if losses is not None and self.name not in losses:
losses[self.name] = 0.
docs, tokvecs = docs_tokvecs
if self.model.nI is None:
@ -302,6 +304,8 @@ class NeuralTagger(BaseThincComponent):
loss, d_tag_scores = self.get_loss(docs, golds, tag_scores)
d_tokvecs = bp_tag_scores(d_tag_scores, sgd=sgd)
if losses is not None:
losses[self.name] += loss
return d_tokvecs
def get_loss(self, docs, golds, scores):