Clean tests.

This commit is contained in:
Paul O'Leary McCann 2022-07-03 14:48:42 +09:00
parent 79720886fa
commit 5192ac1617

View File

@ -36,6 +36,7 @@ TRAIN_DATA = [
def spans2ints(doc): def spans2ints(doc):
"""Convert doc.spans to nested list of ints for comparison. """Convert doc.spans to nested list of ints for comparison.
The ints are token indices.
This is useful for checking consistency of predictions. This is useful for checking consistency of predictions.
""" """
@ -98,21 +99,14 @@ def test_coref_serialization(nlp):
assert nlp.pipe_names == ["coref"] assert nlp.pipe_names == ["coref"]
text = "She gave me her pen." text = "She gave me her pen."
doc = nlp(text) doc = nlp(text)
spans_result = doc.spans
with make_tempdir() as tmp_dir: with make_tempdir() as tmp_dir:
nlp.to_disk(tmp_dir) nlp.to_disk(tmp_dir)
nlp2 = spacy.load(tmp_dir) nlp2 = spacy.load(tmp_dir)
assert nlp2.pipe_names == ["coref"] assert nlp2.pipe_names == ["coref"]
doc2 = nlp2(text) doc2 = nlp2(text)
spans_result2 = doc2.spans
print(1, [(k, len(v)) for k, v in spans_result.items()]) assert spans2ints(doc) == spans2ints(doc2)
print(2, [(k, len(v)) for k, v in spans_result2.items()])
# Note: spans do not compare equal because docs are different and docs
# use object identity for equality
for k, v in spans_result.items():
assert str(spans_result[k]) == str(spans_result2[k])
# assert spans_result == spans_result2
@pytest.mark.skipif(not has_torch, reason="Torch not available") @pytest.mark.skipif(not has_torch, reason="Torch not available")
@ -148,18 +142,11 @@ def test_overfitting_IO(nlp):
"I noticed many friends around me", "I noticed many friends around me",
"They received it. They received the SMS.", "They received it. They received the SMS.",
] ]
docs = list(nlp.pipe(texts)) docs1 = list(nlp.pipe(texts))
batch_deps_1 = [doc.spans for doc in docs] docs2 = list(nlp.pipe(texts))
print(batch_deps_1) docs3 = [nlp(text) for text in texts]
docs = list(nlp.pipe(texts)) assert spans2ints(docs1[0]) == spans2ints(docs2[0])
batch_deps_2 = [doc.spans for doc in docs] assert spans2ints(docs1[0]) == spans2ints(docs3[0])
print(batch_deps_2)
docs = [nlp(text) for text in texts]
no_batch_deps = [doc.spans for doc in docs]
print(no_batch_deps)
print("FINISH")
# assert_equal(batch_deps_1, batch_deps_2)
# assert_equal(batch_deps_1, no_batch_deps)
@pytest.mark.skipif(not has_torch, reason="Torch not available") @pytest.mark.skipif(not has_torch, reason="Torch not available")
def test_tokenization_mismatch(nlp): def test_tokenization_mismatch(nlp):