Run black

This commit is contained in:
Paul O'Leary McCann 2022-07-03 14:49:02 +09:00
parent 5192ac1617
commit 1dacecbbfb

View File

@ -34,6 +34,7 @@ TRAIN_DATA = [
] ]
# fmt: on # fmt: on
def spans2ints(doc): def spans2ints(doc):
"""Convert doc.spans to nested list of ints for comparison. """Convert doc.spans to nested list of ints for comparison.
The ints are token indices. The ints are token indices.
@ -42,9 +43,10 @@ def spans2ints(doc):
""" """
out = [] out = []
for key, cluster in doc.spans.items(): for key, cluster in doc.spans.items():
out.append( [(ss.start, ss.end) for ss in cluster] ) out.append([(ss.start, ss.end) for ss in cluster])
return out return out
@pytest.fixture @pytest.fixture
def nlp(): def nlp():
return English() return English()
@ -70,6 +72,7 @@ def test_not_initialized(nlp):
with pytest.raises(ValueError, match="E109"): with pytest.raises(ValueError, match="E109"):
nlp(text) nlp(text)
@pytest.mark.skipif(not has_torch, reason="Torch not available") @pytest.mark.skipif(not has_torch, reason="Torch not available")
def test_initialized(nlp): def test_initialized(nlp):
nlp.add_pipe("coref") nlp.add_pipe("coref")
@ -148,6 +151,7 @@ def test_overfitting_IO(nlp):
assert spans2ints(docs1[0]) == spans2ints(docs2[0]) assert spans2ints(docs1[0]) == spans2ints(docs2[0])
assert spans2ints(docs1[0]) == spans2ints(docs3[0]) assert spans2ints(docs1[0]) == spans2ints(docs3[0])
@pytest.mark.skipif(not has_torch, reason="Torch not available") @pytest.mark.skipif(not has_torch, reason="Torch not available")
def test_tokenization_mismatch(nlp): def test_tokenization_mismatch(nlp):
train_examples = [] train_examples = []
@ -158,7 +162,7 @@ def test_tokenization_mismatch(nlp):
for key, cluster in ref.spans.items(): for key, cluster in ref.spans.items():
char_spans[key] = [] char_spans[key] = []
for span in cluster: for span in cluster:
char_spans[key].append( (span[0].idx, span[-1].idx + len(span[-1])) ) char_spans[key].append((span[0].idx, span[-1].idx + len(span[-1])))
with ref.retokenize() as retokenizer: with ref.retokenize() as retokenizer:
# merge "many friends" # merge "many friends"
retokenizer.merge(ref[5:7]) retokenizer.merge(ref[5:7])
@ -203,6 +207,7 @@ def test_tokenization_mismatch(nlp):
assert spans2ints(docs1[0]) == spans2ints(docs2[0]) assert spans2ints(docs1[0]) == spans2ints(docs2[0])
assert spans2ints(docs1[0]) == spans2ints(docs3[0]) assert spans2ints(docs1[0]) == spans2ints(docs3[0])
@pytest.mark.skipif(not has_torch, reason="Torch not available") @pytest.mark.skipif(not has_torch, reason="Torch not available")
def test_crossing_spans(): def test_crossing_spans():
starts = [6, 10, 0, 1, 0, 1, 0, 1, 2, 2, 2] starts = [6, 10, 0, 1, 0, 1, 0, 1, 2, 2, 2]
@ -215,6 +220,7 @@ def test_crossing_spans():
guess = sorted(guess) guess = sorted(guess)
assert gold == guess assert gold == guess
@pytest.mark.skipif(not has_torch, reason="Torch not available") @pytest.mark.skipif(not has_torch, reason="Torch not available")
def test_sentence_map(snlp): def test_sentence_map(snlp):
doc = snlp("I like text. This is text.") doc = snlp("I like text. This is text.")