prepare for aligned heads-spans training

This commit is contained in:
Kádár Ákos 2022-04-04 15:26:15 +02:00
parent 63a41ba50a
commit a1d0219903

View File

@ -503,29 +503,20 @@ class SpanPredictor(TrainablePipe):
losses = {}
losses.setdefault(self.name, 0.0)
validate_examples(examples, "SpanPredictor.update")
if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples):
if not any(len(eg.reference) if eg.reference else 0 for eg in examples):
# Handle cases where there are no tokens in any docs.
return losses
set_dropout_rate(self.model, drop)
total_loss = 0
for eg in examples:
# replicates the EntityLinker's behaviour and
# copies annotations over https://bit.ly/3iweDcW
# https://github.com/explosion/spaCy/blob/master/spacy/pipeline/entity_linker.py#L313
doc = eg.predicted
old_spans = eg.predicted.spans
for key, sg in eg.reference.spans.items():
if key.startswith(self.input_prefix):
doc.spans[key] = eg.get_aligned_spans_y2x(sg)
span_scores, backprop = self.model.begin_update([doc])
# For update we use the gold coref_head_clusters
# in the reference.
span_scores, backprop = self.model.begin_update([eg.reference])
loss, d_scores = self.get_loss([eg], span_scores)
total_loss += loss
# TODO check shape here
backprop(d_scores)
# Restore example
for key, sg in old_spans.items():
eg.predicted.spans[key] = sg
if sgd is not None:
self.finish_update(sgd)
@ -570,17 +561,14 @@ class SpanPredictor(TrainablePipe):
# span_scores is a Floats3d. What are the axes? mention x token x start/end
for eg in examples:
# get gold data
gold = doc2clusters(eg.predicted, self.input_prefix)
# flatten the gold data
starts = []
ends = []
for cluster in gold:
for mention in cluster:
starts.append(mention[0])
# XXX I think this was missing here
ends.append(mention[1] - 1)
for key, sg in eg.reference.spans.items():
if key.startswith(self.output_prefix):
for mention in sg:
starts.append(mention.start)
ends.append(mention.end)
starts = self.model.ops.xp.asarray(starts)
ends = self.model.ops.xp.asarray(ends)
start_scores = span_scores[:, :, 0]