From 5192ac16170c51e4a3ed0c8d930a4988853c4dd2 Mon Sep 17 00:00:00 2001 From: Paul O'Leary McCann Date: Sun, 3 Jul 2022 14:48:42 +0900 Subject: [PATCH] Clean tests. --- spacy/tests/pipeline/test_coref.py | 31 +++++++++--------------------- 1 file changed, 9 insertions(+), 22 deletions(-) diff --git a/spacy/tests/pipeline/test_coref.py b/spacy/tests/pipeline/test_coref.py index 584db99b8..73c09b48e 100644 --- a/spacy/tests/pipeline/test_coref.py +++ b/spacy/tests/pipeline/test_coref.py @@ -35,7 +35,8 @@ TRAIN_DATA = [ # fmt: on def spans2ints(doc): - """Convert doc.spans to nested list of ints for comparison. + """Convert doc.spans to nested list of ints for comparison. + The ints are token indices. This is useful for checking consistency of predictions. """ @@ -98,21 +99,14 @@ def test_coref_serialization(nlp): assert nlp.pipe_names == ["coref"] text = "She gave me her pen." doc = nlp(text) - spans_result = doc.spans with make_tempdir() as tmp_dir: nlp.to_disk(tmp_dir) nlp2 = spacy.load(tmp_dir) assert nlp2.pipe_names == ["coref"] doc2 = nlp2(text) - spans_result2 = doc2.spans - print(1, [(k, len(v)) for k, v in spans_result.items()]) - print(2, [(k, len(v)) for k, v in spans_result2.items()]) - # Note: spans do not compare equal because docs are different and docs - # use object identity for equality - for k, v in spans_result.items(): - assert str(spans_result[k]) == str(spans_result2[k]) - # assert spans_result == spans_result2 + + assert spans2ints(doc) == spans2ints(doc2) @pytest.mark.skipif(not has_torch, reason="Torch not available") @@ -148,18 +142,11 @@ def test_overfitting_IO(nlp): "I noticed many friends around me", "They received it. They received the SMS.", ] - docs = list(nlp.pipe(texts)) - batch_deps_1 = [doc.spans for doc in docs] - print(batch_deps_1) - docs = list(nlp.pipe(texts)) - batch_deps_2 = [doc.spans for doc in docs] - print(batch_deps_2) - docs = [nlp(text) for text in texts] - no_batch_deps = [doc.spans for doc in docs] - print(no_batch_deps) - print("FINISH") - # assert_equal(batch_deps_1, batch_deps_2) - # assert_equal(batch_deps_1, no_batch_deps) + docs1 = list(nlp.pipe(texts)) + docs2 = list(nlp.pipe(texts)) + docs3 = [nlp(text) for text in texts] + assert spans2ints(docs1[0]) == spans2ints(docs2[0]) + assert spans2ints(docs1[0]) == spans2ints(docs3[0]) @pytest.mark.skipif(not has_torch, reason="Torch not available") def test_tokenization_mismatch(nlp):