diff --git a/spacy/pipeline/coref.py b/spacy/pipeline/coref.py index 25a353405..1c0e56521 100644 --- a/spacy/pipeline/coref.py +++ b/spacy/pipeline/coref.py @@ -503,29 +503,20 @@ class SpanPredictor(TrainablePipe): losses = {} losses.setdefault(self.name, 0.0) validate_examples(examples, "SpanPredictor.update") - if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples): + if not any(len(eg.reference) if eg.reference else 0 for eg in examples): # Handle cases where there are no tokens in any docs. return losses set_dropout_rate(self.model, drop) total_loss = 0 for eg in examples: - # replicates the EntityLinker's behaviour and - # copies annotations over https://bit.ly/3iweDcW - # https://github.com/explosion/spaCy/blob/master/spacy/pipeline/entity_linker.py#L313 - doc = eg.predicted - old_spans = eg.predicted.spans - for key, sg in eg.reference.spans.items(): - if key.startswith(self.input_prefix): - doc.spans[key] = eg.get_aligned_spans_y2x(sg) - span_scores, backprop = self.model.begin_update([doc]) + # For update we use the gold coref_head_clusters + # in the reference. + span_scores, backprop = self.model.begin_update([eg.reference]) loss, d_scores = self.get_loss([eg], span_scores) total_loss += loss # TODO check shape here backprop(d_scores) - # Restore example - for key, sg in old_spans.items(): - eg.predicted.spans[key] = sg if sgd is not None: self.finish_update(sgd) @@ -570,17 +561,14 @@ class SpanPredictor(TrainablePipe): # span_scores is a Floats3d. What are the axes? mention x token x start/end for eg in examples: - - # get gold data - gold = doc2clusters(eg.predicted, self.input_prefix) - # flatten the gold data starts = [] ends = [] - for cluster in gold: - for mention in cluster: - starts.append(mention[0]) - # XXX I think this was missing here - ends.append(mention[1] - 1) + for key, sg in eg.reference.spans.items(): + if key.startswith(self.output_prefix): + for mention in sg: + starts.append(mention.start) + ends.append(mention.end) + starts = self.model.ops.xp.asarray(starts) ends = self.model.ops.xp.asarray(ends) start_scores = span_scores[:, :, 0]