Fix alignment issues

I believe this resolves issues with tokenization mismatches.
This commit is contained in:
Paul O'Leary McCann 2022-07-03 20:11:03 +09:00
parent cf33b48fe0
commit b09bbc7f5e

View File

@ -231,16 +231,29 @@ class SpanPredictor(TrainablePipe):
for eg in examples: for eg in examples:
starts = [] starts = []
ends = [] ends = []
keeps = []
sidx = 0
for key, sg in eg.reference.spans.items(): for key, sg in eg.reference.spans.items():
if key.startswith(self.output_prefix): if key.startswith(self.output_prefix):
for mention in sg: for ii, mention in enumerate(sg):
starts.append(mention.start) sidx += 1
ends.append(mention.end) # convert to span in pred
sch, ech = (mention.start_char, mention.end_char)
span = eg.predicted.char_span(sch, ech)
# TODO add to errors.py
if span is None:
warnings.warn("Could not align gold span in span predictor, skipping")
continue
starts.append(span.start)
ends.append(span.end)
keeps.append(sidx - 1)
starts = self.model.ops.xp.asarray(starts) starts = self.model.ops.xp.asarray(starts)
ends = self.model.ops.xp.asarray(ends) ends = self.model.ops.xp.asarray(ends)
start_scores = span_scores[:, :, 0] start_scores = span_scores[:, :, 0][keeps]
end_scores = span_scores[:, :, 1] end_scores = span_scores[:, :, 1][keeps]
n_classes = start_scores.shape[1] n_classes = start_scores.shape[1]
start_probs = ops.softmax(start_scores, axis=1) start_probs = ops.softmax(start_scores, axis=1)
end_probs = ops.softmax(end_scores, axis=1) end_probs = ops.softmax(end_scores, axis=1)
@ -248,7 +261,14 @@ class SpanPredictor(TrainablePipe):
end_targets = to_categorical(ends, n_classes) end_targets = to_categorical(ends, n_classes)
start_grads = start_probs - start_targets start_grads = start_probs - start_targets
end_grads = end_probs - end_targets end_grads = end_probs - end_targets
grads = ops.xp.stack((start_grads, end_grads), axis=2) # now return to original shape, with 0s
final_start_grads = ops.alloc2f(*span_scores[:, :, 0].shape)
final_start_grads[keeps] = start_grads
final_end_grads = ops.alloc2f(*final_start_grads.shape)
final_end_grads[keeps] = end_grads
# XXX Note this only works with fake batching
grads = ops.xp.stack((final_start_grads, final_end_grads), axis=2)
loss = float((grads**2).sum()) loss = float((grads**2).sum())
return loss, grads return loss, grads
@ -267,6 +287,7 @@ class SpanPredictor(TrainablePipe):
if not ex.predicted.spans: if not ex.predicted.spans:
# set placeholder for shape inference # set placeholder for shape inference
doc = ex.predicted doc = ex.predicted
# TODO should be able to check if there are some valid docs in the batch
assert len(doc) > 2, "Coreference requires at least two tokens" assert len(doc) > 2, "Coreference requires at least two tokens"
doc.spans[f"{self.input_prefix}_0"] = [doc[0:1], doc[1:2]] doc.spans[f"{self.input_prefix}_0"] = [doc[0:1], doc[1:2]]
X.append(ex.predicted) X.append(ex.predicted)