Update tests

This commit is contained in:
Paul O'Leary McCann 2022-07-03 20:10:53 +09:00
parent fd574a89c4
commit cf33b48fe0
2 changed files with 12 additions and 12 deletions

View File

@ -207,11 +207,13 @@ def create_gold_scores(
def spans2ints(doc):
"""Convert doc.spans to nested list of ints for comparison.
The ints are token indices.
The ints are character indices, and the spans groups are sorted by key first.
This is useful for checking consistency of predictions.
"""
out = []
for key, cluster in doc.spans.items():
out.append([(ss.start, ss.end) for ss in cluster])
keys = sorted([key for key in doc.spans])
for key in keys:
cluster = doc.spans[key]
out.append([(ss.start_char, ss.end_char) for ss in cluster])
return out

View File

@ -114,13 +114,15 @@ def test_overfitting_IO(nlp):
test_text = TRAIN_DATA[0][0]
doc = nlp(test_text)
for i in range(1500):
for i in range(15):
losses = {}
nlp.update(train_examples, sgd=optimizer, losses=losses)
doc = nlp(test_text)
# test the trained model, using the pred since it has heads
doc = nlp(train_examples[0].predicted)
# XXX This actually tests that it can overfit
assert spans2ints(doc) == spans2ints(train_examples[0].reference)
# Also test the results are still the same after IO
with make_tempdir() as tmp_dir:
@ -134,6 +136,7 @@ def test_overfitting_IO(nlp):
"I noticed many friends around me",
"They received it. They received the SMS.",
]
# XXX Note these have no predictions because they have no input spans
docs1 = list(nlp.pipe(texts))
docs2 = list(nlp.pipe(texts))
docs3 = [nlp(text) for text in texts]
@ -175,7 +178,7 @@ def test_tokenization_mismatch(nlp):
test_text = TRAIN_DATA[0][0]
doc = nlp(test_text)
for i in range(100):
for i in range(15):
losses = {}
nlp.update(train_examples, sgd=optimizer, losses=losses)
doc = nlp(test_text)
@ -183,12 +186,8 @@ def test_tokenization_mismatch(nlp):
# test the trained model; need to use doc with head spans on it already
test_doc = train_examples[0].predicted
doc = nlp(test_doc)
# XXX DEBUG
print("SPANS", len(doc.spans))
for key, val in doc.spans.items():
print(key, val)
print("...")
# XXX This actually tests that it can overfit
assert spans2ints(doc) == spans2ints(train_examples[0].reference)
# Also test the results are still the same after IO
with make_tempdir() as tmp_dir:
@ -209,5 +208,4 @@ def test_tokenization_mismatch(nlp):
docs3 = [nlp(text) for text in texts]
assert spans2ints(docs1[0]) == spans2ints(docs2[0])
assert spans2ints(docs1[0]) == spans2ints(docs3[0])
assert False