From 3ba913109d27827639eaa2bf91c1693bed7f33f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A1d=C3=A1r=20=C3=81kos?= Date: Thu, 7 Apr 2022 13:20:12 +0200 Subject: [PATCH] update with eg.predited as other components --- spacy/pipeline/coref.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/spacy/pipeline/coref.py b/spacy/pipeline/coref.py index c1db23d68..1b062ed9a 100644 --- a/spacy/pipeline/coref.py +++ b/spacy/pipeline/coref.py @@ -457,7 +457,6 @@ class SpanPredictor(TrainablePipe): for doc in docs: # TODO check shape here span_scores = self.model.predict([doc]) - print(span_scores) if span_scores.size: # the information about clustering has to come from the input docs # first let's convert the scores to a list of span idxs @@ -513,7 +512,7 @@ class SpanPredictor(TrainablePipe): for eg in examples: # For update we use the gold coref_head_clusters # in the reference. - span_scores, backprop = self.model.begin_update([eg.reference]) + span_scores, backprop = self.model.begin_update([eg.predicted]) loss, d_scores = self.get_loss([eg], span_scores) total_loss += loss # TODO check shape here @@ -622,10 +621,9 @@ class SpanPredictor(TrainablePipe): ref = eg.reference pred = eg.predicted for key, gold_sg in ref.spans.items(): - if key.startswith(self.input_prefix): + if key.startswith(self.output_prefix): cluster_id = key.split('_')[-1] - # FIXME THIS DOESN'T WORK BECAUSE pred.spans are empty? - pred_sg = pred.spans[f"{self.output_prefix}_{cluster_id}"] + pred_sg = pred.spans[key] for gold_mention, pred_mention in zip(gold_sg, pred_sg): starts.append(gold_mention.start) ends.append(gold_mention.end)