Fix alignment issues

I believe this resolves issues with tokenization mismatches.
This commit is contained in:
Paul O'Leary McCann 2022-07-03 20:11:03 +09:00
parent cf33b48fe0
commit b09bbc7f5e

View File

@ -231,16 +231,29 @@ class SpanPredictor(TrainablePipe):
for eg in examples:
starts = []
ends = []
keeps = []
sidx = 0
for key, sg in eg.reference.spans.items():
if key.startswith(self.output_prefix):
for mention in sg:
starts.append(mention.start)
ends.append(mention.end)
for ii, mention in enumerate(sg):
sidx += 1
# convert to span in pred
sch, ech = (mention.start_char, mention.end_char)
span = eg.predicted.char_span(sch, ech)
# TODO add to errors.py
if span is None:
warnings.warn("Could not align gold span in span predictor, skipping")
continue
starts.append(span.start)
ends.append(span.end)
keeps.append(sidx - 1)
starts = self.model.ops.xp.asarray(starts)
ends = self.model.ops.xp.asarray(ends)
start_scores = span_scores[:, :, 0]
end_scores = span_scores[:, :, 1]
start_scores = span_scores[:, :, 0][keeps]
end_scores = span_scores[:, :, 1][keeps]
n_classes = start_scores.shape[1]
start_probs = ops.softmax(start_scores, axis=1)
end_probs = ops.softmax(end_scores, axis=1)
@ -248,7 +261,14 @@ class SpanPredictor(TrainablePipe):
end_targets = to_categorical(ends, n_classes)
start_grads = start_probs - start_targets
end_grads = end_probs - end_targets
grads = ops.xp.stack((start_grads, end_grads), axis=2)
# now return to original shape, with 0s
final_start_grads = ops.alloc2f(*span_scores[:, :, 0].shape)
final_start_grads[keeps] = start_grads
final_end_grads = ops.alloc2f(*final_start_grads.shape)
final_end_grads[keeps] = end_grads
# XXX Note this only works with fake batching
grads = ops.xp.stack((final_start_grads, final_end_grads), axis=2)
loss = float((grads**2).sum())
return loss, grads
@ -267,6 +287,7 @@ class SpanPredictor(TrainablePipe):
if not ex.predicted.spans:
# set placeholder for shape inference
doc = ex.predicted
# TODO should be able to check if there are some valid docs in the batch
assert len(doc) > 2, "Coreference requires at least two tokens"
doc.spans[f"{self.input_prefix}_0"] = [doc[0:1], doc[1:2]]
X.append(ex.predicted)