mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 12:12:20 +03:00
nicer restore
This commit is contained in:
parent
06d680b269
commit
7ff99a3acc
|
@ -364,9 +364,7 @@ class CoreferenceResolver(TrainablePipe):
|
|||
for ex in examples:
|
||||
p_clusters = doc2clusters(ex.predicted, self.span_cluster_prefix)
|
||||
g_clusters = doc2clusters(ex.reference, self.span_cluster_prefix)
|
||||
|
||||
cluster_info = get_cluster_info(p_clusters, g_clusters)
|
||||
|
||||
evaluator.update(cluster_info)
|
||||
|
||||
score = {
|
||||
|
@ -511,12 +509,12 @@ class SpanPredictor(TrainablePipe):
|
|||
set_dropout_rate(self.model, drop)
|
||||
|
||||
total_loss = 0
|
||||
old_spans = [eg.predicted.spans for eg in examples]
|
||||
for eg in examples:
|
||||
# replicates the EntityLinker's behaviour and
|
||||
# copies annotations over https://bit.ly/3iweDcW
|
||||
# https://github.com/explosion/spaCy/blob/master/spacy/pipeline/entity_linker.py#L313
|
||||
doc = eg.predicted
|
||||
old_spans = eg.predicted.spans
|
||||
for key, sg in eg.reference.spans.items():
|
||||
if key.startswith(self.input_prefix):
|
||||
doc.spans[key] = eg.get_aligned_spans_y2x(sg)
|
||||
|
@ -525,9 +523,8 @@ class SpanPredictor(TrainablePipe):
|
|||
total_loss += loss
|
||||
# TODO check shape here
|
||||
backprop(d_scores)
|
||||
# Restore examples
|
||||
for spans, eg in zip(old_spans, examples):
|
||||
for key, sg in spans.items():
|
||||
# Restore example
|
||||
for key, sg in old_spans.items():
|
||||
eg.predicted.spans[key] = sg
|
||||
|
||||
if sgd is not None:
|
||||
|
|
Loading…
Reference in New Issue
Block a user