mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-03 11:50:19 +03:00
only restore entities after loss calculation
This commit is contained in:
parent
2b0a3e2e40
commit
42188b30c5
|
@ -375,10 +375,6 @@ class EntityLinker(TrainablePipe):
|
||||||
|
|
||||||
sentence_encodings, bp_context = self.model.begin_update(docs)
|
sentence_encodings, bp_context = self.model.begin_update(docs)
|
||||||
|
|
||||||
# now restore the ents
|
|
||||||
for doc, old in zip(docs, old_ents):
|
|
||||||
doc.ents = old
|
|
||||||
|
|
||||||
loss, d_scores = self.get_loss(
|
loss, d_scores = self.get_loss(
|
||||||
sentence_encodings=sentence_encodings, examples=examples
|
sentence_encodings=sentence_encodings, examples=examples
|
||||||
)
|
)
|
||||||
|
@ -386,6 +382,12 @@ class EntityLinker(TrainablePipe):
|
||||||
if sgd is not None:
|
if sgd is not None:
|
||||||
self.finish_update(sgd)
|
self.finish_update(sgd)
|
||||||
losses[self.name] += loss
|
losses[self.name] += loss
|
||||||
|
|
||||||
|
# now restore the ents
|
||||||
|
assert len(docs) == len(old_ents)
|
||||||
|
for doc, old in zip(docs, old_ents):
|
||||||
|
doc.ents = old
|
||||||
|
|
||||||
return losses
|
return losses
|
||||||
|
|
||||||
def get_loss(self, examples: Iterable[Example], sentence_encodings: Floats2d):
|
def get_loss(self, examples: Iterable[Example], sentence_encodings: Floats2d):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user