Run black

This commit is contained in:
Paul O'Leary McCann 2022-07-03 14:49:02 +09:00
parent 5192ac1617
commit 1dacecbbfb

View File

@ -34,6 +34,7 @@ TRAIN_DATA = [
] ]
# fmt: on # fmt: on
def spans2ints(doc): def spans2ints(doc):
"""Convert doc.spans to nested list of ints for comparison. """Convert doc.spans to nested list of ints for comparison.
The ints are token indices. The ints are token indices.
@ -45,6 +46,7 @@ def spans2ints(doc):
out.append([(ss.start, ss.end) for ss in cluster]) out.append([(ss.start, ss.end) for ss in cluster])
return out return out
@pytest.fixture @pytest.fixture
def nlp(): def nlp():
return English() return English()
@ -70,6 +72,7 @@ def test_not_initialized(nlp):
with pytest.raises(ValueError, match="E109"): with pytest.raises(ValueError, match="E109"):
nlp(text) nlp(text)
@pytest.mark.skipif(not has_torch, reason="Torch not available") @pytest.mark.skipif(not has_torch, reason="Torch not available")
def test_initialized(nlp): def test_initialized(nlp):
nlp.add_pipe("coref") nlp.add_pipe("coref")
@ -148,6 +151,7 @@ def test_overfitting_IO(nlp):
assert spans2ints(docs1[0]) == spans2ints(docs2[0]) assert spans2ints(docs1[0]) == spans2ints(docs2[0])
assert spans2ints(docs1[0]) == spans2ints(docs3[0]) assert spans2ints(docs1[0]) == spans2ints(docs3[0])
@pytest.mark.skipif(not has_torch, reason="Torch not available") @pytest.mark.skipif(not has_torch, reason="Torch not available")
def test_tokenization_mismatch(nlp): def test_tokenization_mismatch(nlp):
train_examples = [] train_examples = []
@ -203,6 +207,7 @@ def test_tokenization_mismatch(nlp):
assert spans2ints(docs1[0]) == spans2ints(docs2[0]) assert spans2ints(docs1[0]) == spans2ints(docs2[0])
assert spans2ints(docs1[0]) == spans2ints(docs3[0]) assert spans2ints(docs1[0]) == spans2ints(docs3[0])
@pytest.mark.skipif(not has_torch, reason="Torch not available") @pytest.mark.skipif(not has_torch, reason="Torch not available")
def test_crossing_spans(): def test_crossing_spans():
starts = [6, 10, 0, 1, 0, 1, 0, 1, 2, 2, 2] starts = [6, 10, 0, 1, 0, 1, 0, 1, 2, 2, 2]
@ -215,6 +220,7 @@ def test_crossing_spans():
guess = sorted(guess) guess = sorted(guess)
assert gold == guess assert gold == guess
@pytest.mark.skipif(not has_torch, reason="Torch not available") @pytest.mark.skipif(not has_torch, reason="Torch not available")
def test_sentence_map(snlp): def test_sentence_map(snlp):
doc = snlp("I like text. This is text.") doc = snlp("I like text. This is text.")