addressing suggestions by @polm

This commit is contained in:
Kádár Ákos 2022-03-28 14:31:51 +02:00
parent e4b4b67ef6
commit 06d680b269

View File

@ -511,20 +511,24 @@ class SpanPredictor(TrainablePipe):
set_dropout_rate(self.model, drop) set_dropout_rate(self.model, drop)
total_loss = 0 total_loss = 0
docs = [eg.predicted for eg in examples] old_spans = [eg.predicted.spans for eg in examples]
for doc, eg in zip(docs, examples): for eg in examples:
# replicates the EntityLinker's behaviour and # replicates the EntityLinker's behaviour and
# copies annotations over https://bit.ly/3iweDcW # copies annotations over https://bit.ly/3iweDcW
# takes 'coref_head_clusters' from the reference. # https://github.com/explosion/spaCy/blob/master/spacy/pipeline/entity_linker.py#L313
doc = eg.predicted
for key, sg in eg.reference.spans.items(): for key, sg in eg.reference.spans.items():
if key.startswith(self.input_prefix): if key.startswith(self.input_prefix):
aligned_spans = eg.get_aligned_spans_x2y(sg) doc.spans[key] = eg.get_aligned_spans_y2x(sg)
doc.spans[key] = [doc[span.start:span.end] for span in aligned_spans]
span_scores, backprop = self.model.begin_update([doc]) span_scores, backprop = self.model.begin_update([doc])
loss, d_scores = self.get_loss([eg], span_scores) loss, d_scores = self.get_loss([eg], span_scores)
total_loss += loss total_loss += loss
# TODO check shape here # TODO check shape here
backprop(d_scores) backprop(d_scores)
# Restore examples
for spans, eg in zip(old_spans, examples):
for key, sg in spans.items():
eg.predicted.spans[key] = sg
if sgd is not None: if sgd is not None:
self.finish_update(sgd) self.finish_update(sgd)