From 7304604edd6238d16f156b3f30db40d809f1a440 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A1d=C3=A1r=20=C3=81kos?= Date: Fri, 25 Mar 2022 18:29:33 +0100 Subject: [PATCH] make sure predicted and reference keeps aligned --- spacy/pipeline/coref.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/spacy/pipeline/coref.py b/spacy/pipeline/coref.py index eb05011ec..99bb611ff 100644 --- a/spacy/pipeline/coref.py +++ b/spacy/pipeline/coref.py @@ -130,7 +130,6 @@ class CoreferenceResolver(TrainablePipe): DOCS: https://spacy.io/api/coref#predict (TODO) """ - #print("DOCS", docs) out = [] for doc in docs: scores, idxs = self.model.predict([doc]) @@ -212,7 +211,6 @@ class CoreferenceResolver(TrainablePipe): # TODO check this causes no issues (in practice it runs) preds, backprop = self.model.begin_update([eg.predicted]) score_matrix, mention_idx = preds - loss, d_scores = self.get_loss([eg], score_matrix, mention_idx) total_loss += loss # TODO check shape here @@ -518,7 +516,8 @@ class SpanPredictor(TrainablePipe): # takes 'coref_head_clusters' from the reference. for key, sg in eg.reference.spans.items(): if key.startswith(self.input_prefix): - doc.spans[key] = [doc[span.start:span.end] for span in sg] + aligned_spans = eg.get_aligned_spans_x2y(sg) + doc.spans[key] = [doc[span.start:span.end] for span in aligned_spans] span_scores, backprop = self.model.begin_update([doc]) loss, d_scores = self.get_loss([eg], span_scores) total_loss += loss @@ -600,7 +599,7 @@ class SpanPredictor(TrainablePipe): *, nlp: Optional[Language] = None, ) -> None: - validate_get_examples(get_examples, "CoreferenceResolver.initialize") + validate_get_examples(get_examples, "SpanPredictor.initialize") X = [] Y = []