Add tests to give up with whitespace differences

Docs in Examples are allowed to have arbitrarily different whitespace.
Handling that properly would be nice but isn't required, but for now
check for it and blow up.
This commit is contained in:
Paul O'Leary McCann 2022-07-04 19:37:42 +09:00
parent c7f333d593
commit 178feae00a
4 changed files with 49 additions and 2 deletions

View File

@ -218,6 +218,13 @@ class CoreferenceResolver(TrainablePipe):
total_loss = 0 total_loss = 0
for eg in examples: for eg in examples:
if eg.x.text != eg.y.text:
# TODO assign error number
raise ValueError(
"""Text, including whitespace, must match between reference and
predicted docs in coref training.
"""
)
# TODO check this causes no issues (in practice it runs) # TODO check this causes no issues (in practice it runs)
preds, backprop = self.model.begin_update([eg.predicted]) preds, backprop = self.model.begin_update([eg.predicted])
score_matrix, mention_idx = preds score_matrix, mention_idx = preds
@ -277,7 +284,7 @@ class CoreferenceResolver(TrainablePipe):
if span is None: if span is None:
# TODO log more details # TODO log more details
raise IndexError(Errors.E1043) raise IndexError(Errors.E1043)
cc.append( (span.start, span.end) ) cc.append((span.start, span.end))
clusters.append(cc) clusters.append(cc)
span_idxs = create_head_span_idxs(ops, len(example.predicted)) span_idxs = create_head_span_idxs(ops, len(example.predicted))

View File

@ -178,6 +178,13 @@ class SpanPredictor(TrainablePipe):
total_loss = 0 total_loss = 0
for eg in examples: for eg in examples:
if eg.x.text != eg.y.text:
# TODO assign error number
raise ValueError(
"""Text, including whitespace, must match between reference and
predicted docs in span predictor training.
"""
)
span_scores, backprop = self.model.begin_update([eg.predicted]) span_scores, backprop = self.model.begin_update([eg.predicted])
# FIXME, this only happens once in the first 1000 docs of OntoNotes # FIXME, this only happens once in the first 1000 docs of OntoNotes
# and I'm not sure yet why. # and I'm not sure yet why.

View File

@ -218,3 +218,20 @@ def test_sentence_map(snlp):
doc = snlp("I like text. This is text.") doc = snlp("I like text. This is text.")
sm = get_sentence_ids(doc) sm = get_sentence_ids(doc)
assert sm == [0, 0, 0, 0, 1, 1, 1, 1] assert sm == [0, 0, 0, 0, 1, 1, 1, 1]
@pytest.mark.skipif(not has_torch, reason="Torch not available")
def test_whitespace_mismatch(nlp):
train_examples = []
for text, annot in TRAIN_DATA:
eg = Example.from_dict(nlp.make_doc(text), annot)
eg.predicted = nlp.make_doc(" " + text)
train_examples.append(eg)
nlp.add_pipe("coref", config=CONFIG)
optimizer = nlp.initialize()
test_text = TRAIN_DATA[0][0]
doc = nlp(test_text)
with pytest.raises(ValueError, match="whitespace"):
nlp.update(train_examples, sgd=optimizer)

View File

@ -106,7 +106,7 @@ def test_overfitting_IO(nlp):
pred = eg.predicted pred = eg.predicted
for key, spans in ref.spans.items(): for key, spans in ref.spans.items():
if key.startswith("coref_head_clusters"): if key.startswith("coref_head_clusters"):
pred.spans[key] = [pred[span.start:span.end] for span in spans] pred.spans[key] = [pred[span.start : span.end] for span in spans]
train_examples.append(eg) train_examples.append(eg)
nlp.add_pipe("span_predictor", config=CONFIG) nlp.add_pipe("span_predictor", config=CONFIG)
@ -209,3 +209,19 @@ def test_tokenization_mismatch(nlp):
assert _spans_to_offsets(docs1[0]) == _spans_to_offsets(docs2[0]) assert _spans_to_offsets(docs1[0]) == _spans_to_offsets(docs2[0])
assert _spans_to_offsets(docs1[0]) == _spans_to_offsets(docs3[0]) assert _spans_to_offsets(docs1[0]) == _spans_to_offsets(docs3[0])
@pytest.mark.skipif(not has_torch, reason="Torch not available")
def test_whitespace_mismatch(nlp):
train_examples = []
for text, annot in TRAIN_DATA:
eg = Example.from_dict(nlp.make_doc(text), annot)
eg.predicted = nlp.make_doc(" " + text)
train_examples.append(eg)
nlp.add_pipe("span_predictor", config=CONFIG)
optimizer = nlp.initialize()
test_text = TRAIN_DATA[0][0]
doc = nlp(test_text)
with pytest.raises(ValueError, match="whitespace"):
nlp.update(train_examples, sgd=optimizer)