nicer restore

This commit is contained in:
Kádár Ákos 2022-03-28 18:16:41 +02:00
parent 06d680b269
commit 7ff99a3acc

View File

@ -364,9 +364,7 @@ class CoreferenceResolver(TrainablePipe):
for ex in examples: for ex in examples:
p_clusters = doc2clusters(ex.predicted, self.span_cluster_prefix) p_clusters = doc2clusters(ex.predicted, self.span_cluster_prefix)
g_clusters = doc2clusters(ex.reference, self.span_cluster_prefix) g_clusters = doc2clusters(ex.reference, self.span_cluster_prefix)
cluster_info = get_cluster_info(p_clusters, g_clusters) cluster_info = get_cluster_info(p_clusters, g_clusters)
evaluator.update(cluster_info) evaluator.update(cluster_info)
score = { score = {
@ -511,12 +509,12 @@ class SpanPredictor(TrainablePipe):
set_dropout_rate(self.model, drop) set_dropout_rate(self.model, drop)
total_loss = 0 total_loss = 0
old_spans = [eg.predicted.spans for eg in examples]
for eg in examples: for eg in examples:
# replicates the EntityLinker's behaviour and # replicates the EntityLinker's behaviour and
# copies annotations over https://bit.ly/3iweDcW # copies annotations over https://bit.ly/3iweDcW
# https://github.com/explosion/spaCy/blob/master/spacy/pipeline/entity_linker.py#L313 # https://github.com/explosion/spaCy/blob/master/spacy/pipeline/entity_linker.py#L313
doc = eg.predicted doc = eg.predicted
old_spans = eg.predicted.spans
for key, sg in eg.reference.spans.items(): for key, sg in eg.reference.spans.items():
if key.startswith(self.input_prefix): if key.startswith(self.input_prefix):
doc.spans[key] = eg.get_aligned_spans_y2x(sg) doc.spans[key] = eg.get_aligned_spans_y2x(sg)
@ -525,9 +523,8 @@ class SpanPredictor(TrainablePipe):
total_loss += loss total_loss += loss
# TODO check shape here # TODO check shape here
backprop(d_scores) backprop(d_scores)
# Restore examples # Restore example
for spans, eg in zip(old_spans, examples): for key, sg in old_spans.items():
for key, sg in spans.items():
eg.predicted.spans[key] = sg eg.predicted.spans[key] = sg
if sgd is not None: if sgd is not None: