From ef141ad3995410d64cd27a615b3f17ee21d59dd2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A1d=C3=A1r=20=C3=81kos?= Date: Mon, 4 Apr 2022 18:10:09 +0200 Subject: [PATCH] span accuracy score --- spacy/pipeline/coref.py | 52 +++++++++++++++++++++++------------------ 1 file changed, 29 insertions(+), 23 deletions(-) diff --git a/spacy/pipeline/coref.py b/spacy/pipeline/coref.py index 1c0e56521..c1db23d68 100644 --- a/spacy/pipeline/coref.py +++ b/spacy/pipeline/coref.py @@ -457,6 +457,7 @@ class SpanPredictor(TrainablePipe): for doc in docs: # TODO check shape here span_scores = self.model.predict([doc]) + print(span_scores) if span_scores.size: # the information about clustering has to come from the input docs # first let's convert the scores to a list of span idxs @@ -608,30 +609,35 @@ class SpanPredictor(TrainablePipe): self.model.initialize(X=X, Y=Y) def score(self, examples, **kwargs): - """Score a batch of examples.""" - # TODO This is basically the same as the main coref component - factor out? - + """ + Evaluate on reconstructing the correct spans around + gold heads. + """ scores = [] - for metric in (b_cubed, muc, ceafe): - evaluator = Evaluator(metric) + for eg in examples: + starts = [] + ends = [] + pred_starts = [] + pred_ends = [] + ref = eg.reference + pred = eg.predicted + for key, gold_sg in ref.spans.items(): + if key.startswith(self.input_prefix): + cluster_id = key.split('_')[-1] + # FIXME THIS DOESN'T WORK BECAUSE pred.spans are empty? + pred_sg = pred.spans[f"{self.output_prefix}_{cluster_id}"] + for gold_mention, pred_mention in zip(gold_sg, pred_sg): + starts.append(gold_mention.start) + ends.append(gold_mention.end) + pred_starts.append(pred_mention.start) + pred_ends.append(pred_mention.end) - for ex in examples: - # XXX this is the only different part - p_clusters = doc2clusters(ex.predicted, self.output_prefix) - g_clusters = doc2clusters(ex.reference, self.output_prefix) - cluster_info = get_cluster_info(p_clusters, g_clusters) + starts = self.model.ops.xp.asarray(starts) + ends = self.model.ops.xp.asarray(ends) + pred_starts = self.model.ops.xp.asarray(pred_starts) + pred_ends = self.model.ops.xp.asarray(pred_ends) + correct = ((starts == pred_starts) * (ends == pred_ends)).sum() + scores.append(correct) - evaluator.update(cluster_info) - - score = { - "coref_span_f": evaluator.get_f1(), - "coref_span_p": evaluator.get_precision(), - "coref_span_r": evaluator.get_recall(), - } - scores.append(score) - - out = {} - for field in ("f", "p", "r"): - fname = f"coref_span_{field}" - out[fname] = mean([ss[fname] for ss in scores]) + out = {"span_accuracy": mean(scores)} return out