mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-19 04:32:32 +03:00
Clean tests.
This commit is contained in:
parent
79720886fa
commit
5192ac1617
|
@ -36,6 +36,7 @@ TRAIN_DATA = [
|
|||
|
||||
def spans2ints(doc):
|
||||
"""Convert doc.spans to nested list of ints for comparison.
|
||||
The ints are token indices.
|
||||
|
||||
This is useful for checking consistency of predictions.
|
||||
"""
|
||||
|
@ -98,21 +99,14 @@ def test_coref_serialization(nlp):
|
|||
assert nlp.pipe_names == ["coref"]
|
||||
text = "She gave me her pen."
|
||||
doc = nlp(text)
|
||||
spans_result = doc.spans
|
||||
|
||||
with make_tempdir() as tmp_dir:
|
||||
nlp.to_disk(tmp_dir)
|
||||
nlp2 = spacy.load(tmp_dir)
|
||||
assert nlp2.pipe_names == ["coref"]
|
||||
doc2 = nlp2(text)
|
||||
spans_result2 = doc2.spans
|
||||
print(1, [(k, len(v)) for k, v in spans_result.items()])
|
||||
print(2, [(k, len(v)) for k, v in spans_result2.items()])
|
||||
# Note: spans do not compare equal because docs are different and docs
|
||||
# use object identity for equality
|
||||
for k, v in spans_result.items():
|
||||
assert str(spans_result[k]) == str(spans_result2[k])
|
||||
# assert spans_result == spans_result2
|
||||
|
||||
assert spans2ints(doc) == spans2ints(doc2)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||
|
@ -148,18 +142,11 @@ def test_overfitting_IO(nlp):
|
|||
"I noticed many friends around me",
|
||||
"They received it. They received the SMS.",
|
||||
]
|
||||
docs = list(nlp.pipe(texts))
|
||||
batch_deps_1 = [doc.spans for doc in docs]
|
||||
print(batch_deps_1)
|
||||
docs = list(nlp.pipe(texts))
|
||||
batch_deps_2 = [doc.spans for doc in docs]
|
||||
print(batch_deps_2)
|
||||
docs = [nlp(text) for text in texts]
|
||||
no_batch_deps = [doc.spans for doc in docs]
|
||||
print(no_batch_deps)
|
||||
print("FINISH")
|
||||
# assert_equal(batch_deps_1, batch_deps_2)
|
||||
# assert_equal(batch_deps_1, no_batch_deps)
|
||||
docs1 = list(nlp.pipe(texts))
|
||||
docs2 = list(nlp.pipe(texts))
|
||||
docs3 = [nlp(text) for text in texts]
|
||||
assert spans2ints(docs1[0]) == spans2ints(docs2[0])
|
||||
assert spans2ints(docs1[0]) == spans2ints(docs3[0])
|
||||
|
||||
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||
def test_tokenization_mismatch(nlp):
|
||||
|
|
Loading…
Reference in New Issue
Block a user