mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 20:22:25 +03:00
addressing suggestions by @polm
This commit is contained in:
parent
e4b4b67ef6
commit
06d680b269
|
@ -511,20 +511,24 @@ class SpanPredictor(TrainablePipe):
|
||||||
set_dropout_rate(self.model, drop)
|
set_dropout_rate(self.model, drop)
|
||||||
|
|
||||||
total_loss = 0
|
total_loss = 0
|
||||||
docs = [eg.predicted for eg in examples]
|
old_spans = [eg.predicted.spans for eg in examples]
|
||||||
for doc, eg in zip(docs, examples):
|
for eg in examples:
|
||||||
# replicates the EntityLinker's behaviour and
|
# replicates the EntityLinker's behaviour and
|
||||||
# copies annotations over https://bit.ly/3iweDcW
|
# copies annotations over https://bit.ly/3iweDcW
|
||||||
# takes 'coref_head_clusters' from the reference.
|
# https://github.com/explosion/spaCy/blob/master/spacy/pipeline/entity_linker.py#L313
|
||||||
|
doc = eg.predicted
|
||||||
for key, sg in eg.reference.spans.items():
|
for key, sg in eg.reference.spans.items():
|
||||||
if key.startswith(self.input_prefix):
|
if key.startswith(self.input_prefix):
|
||||||
aligned_spans = eg.get_aligned_spans_x2y(sg)
|
doc.spans[key] = eg.get_aligned_spans_y2x(sg)
|
||||||
doc.spans[key] = [doc[span.start:span.end] for span in aligned_spans]
|
|
||||||
span_scores, backprop = self.model.begin_update([doc])
|
span_scores, backprop = self.model.begin_update([doc])
|
||||||
loss, d_scores = self.get_loss([eg], span_scores)
|
loss, d_scores = self.get_loss([eg], span_scores)
|
||||||
total_loss += loss
|
total_loss += loss
|
||||||
# TODO check shape here
|
# TODO check shape here
|
||||||
backprop(d_scores)
|
backprop(d_scores)
|
||||||
|
# Restore examples
|
||||||
|
for spans, eg in zip(old_spans, examples):
|
||||||
|
for key, sg in spans.items():
|
||||||
|
eg.predicted.spans[key] = sg
|
||||||
|
|
||||||
if sgd is not None:
|
if sgd is not None:
|
||||||
self.finish_update(sgd)
|
self.finish_update(sgd)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user