Update tests

This commit is contained in:
Paul O'Leary McCann 2022-07-03 20:10:53 +09:00
parent fd574a89c4
commit cf33b48fe0
2 changed files with 12 additions and 12 deletions

View File

@ -207,11 +207,13 @@ def create_gold_scores(
def spans2ints(doc): def spans2ints(doc):
"""Convert doc.spans to nested list of ints for comparison. """Convert doc.spans to nested list of ints for comparison.
The ints are token indices. The ints are character indices, and the spans groups are sorted by key first.
This is useful for checking consistency of predictions. This is useful for checking consistency of predictions.
""" """
out = [] out = []
for key, cluster in doc.spans.items(): keys = sorted([key for key in doc.spans])
out.append([(ss.start, ss.end) for ss in cluster]) for key in keys:
cluster = doc.spans[key]
out.append([(ss.start_char, ss.end_char) for ss in cluster])
return out return out

View File

@ -114,13 +114,15 @@ def test_overfitting_IO(nlp):
test_text = TRAIN_DATA[0][0] test_text = TRAIN_DATA[0][0]
doc = nlp(test_text) doc = nlp(test_text)
for i in range(1500): for i in range(15):
losses = {} losses = {}
nlp.update(train_examples, sgd=optimizer, losses=losses) nlp.update(train_examples, sgd=optimizer, losses=losses)
doc = nlp(test_text) doc = nlp(test_text)
# test the trained model, using the pred since it has heads # test the trained model, using the pred since it has heads
doc = nlp(train_examples[0].predicted) doc = nlp(train_examples[0].predicted)
# XXX This actually tests that it can overfit
assert spans2ints(doc) == spans2ints(train_examples[0].reference)
# Also test the results are still the same after IO # Also test the results are still the same after IO
with make_tempdir() as tmp_dir: with make_tempdir() as tmp_dir:
@ -134,6 +136,7 @@ def test_overfitting_IO(nlp):
"I noticed many friends around me", "I noticed many friends around me",
"They received it. They received the SMS.", "They received it. They received the SMS.",
] ]
# XXX Note these have no predictions because they have no input spans
docs1 = list(nlp.pipe(texts)) docs1 = list(nlp.pipe(texts))
docs2 = list(nlp.pipe(texts)) docs2 = list(nlp.pipe(texts))
docs3 = [nlp(text) for text in texts] docs3 = [nlp(text) for text in texts]
@ -175,7 +178,7 @@ def test_tokenization_mismatch(nlp):
test_text = TRAIN_DATA[0][0] test_text = TRAIN_DATA[0][0]
doc = nlp(test_text) doc = nlp(test_text)
for i in range(100): for i in range(15):
losses = {} losses = {}
nlp.update(train_examples, sgd=optimizer, losses=losses) nlp.update(train_examples, sgd=optimizer, losses=losses)
doc = nlp(test_text) doc = nlp(test_text)
@ -183,12 +186,8 @@ def test_tokenization_mismatch(nlp):
# test the trained model; need to use doc with head spans on it already # test the trained model; need to use doc with head spans on it already
test_doc = train_examples[0].predicted test_doc = train_examples[0].predicted
doc = nlp(test_doc) doc = nlp(test_doc)
# XXX This actually tests that it can overfit
# XXX DEBUG assert spans2ints(doc) == spans2ints(train_examples[0].reference)
print("SPANS", len(doc.spans))
for key, val in doc.spans.items():
print(key, val)
print("...")
# Also test the results are still the same after IO # Also test the results are still the same after IO
with make_tempdir() as tmp_dir: with make_tempdir() as tmp_dir:
@ -209,5 +208,4 @@ def test_tokenization_mismatch(nlp):
docs3 = [nlp(text) for text in texts] docs3 = [nlp(text) for text in texts]
assert spans2ints(docs1[0]) == spans2ints(docs2[0]) assert spans2ints(docs1[0]) == spans2ints(docs2[0])
assert spans2ints(docs1[0]) == spans2ints(docs3[0]) assert spans2ints(docs1[0]) == spans2ints(docs3[0])
assert False