Clean tests.

This commit is contained in:
Paul O'Leary McCann 2022-07-03 14:48:42 +09:00
parent 79720886fa
commit 5192ac1617

View File

@ -35,7 +35,8 @@ TRAIN_DATA = [
# fmt: on
def spans2ints(doc):
"""Convert doc.spans to nested list of ints for comparison.
"""Convert doc.spans to nested list of ints for comparison.
The ints are token indices.
This is useful for checking consistency of predictions.
"""
@ -98,21 +99,14 @@ def test_coref_serialization(nlp):
assert nlp.pipe_names == ["coref"]
text = "She gave me her pen."
doc = nlp(text)
spans_result = doc.spans
with make_tempdir() as tmp_dir:
nlp.to_disk(tmp_dir)
nlp2 = spacy.load(tmp_dir)
assert nlp2.pipe_names == ["coref"]
doc2 = nlp2(text)
spans_result2 = doc2.spans
print(1, [(k, len(v)) for k, v in spans_result.items()])
print(2, [(k, len(v)) for k, v in spans_result2.items()])
# Note: spans do not compare equal because docs are different and docs
# use object identity for equality
for k, v in spans_result.items():
assert str(spans_result[k]) == str(spans_result2[k])
# assert spans_result == spans_result2
assert spans2ints(doc) == spans2ints(doc2)
@pytest.mark.skipif(not has_torch, reason="Torch not available")
@ -148,18 +142,11 @@ def test_overfitting_IO(nlp):
"I noticed many friends around me",
"They received it. They received the SMS.",
]
docs = list(nlp.pipe(texts))
batch_deps_1 = [doc.spans for doc in docs]
print(batch_deps_1)
docs = list(nlp.pipe(texts))
batch_deps_2 = [doc.spans for doc in docs]
print(batch_deps_2)
docs = [nlp(text) for text in texts]
no_batch_deps = [doc.spans for doc in docs]
print(no_batch_deps)
print("FINISH")
# assert_equal(batch_deps_1, batch_deps_2)
# assert_equal(batch_deps_1, no_batch_deps)
docs1 = list(nlp.pipe(texts))
docs2 = list(nlp.pipe(texts))
docs3 = [nlp(text) for text in texts]
assert spans2ints(docs1[0]) == spans2ints(docs2[0])
assert spans2ints(docs1[0]) == spans2ints(docs3[0])
@pytest.mark.skipif(not has_torch, reason="Torch not available")
def test_tokenization_mismatch(nlp):