mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 20:22:25 +03:00
prepare for aligned heads-spans training
This commit is contained in:
parent
63a41ba50a
commit
a1d0219903
|
@ -503,29 +503,20 @@ class SpanPredictor(TrainablePipe):
|
||||||
losses = {}
|
losses = {}
|
||||||
losses.setdefault(self.name, 0.0)
|
losses.setdefault(self.name, 0.0)
|
||||||
validate_examples(examples, "SpanPredictor.update")
|
validate_examples(examples, "SpanPredictor.update")
|
||||||
if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples):
|
if not any(len(eg.reference) if eg.reference else 0 for eg in examples):
|
||||||
# Handle cases where there are no tokens in any docs.
|
# Handle cases where there are no tokens in any docs.
|
||||||
return losses
|
return losses
|
||||||
set_dropout_rate(self.model, drop)
|
set_dropout_rate(self.model, drop)
|
||||||
|
|
||||||
total_loss = 0
|
total_loss = 0
|
||||||
for eg in examples:
|
for eg in examples:
|
||||||
# replicates the EntityLinker's behaviour and
|
# For update we use the gold coref_head_clusters
|
||||||
# copies annotations over https://bit.ly/3iweDcW
|
# in the reference.
|
||||||
# https://github.com/explosion/spaCy/blob/master/spacy/pipeline/entity_linker.py#L313
|
span_scores, backprop = self.model.begin_update([eg.reference])
|
||||||
doc = eg.predicted
|
|
||||||
old_spans = eg.predicted.spans
|
|
||||||
for key, sg in eg.reference.spans.items():
|
|
||||||
if key.startswith(self.input_prefix):
|
|
||||||
doc.spans[key] = eg.get_aligned_spans_y2x(sg)
|
|
||||||
span_scores, backprop = self.model.begin_update([doc])
|
|
||||||
loss, d_scores = self.get_loss([eg], span_scores)
|
loss, d_scores = self.get_loss([eg], span_scores)
|
||||||
total_loss += loss
|
total_loss += loss
|
||||||
# TODO check shape here
|
# TODO check shape here
|
||||||
backprop(d_scores)
|
backprop(d_scores)
|
||||||
# Restore example
|
|
||||||
for key, sg in old_spans.items():
|
|
||||||
eg.predicted.spans[key] = sg
|
|
||||||
|
|
||||||
if sgd is not None:
|
if sgd is not None:
|
||||||
self.finish_update(sgd)
|
self.finish_update(sgd)
|
||||||
|
@ -570,17 +561,14 @@ class SpanPredictor(TrainablePipe):
|
||||||
# span_scores is a Floats3d. What are the axes? mention x token x start/end
|
# span_scores is a Floats3d. What are the axes? mention x token x start/end
|
||||||
|
|
||||||
for eg in examples:
|
for eg in examples:
|
||||||
|
|
||||||
# get gold data
|
|
||||||
gold = doc2clusters(eg.predicted, self.input_prefix)
|
|
||||||
# flatten the gold data
|
|
||||||
starts = []
|
starts = []
|
||||||
ends = []
|
ends = []
|
||||||
for cluster in gold:
|
for key, sg in eg.reference.spans.items():
|
||||||
for mention in cluster:
|
if key.startswith(self.output_prefix):
|
||||||
starts.append(mention[0])
|
for mention in sg:
|
||||||
# XXX I think this was missing here
|
starts.append(mention.start)
|
||||||
ends.append(mention[1] - 1)
|
ends.append(mention.end)
|
||||||
|
|
||||||
starts = self.model.ops.xp.asarray(starts)
|
starts = self.model.ops.xp.asarray(starts)
|
||||||
ends = self.model.ops.xp.asarray(ends)
|
ends = self.model.ops.xp.asarray(ends)
|
||||||
start_scores = span_scores[:, :, 0]
|
start_scores = span_scores[:, :, 0]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user