From b09bbc7f5eb8f04e2441f43f063492d3e2fc1d22 Mon Sep 17 00:00:00 2001 From: Paul O'Leary McCann Date: Sun, 3 Jul 2022 20:11:03 +0900 Subject: [PATCH] Fix alignment issues I believe this resolves issues with tokenization mismatches. --- spacy/pipeline/span_predictor.py | 33 ++++++++++++++++++++++++++------ 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/spacy/pipeline/span_predictor.py b/spacy/pipeline/span_predictor.py index d7e96a4b2..c9343a97e 100644 --- a/spacy/pipeline/span_predictor.py +++ b/spacy/pipeline/span_predictor.py @@ -231,16 +231,29 @@ class SpanPredictor(TrainablePipe): for eg in examples: starts = [] ends = [] + keeps = [] + sidx = 0 for key, sg in eg.reference.spans.items(): if key.startswith(self.output_prefix): - for mention in sg: - starts.append(mention.start) - ends.append(mention.end) + for ii, mention in enumerate(sg): + sidx += 1 + # convert to span in pred + sch, ech = (mention.start_char, mention.end_char) + span = eg.predicted.char_span(sch, ech) + # TODO add to errors.py + if span is None: + warnings.warn("Could not align gold span in span predictor, skipping") + continue + starts.append(span.start) + ends.append(span.end) + keeps.append(sidx - 1) starts = self.model.ops.xp.asarray(starts) ends = self.model.ops.xp.asarray(ends) - start_scores = span_scores[:, :, 0] - end_scores = span_scores[:, :, 1] + start_scores = span_scores[:, :, 0][keeps] + end_scores = span_scores[:, :, 1][keeps] + + n_classes = start_scores.shape[1] start_probs = ops.softmax(start_scores, axis=1) end_probs = ops.softmax(end_scores, axis=1) @@ -248,7 +261,14 @@ class SpanPredictor(TrainablePipe): end_targets = to_categorical(ends, n_classes) start_grads = start_probs - start_targets end_grads = end_probs - end_targets - grads = ops.xp.stack((start_grads, end_grads), axis=2) + # now return to original shape, with 0s + final_start_grads = ops.alloc2f(*span_scores[:, :, 0].shape) + final_start_grads[keeps] = start_grads + final_end_grads = ops.alloc2f(*final_start_grads.shape) + final_end_grads[keeps] = end_grads + # XXX Note this only works with fake batching + grads = ops.xp.stack((final_start_grads, final_end_grads), axis=2) + loss = float((grads**2).sum()) return loss, grads @@ -267,6 +287,7 @@ class SpanPredictor(TrainablePipe): if not ex.predicted.spans: # set placeholder for shape inference doc = ex.predicted + # TODO should be able to check if there are some valid docs in the batch assert len(doc) > 2, "Coreference requires at least two tokens" doc.spans[f"{self.input_prefix}_0"] = [doc[0:1], doc[1:2]] X.append(ex.predicted)