mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-19 20:52:23 +03:00
Clean tests.
This commit is contained in:
parent
79720886fa
commit
5192ac1617
|
@ -36,6 +36,7 @@ TRAIN_DATA = [
|
||||||
|
|
||||||
def spans2ints(doc):
|
def spans2ints(doc):
|
||||||
"""Convert doc.spans to nested list of ints for comparison.
|
"""Convert doc.spans to nested list of ints for comparison.
|
||||||
|
The ints are token indices.
|
||||||
|
|
||||||
This is useful for checking consistency of predictions.
|
This is useful for checking consistency of predictions.
|
||||||
"""
|
"""
|
||||||
|
@ -98,21 +99,14 @@ def test_coref_serialization(nlp):
|
||||||
assert nlp.pipe_names == ["coref"]
|
assert nlp.pipe_names == ["coref"]
|
||||||
text = "She gave me her pen."
|
text = "She gave me her pen."
|
||||||
doc = nlp(text)
|
doc = nlp(text)
|
||||||
spans_result = doc.spans
|
|
||||||
|
|
||||||
with make_tempdir() as tmp_dir:
|
with make_tempdir() as tmp_dir:
|
||||||
nlp.to_disk(tmp_dir)
|
nlp.to_disk(tmp_dir)
|
||||||
nlp2 = spacy.load(tmp_dir)
|
nlp2 = spacy.load(tmp_dir)
|
||||||
assert nlp2.pipe_names == ["coref"]
|
assert nlp2.pipe_names == ["coref"]
|
||||||
doc2 = nlp2(text)
|
doc2 = nlp2(text)
|
||||||
spans_result2 = doc2.spans
|
|
||||||
print(1, [(k, len(v)) for k, v in spans_result.items()])
|
assert spans2ints(doc) == spans2ints(doc2)
|
||||||
print(2, [(k, len(v)) for k, v in spans_result2.items()])
|
|
||||||
# Note: spans do not compare equal because docs are different and docs
|
|
||||||
# use object identity for equality
|
|
||||||
for k, v in spans_result.items():
|
|
||||||
assert str(spans_result[k]) == str(spans_result2[k])
|
|
||||||
# assert spans_result == spans_result2
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||||
|
@ -148,18 +142,11 @@ def test_overfitting_IO(nlp):
|
||||||
"I noticed many friends around me",
|
"I noticed many friends around me",
|
||||||
"They received it. They received the SMS.",
|
"They received it. They received the SMS.",
|
||||||
]
|
]
|
||||||
docs = list(nlp.pipe(texts))
|
docs1 = list(nlp.pipe(texts))
|
||||||
batch_deps_1 = [doc.spans for doc in docs]
|
docs2 = list(nlp.pipe(texts))
|
||||||
print(batch_deps_1)
|
docs3 = [nlp(text) for text in texts]
|
||||||
docs = list(nlp.pipe(texts))
|
assert spans2ints(docs1[0]) == spans2ints(docs2[0])
|
||||||
batch_deps_2 = [doc.spans for doc in docs]
|
assert spans2ints(docs1[0]) == spans2ints(docs3[0])
|
||||||
print(batch_deps_2)
|
|
||||||
docs = [nlp(text) for text in texts]
|
|
||||||
no_batch_deps = [doc.spans for doc in docs]
|
|
||||||
print(no_batch_deps)
|
|
||||||
print("FINISH")
|
|
||||||
# assert_equal(batch_deps_1, batch_deps_2)
|
|
||||||
# assert_equal(batch_deps_1, no_batch_deps)
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||||
def test_tokenization_mismatch(nlp):
|
def test_tokenization_mismatch(nlp):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user