From 7ff99a3acc38cf7202fc269f32774d3e1f613d43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A1d=C3=A1r=20=C3=81kos?= Date: Mon, 28 Mar 2022 18:16:41 +0200 Subject: [PATCH] nicer restore --- spacy/pipeline/coref.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/spacy/pipeline/coref.py b/spacy/pipeline/coref.py index 340dde470..f0862c844 100644 --- a/spacy/pipeline/coref.py +++ b/spacy/pipeline/coref.py @@ -364,9 +364,7 @@ class CoreferenceResolver(TrainablePipe): for ex in examples: p_clusters = doc2clusters(ex.predicted, self.span_cluster_prefix) g_clusters = doc2clusters(ex.reference, self.span_cluster_prefix) - cluster_info = get_cluster_info(p_clusters, g_clusters) - evaluator.update(cluster_info) score = { @@ -511,12 +509,12 @@ class SpanPredictor(TrainablePipe): set_dropout_rate(self.model, drop) total_loss = 0 - old_spans = [eg.predicted.spans for eg in examples] for eg in examples: # replicates the EntityLinker's behaviour and # copies annotations over https://bit.ly/3iweDcW # https://github.com/explosion/spaCy/blob/master/spacy/pipeline/entity_linker.py#L313 doc = eg.predicted + old_spans = eg.predicted.spans for key, sg in eg.reference.spans.items(): if key.startswith(self.input_prefix): doc.spans[key] = eg.get_aligned_spans_y2x(sg) @@ -525,9 +523,8 @@ class SpanPredictor(TrainablePipe): total_loss += loss # TODO check shape here backprop(d_scores) - # Restore examples - for spans, eg in zip(old_spans, examples): - for key, sg in spans.items(): + # Restore example + for key, sg in old_spans.items(): eg.predicted.spans[key] = sg if sgd is not None: