From 6aedd98d02b55672469556f4d61f2ad6254f3759 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A1d=C3=A1r=20=C3=81kos?= Date: Mon, 11 Apr 2022 16:10:14 +0200 Subject: [PATCH] fixing scorer --- spacy/pipeline/coref.py | 37 +++++++++++++++++-------------------- 1 file changed, 17 insertions(+), 20 deletions(-) diff --git a/spacy/pipeline/coref.py b/spacy/pipeline/coref.py index 02c93f712..fc04d1a3e 100644 --- a/spacy/pipeline/coref.py +++ b/spacy/pipeline/coref.py @@ -511,10 +511,13 @@ class SpanPredictor(TrainablePipe): total_loss = 0 for eg in examples: span_scores, backprop = self.model.begin_update([eg.predicted]) - loss, d_scores = self.get_loss([eg], span_scores) - total_loss += loss - # TODO check shape here - backprop((d_scores)) + # FIXME, this only happens once in the first 1000 docs of OntoNotes + # and I'm not sure yet why. + if span_scores.size: + loss, d_scores = self.get_loss([eg], span_scores) + total_loss += loss + # TODO check shape here + backprop((d_scores)) if sgd is not None: self.finish_update(sgd) @@ -557,7 +560,6 @@ class SpanPredictor(TrainablePipe): assert len(examples) == 1, "Only fake batching is supported." # starts and ends are gold starts and ends (Ints1d) # span_scores is a Floats3d. What are the axes? mention x token x start/end - for eg in examples: starts = [] ends = [] @@ -610,8 +612,8 @@ class SpanPredictor(TrainablePipe): Evaluate on reconstructing the correct spans around gold heads. """ - start_scores = [] - end_scores = [] + scores = [] + xp = self.model.ops.xp for eg in examples: starts = [] ends = [] @@ -628,16 +630,11 @@ class SpanPredictor(TrainablePipe): pred_starts.append(pred_mention.start) pred_ends.append(pred_mention.end) - starts = self.model.ops.xp.asarray(starts) - ends = self.model.ops.xp.asarray(ends) - pred_starts = self.model.ops.xp.asarray(pred_starts) - pred_ends = self.model.ops.xp.asarray(pred_ends) - start_accuracy = (starts == pred_starts).mean() - end_accuracy = (ends == pred_ends).mean() - start_scores.append(float(start_accuracy)) - end_scores.append(float(end_accuracy)) - out = { - "span_start_accuracy": mean(start_scores), - "span_end_accuracy": mean(end_scores) - } - return out + starts = xp.asarray(starts) + ends = xp.asarray(ends) + pred_starts = xp.asarray(pred_starts) + pred_ends = xp.asarray(pred_ends) + correct = (starts == pred_starts) * (ends == pred_ends) + accuracy = correct.mean() + scores.append(float(accuracy)) + return {"span_accuracy": mean(scores)}