From 1dacecbbfbf1648ec0d6a44d0c53d722de0c2c40 Mon Sep 17 00:00:00 2001 From: Paul O'Leary McCann Date: Sun, 3 Jul 2022 14:49:02 +0900 Subject: [PATCH] Run black --- spacy/tests/pipeline/test_coref.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/spacy/tests/pipeline/test_coref.py b/spacy/tests/pipeline/test_coref.py index 73c09b48e..4b8ca1653 100644 --- a/spacy/tests/pipeline/test_coref.py +++ b/spacy/tests/pipeline/test_coref.py @@ -34,17 +34,19 @@ TRAIN_DATA = [ ] # fmt: on + def spans2ints(doc): - """Convert doc.spans to nested list of ints for comparison. + """Convert doc.spans to nested list of ints for comparison. The ints are token indices. This is useful for checking consistency of predictions. """ out = [] for key, cluster in doc.spans.items(): - out.append( [(ss.start, ss.end) for ss in cluster] ) + out.append([(ss.start, ss.end) for ss in cluster]) return out + @pytest.fixture def nlp(): return English() @@ -70,6 +72,7 @@ def test_not_initialized(nlp): with pytest.raises(ValueError, match="E109"): nlp(text) + @pytest.mark.skipif(not has_torch, reason="Torch not available") def test_initialized(nlp): nlp.add_pipe("coref") @@ -148,6 +151,7 @@ def test_overfitting_IO(nlp): assert spans2ints(docs1[0]) == spans2ints(docs2[0]) assert spans2ints(docs1[0]) == spans2ints(docs3[0]) + @pytest.mark.skipif(not has_torch, reason="Torch not available") def test_tokenization_mismatch(nlp): train_examples = [] @@ -158,7 +162,7 @@ def test_tokenization_mismatch(nlp): for key, cluster in ref.spans.items(): char_spans[key] = [] for span in cluster: - char_spans[key].append( (span[0].idx, span[-1].idx + len(span[-1])) ) + char_spans[key].append((span[0].idx, span[-1].idx + len(span[-1]))) with ref.retokenize() as retokenizer: # merge "many friends" retokenizer.merge(ref[5:7]) @@ -203,6 +207,7 @@ def test_tokenization_mismatch(nlp): assert spans2ints(docs1[0]) == spans2ints(docs2[0]) assert spans2ints(docs1[0]) == spans2ints(docs3[0]) + @pytest.mark.skipif(not has_torch, reason="Torch not available") def test_crossing_spans(): starts = [6, 10, 0, 1, 0, 1, 0, 1, 2, 2, 2] @@ -215,6 +220,7 @@ def test_crossing_spans(): guess = sorted(guess) assert gold == guess + @pytest.mark.skipif(not has_torch, reason="Torch not available") def test_sentence_map(snlp): doc = snlp("I like text. This is text.")