Add tests to give up with whitespace differences

Docs in Examples are allowed to have arbitrarily different whitespace.
Handling that properly would be nice but isn't required, but for now
check for it and blow up.
This commit is contained in:
Paul O'Leary McCann 2022-07-04 19:37:42 +09:00
parent c7f333d593
commit 178feae00a
4 changed files with 49 additions and 2 deletions

View File

@ -218,6 +218,13 @@ class CoreferenceResolver(TrainablePipe):
total_loss = 0
for eg in examples:
if eg.x.text != eg.y.text:
# TODO assign error number
raise ValueError(
"""Text, including whitespace, must match between reference and
predicted docs in coref training.
"""
)
# TODO check this causes no issues (in practice it runs)
preds, backprop = self.model.begin_update([eg.predicted])
score_matrix, mention_idx = preds
@ -277,7 +284,7 @@ class CoreferenceResolver(TrainablePipe):
if span is None:
# TODO log more details
raise IndexError(Errors.E1043)
cc.append( (span.start, span.end) )
cc.append((span.start, span.end))
clusters.append(cc)
span_idxs = create_head_span_idxs(ops, len(example.predicted))

View File

@ -178,6 +178,13 @@ class SpanPredictor(TrainablePipe):
total_loss = 0
for eg in examples:
if eg.x.text != eg.y.text:
# TODO assign error number
raise ValueError(
"""Text, including whitespace, must match between reference and
predicted docs in span predictor training.
"""
)
span_scores, backprop = self.model.begin_update([eg.predicted])
# FIXME, this only happens once in the first 1000 docs of OntoNotes
# and I'm not sure yet why.

View File

@ -218,3 +218,20 @@ def test_sentence_map(snlp):
doc = snlp("I like text. This is text.")
sm = get_sentence_ids(doc)
assert sm == [0, 0, 0, 0, 1, 1, 1, 1]
@pytest.mark.skipif(not has_torch, reason="Torch not available")
def test_whitespace_mismatch(nlp):
train_examples = []
for text, annot in TRAIN_DATA:
eg = Example.from_dict(nlp.make_doc(text), annot)
eg.predicted = nlp.make_doc(" " + text)
train_examples.append(eg)
nlp.add_pipe("coref", config=CONFIG)
optimizer = nlp.initialize()
test_text = TRAIN_DATA[0][0]
doc = nlp(test_text)
with pytest.raises(ValueError, match="whitespace"):
nlp.update(train_examples, sgd=optimizer)

View File

@ -106,7 +106,7 @@ def test_overfitting_IO(nlp):
pred = eg.predicted
for key, spans in ref.spans.items():
if key.startswith("coref_head_clusters"):
pred.spans[key] = [pred[span.start:span.end] for span in spans]
pred.spans[key] = [pred[span.start : span.end] for span in spans]
train_examples.append(eg)
nlp.add_pipe("span_predictor", config=CONFIG)
@ -209,3 +209,19 @@ def test_tokenization_mismatch(nlp):
assert _spans_to_offsets(docs1[0]) == _spans_to_offsets(docs2[0])
assert _spans_to_offsets(docs1[0]) == _spans_to_offsets(docs3[0])
@pytest.mark.skipif(not has_torch, reason="Torch not available")
def test_whitespace_mismatch(nlp):
train_examples = []
for text, annot in TRAIN_DATA:
eg = Example.from_dict(nlp.make_doc(text), annot)
eg.predicted = nlp.make_doc(" " + text)
train_examples.append(eg)
nlp.add_pipe("span_predictor", config=CONFIG)
optimizer = nlp.initialize()
test_text = TRAIN_DATA[0][0]
doc = nlp(test_text)
with pytest.raises(ValueError, match="whitespace"):
nlp.update(train_examples, sgd=optimizer)