diff --git a/spacy/pipeline/coref.py b/spacy/pipeline/coref.py index c1db23d68..1b062ed9a 100644 --- a/spacy/pipeline/coref.py +++ b/spacy/pipeline/coref.py @@ -457,7 +457,6 @@ class SpanPredictor(TrainablePipe): for doc in docs: # TODO check shape here span_scores = self.model.predict([doc]) - print(span_scores) if span_scores.size: # the information about clustering has to come from the input docs # first let's convert the scores to a list of span idxs @@ -513,7 +512,7 @@ class SpanPredictor(TrainablePipe): for eg in examples: # For update we use the gold coref_head_clusters # in the reference. - span_scores, backprop = self.model.begin_update([eg.reference]) + span_scores, backprop = self.model.begin_update([eg.predicted]) loss, d_scores = self.get_loss([eg], span_scores) total_loss += loss # TODO check shape here @@ -622,10 +621,9 @@ class SpanPredictor(TrainablePipe): ref = eg.reference pred = eg.predicted for key, gold_sg in ref.spans.items(): - if key.startswith(self.input_prefix): + if key.startswith(self.output_prefix): cluster_id = key.split('_')[-1] - # FIXME THIS DOESN'T WORK BECAUSE pred.spans are empty? - pred_sg = pred.spans[f"{self.output_prefix}_{cluster_id}"] + pred_sg = pred.spans[key] for gold_mention, pred_mention in zip(gold_sg, pred_sg): starts.append(gold_mention.start) ends.append(gold_mention.end)