prepare for aligned heads-spans training

This commit is contained in:
Kádár Ákos 2022-04-04 15:26:15 +02:00
parent 63a41ba50a
commit a1d0219903

View File

@ -503,29 +503,20 @@ class SpanPredictor(TrainablePipe):
losses = {} losses = {}
losses.setdefault(self.name, 0.0) losses.setdefault(self.name, 0.0)
validate_examples(examples, "SpanPredictor.update") validate_examples(examples, "SpanPredictor.update")
if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples): if not any(len(eg.reference) if eg.reference else 0 for eg in examples):
# Handle cases where there are no tokens in any docs. # Handle cases where there are no tokens in any docs.
return losses return losses
set_dropout_rate(self.model, drop) set_dropout_rate(self.model, drop)
total_loss = 0 total_loss = 0
for eg in examples: for eg in examples:
# replicates the EntityLinker's behaviour and # For update we use the gold coref_head_clusters
# copies annotations over https://bit.ly/3iweDcW # in the reference.
# https://github.com/explosion/spaCy/blob/master/spacy/pipeline/entity_linker.py#L313 span_scores, backprop = self.model.begin_update([eg.reference])
doc = eg.predicted
old_spans = eg.predicted.spans
for key, sg in eg.reference.spans.items():
if key.startswith(self.input_prefix):
doc.spans[key] = eg.get_aligned_spans_y2x(sg)
span_scores, backprop = self.model.begin_update([doc])
loss, d_scores = self.get_loss([eg], span_scores) loss, d_scores = self.get_loss([eg], span_scores)
total_loss += loss total_loss += loss
# TODO check shape here # TODO check shape here
backprop(d_scores) backprop(d_scores)
# Restore example
for key, sg in old_spans.items():
eg.predicted.spans[key] = sg
if sgd is not None: if sgd is not None:
self.finish_update(sgd) self.finish_update(sgd)
@ -570,17 +561,14 @@ class SpanPredictor(TrainablePipe):
# span_scores is a Floats3d. What are the axes? mention x token x start/end # span_scores is a Floats3d. What are the axes? mention x token x start/end
for eg in examples: for eg in examples:
# get gold data
gold = doc2clusters(eg.predicted, self.input_prefix)
# flatten the gold data
starts = [] starts = []
ends = [] ends = []
for cluster in gold: for key, sg in eg.reference.spans.items():
for mention in cluster: if key.startswith(self.output_prefix):
starts.append(mention[0]) for mention in sg:
# XXX I think this was missing here starts.append(mention.start)
ends.append(mention[1] - 1) ends.append(mention.end)
starts = self.model.ops.xp.asarray(starts) starts = self.model.ops.xp.asarray(starts)
ends = self.model.ops.xp.asarray(ends) ends = self.model.ops.xp.asarray(ends)
start_scores = span_scores[:, :, 0] start_scores = span_scores[:, :, 0]