From 7a239f2ec7c71a494f2380686fdbcfdd421e7fa6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A1d=C3=A1r=20=C3=81kos?= Date: Fri, 8 Apr 2022 14:57:19 +0200 Subject: [PATCH] report start- and end-accuracies separately --- spacy/pipeline/coref.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/spacy/pipeline/coref.py b/spacy/pipeline/coref.py index 1b062ed9a..02c93f712 100644 --- a/spacy/pipeline/coref.py +++ b/spacy/pipeline/coref.py @@ -510,13 +510,11 @@ class SpanPredictor(TrainablePipe): total_loss = 0 for eg in examples: - # For update we use the gold coref_head_clusters - # in the reference. span_scores, backprop = self.model.begin_update([eg.predicted]) loss, d_scores = self.get_loss([eg], span_scores) total_loss += loss # TODO check shape here - backprop(d_scores) + backprop((d_scores)) if sgd is not None: self.finish_update(sgd) @@ -612,7 +610,8 @@ class SpanPredictor(TrainablePipe): Evaluate on reconstructing the correct spans around gold heads. """ - scores = [] + start_scores = [] + end_scores = [] for eg in examples: starts = [] ends = [] @@ -622,7 +621,6 @@ class SpanPredictor(TrainablePipe): pred = eg.predicted for key, gold_sg in ref.spans.items(): if key.startswith(self.output_prefix): - cluster_id = key.split('_')[-1] pred_sg = pred.spans[key] for gold_mention, pred_mention in zip(gold_sg, pred_sg): starts.append(gold_mention.start) @@ -634,8 +632,12 @@ class SpanPredictor(TrainablePipe): ends = self.model.ops.xp.asarray(ends) pred_starts = self.model.ops.xp.asarray(pred_starts) pred_ends = self.model.ops.xp.asarray(pred_ends) - correct = ((starts == pred_starts) * (ends == pred_ends)).sum() - scores.append(correct) - - out = {"span_accuracy": mean(scores)} + start_accuracy = (starts == pred_starts).mean() + end_accuracy = (ends == pred_ends).mean() + start_scores.append(float(start_accuracy)) + end_scores.append(float(end_accuracy)) + out = { + "span_start_accuracy": mean(start_scores), + "span_end_accuracy": mean(end_scores) + } return out