Clean tests.

This commit is contained in:
Paul O'Leary McCann 2022-07-03 14:48:42 +09:00
parent 79720886fa
commit 5192ac1617

View File

@ -36,6 +36,7 @@ TRAIN_DATA = [
def spans2ints(doc):
"""Convert doc.spans to nested list of ints for comparison.
The ints are token indices.
This is useful for checking consistency of predictions.
"""
@ -98,21 +99,14 @@ def test_coref_serialization(nlp):
assert nlp.pipe_names == ["coref"]
text = "She gave me her pen."
doc = nlp(text)
spans_result = doc.spans
with make_tempdir() as tmp_dir:
nlp.to_disk(tmp_dir)
nlp2 = spacy.load(tmp_dir)
assert nlp2.pipe_names == ["coref"]
doc2 = nlp2(text)
spans_result2 = doc2.spans
print(1, [(k, len(v)) for k, v in spans_result.items()])
print(2, [(k, len(v)) for k, v in spans_result2.items()])
# Note: spans do not compare equal because docs are different and docs
# use object identity for equality
for k, v in spans_result.items():
assert str(spans_result[k]) == str(spans_result2[k])
# assert spans_result == spans_result2
assert spans2ints(doc) == spans2ints(doc2)
@pytest.mark.skipif(not has_torch, reason="Torch not available")
@ -148,18 +142,11 @@ def test_overfitting_IO(nlp):
"I noticed many friends around me",
"They received it. They received the SMS.",
]
docs = list(nlp.pipe(texts))
batch_deps_1 = [doc.spans for doc in docs]
print(batch_deps_1)
docs = list(nlp.pipe(texts))
batch_deps_2 = [doc.spans for doc in docs]
print(batch_deps_2)
docs = [nlp(text) for text in texts]
no_batch_deps = [doc.spans for doc in docs]
print(no_batch_deps)
print("FINISH")
# assert_equal(batch_deps_1, batch_deps_2)
# assert_equal(batch_deps_1, no_batch_deps)
docs1 = list(nlp.pipe(texts))
docs2 = list(nlp.pipe(texts))
docs3 = [nlp(text) for text in texts]
assert spans2ints(docs1[0]) == spans2ints(docs2[0])
assert spans2ints(docs1[0]) == spans2ints(docs3[0])
@pytest.mark.skipif(not has_torch, reason="Torch not available")
def test_tokenization_mismatch(nlp):