mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 04:02:20 +03:00
Fix alignment issues
I believe this resolves issues with tokenization mismatches.
This commit is contained in:
parent
cf33b48fe0
commit
b09bbc7f5e
|
@ -231,16 +231,29 @@ class SpanPredictor(TrainablePipe):
|
|||
for eg in examples:
|
||||
starts = []
|
||||
ends = []
|
||||
keeps = []
|
||||
sidx = 0
|
||||
for key, sg in eg.reference.spans.items():
|
||||
if key.startswith(self.output_prefix):
|
||||
for mention in sg:
|
||||
starts.append(mention.start)
|
||||
ends.append(mention.end)
|
||||
for ii, mention in enumerate(sg):
|
||||
sidx += 1
|
||||
# convert to span in pred
|
||||
sch, ech = (mention.start_char, mention.end_char)
|
||||
span = eg.predicted.char_span(sch, ech)
|
||||
# TODO add to errors.py
|
||||
if span is None:
|
||||
warnings.warn("Could not align gold span in span predictor, skipping")
|
||||
continue
|
||||
starts.append(span.start)
|
||||
ends.append(span.end)
|
||||
keeps.append(sidx - 1)
|
||||
|
||||
starts = self.model.ops.xp.asarray(starts)
|
||||
ends = self.model.ops.xp.asarray(ends)
|
||||
start_scores = span_scores[:, :, 0]
|
||||
end_scores = span_scores[:, :, 1]
|
||||
start_scores = span_scores[:, :, 0][keeps]
|
||||
end_scores = span_scores[:, :, 1][keeps]
|
||||
|
||||
|
||||
n_classes = start_scores.shape[1]
|
||||
start_probs = ops.softmax(start_scores, axis=1)
|
||||
end_probs = ops.softmax(end_scores, axis=1)
|
||||
|
@ -248,7 +261,14 @@ class SpanPredictor(TrainablePipe):
|
|||
end_targets = to_categorical(ends, n_classes)
|
||||
start_grads = start_probs - start_targets
|
||||
end_grads = end_probs - end_targets
|
||||
grads = ops.xp.stack((start_grads, end_grads), axis=2)
|
||||
# now return to original shape, with 0s
|
||||
final_start_grads = ops.alloc2f(*span_scores[:, :, 0].shape)
|
||||
final_start_grads[keeps] = start_grads
|
||||
final_end_grads = ops.alloc2f(*final_start_grads.shape)
|
||||
final_end_grads[keeps] = end_grads
|
||||
# XXX Note this only works with fake batching
|
||||
grads = ops.xp.stack((final_start_grads, final_end_grads), axis=2)
|
||||
|
||||
loss = float((grads**2).sum())
|
||||
return loss, grads
|
||||
|
||||
|
@ -267,6 +287,7 @@ class SpanPredictor(TrainablePipe):
|
|||
if not ex.predicted.spans:
|
||||
# set placeholder for shape inference
|
||||
doc = ex.predicted
|
||||
# TODO should be able to check if there are some valid docs in the batch
|
||||
assert len(doc) > 2, "Coreference requires at least two tokens"
|
||||
doc.spans[f"{self.input_prefix}_0"] = [doc[0:1], doc[1:2]]
|
||||
X.append(ex.predicted)
|
||||
|
|
Loading…
Reference in New Issue
Block a user