mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 20:22:25 +03:00
prepare for aligned heads-spans training
This commit is contained in:
parent
63a41ba50a
commit
a1d0219903
|
@ -503,29 +503,20 @@ class SpanPredictor(TrainablePipe):
|
|||
losses = {}
|
||||
losses.setdefault(self.name, 0.0)
|
||||
validate_examples(examples, "SpanPredictor.update")
|
||||
if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples):
|
||||
if not any(len(eg.reference) if eg.reference else 0 for eg in examples):
|
||||
# Handle cases where there are no tokens in any docs.
|
||||
return losses
|
||||
set_dropout_rate(self.model, drop)
|
||||
|
||||
total_loss = 0
|
||||
for eg in examples:
|
||||
# replicates the EntityLinker's behaviour and
|
||||
# copies annotations over https://bit.ly/3iweDcW
|
||||
# https://github.com/explosion/spaCy/blob/master/spacy/pipeline/entity_linker.py#L313
|
||||
doc = eg.predicted
|
||||
old_spans = eg.predicted.spans
|
||||
for key, sg in eg.reference.spans.items():
|
||||
if key.startswith(self.input_prefix):
|
||||
doc.spans[key] = eg.get_aligned_spans_y2x(sg)
|
||||
span_scores, backprop = self.model.begin_update([doc])
|
||||
# For update we use the gold coref_head_clusters
|
||||
# in the reference.
|
||||
span_scores, backprop = self.model.begin_update([eg.reference])
|
||||
loss, d_scores = self.get_loss([eg], span_scores)
|
||||
total_loss += loss
|
||||
# TODO check shape here
|
||||
backprop(d_scores)
|
||||
# Restore example
|
||||
for key, sg in old_spans.items():
|
||||
eg.predicted.spans[key] = sg
|
||||
|
||||
if sgd is not None:
|
||||
self.finish_update(sgd)
|
||||
|
@ -570,17 +561,14 @@ class SpanPredictor(TrainablePipe):
|
|||
# span_scores is a Floats3d. What are the axes? mention x token x start/end
|
||||
|
||||
for eg in examples:
|
||||
|
||||
# get gold data
|
||||
gold = doc2clusters(eg.predicted, self.input_prefix)
|
||||
# flatten the gold data
|
||||
starts = []
|
||||
ends = []
|
||||
for cluster in gold:
|
||||
for mention in cluster:
|
||||
starts.append(mention[0])
|
||||
# XXX I think this was missing here
|
||||
ends.append(mention[1] - 1)
|
||||
for key, sg in eg.reference.spans.items():
|
||||
if key.startswith(self.output_prefix):
|
||||
for mention in sg:
|
||||
starts.append(mention.start)
|
||||
ends.append(mention.end)
|
||||
|
||||
starts = self.model.ops.xp.asarray(starts)
|
||||
ends = self.model.ops.xp.asarray(ends)
|
||||
start_scores = span_scores[:, :, 0]
|
||||
|
|
Loading…
Reference in New Issue
Block a user