From 06d680b269c87059ca1fd0381f025a2bcc60c5ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A1d=C3=A1r=20=C3=81kos?= Date: Mon, 28 Mar 2022 14:31:51 +0200 Subject: [PATCH] addressing suggestions by @polm --- spacy/pipeline/coref.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/spacy/pipeline/coref.py b/spacy/pipeline/coref.py index 5a4fa1ab9..340dde470 100644 --- a/spacy/pipeline/coref.py +++ b/spacy/pipeline/coref.py @@ -511,20 +511,24 @@ class SpanPredictor(TrainablePipe): set_dropout_rate(self.model, drop) total_loss = 0 - docs = [eg.predicted for eg in examples] - for doc, eg in zip(docs, examples): + old_spans = [eg.predicted.spans for eg in examples] + for eg in examples: # replicates the EntityLinker's behaviour and # copies annotations over https://bit.ly/3iweDcW - # takes 'coref_head_clusters' from the reference. + # https://github.com/explosion/spaCy/blob/master/spacy/pipeline/entity_linker.py#L313 + doc = eg.predicted for key, sg in eg.reference.spans.items(): if key.startswith(self.input_prefix): - aligned_spans = eg.get_aligned_spans_x2y(sg) - doc.spans[key] = [doc[span.start:span.end] for span in aligned_spans] + doc.spans[key] = eg.get_aligned_spans_y2x(sg) span_scores, backprop = self.model.begin_update([doc]) loss, d_scores = self.get_loss([eg], span_scores) total_loss += loss # TODO check shape here backprop(d_scores) + # Restore examples + for spans, eg in zip(old_spans, examples): + for key, sg in spans.items(): + eg.predicted.spans[key] = sg if sgd is not None: self.finish_update(sgd)