Run black

This commit is contained in:
Paul O'Leary McCann 2022-07-03 14:49:02 +09:00
parent 5192ac1617
commit 1dacecbbfb

View File

@ -34,6 +34,7 @@ TRAIN_DATA = [
]
# fmt: on
def spans2ints(doc):
"""Convert doc.spans to nested list of ints for comparison.
The ints are token indices.
@ -42,9 +43,10 @@ def spans2ints(doc):
"""
out = []
for key, cluster in doc.spans.items():
out.append( [(ss.start, ss.end) for ss in cluster] )
out.append([(ss.start, ss.end) for ss in cluster])
return out
@pytest.fixture
def nlp():
return English()
@ -70,6 +72,7 @@ def test_not_initialized(nlp):
with pytest.raises(ValueError, match="E109"):
nlp(text)
@pytest.mark.skipif(not has_torch, reason="Torch not available")
def test_initialized(nlp):
nlp.add_pipe("coref")
@ -148,6 +151,7 @@ def test_overfitting_IO(nlp):
assert spans2ints(docs1[0]) == spans2ints(docs2[0])
assert spans2ints(docs1[0]) == spans2ints(docs3[0])
@pytest.mark.skipif(not has_torch, reason="Torch not available")
def test_tokenization_mismatch(nlp):
train_examples = []
@ -158,7 +162,7 @@ def test_tokenization_mismatch(nlp):
for key, cluster in ref.spans.items():
char_spans[key] = []
for span in cluster:
char_spans[key].append( (span[0].idx, span[-1].idx + len(span[-1])) )
char_spans[key].append((span[0].idx, span[-1].idx + len(span[-1])))
with ref.retokenize() as retokenizer:
# merge "many friends"
retokenizer.merge(ref[5:7])
@ -203,6 +207,7 @@ def test_tokenization_mismatch(nlp):
assert spans2ints(docs1[0]) == spans2ints(docs2[0])
assert spans2ints(docs1[0]) == spans2ints(docs3[0])
@pytest.mark.skipif(not has_torch, reason="Torch not available")
def test_crossing_spans():
starts = [6, 10, 0, 1, 0, 1, 0, 1, 2, 2, 2]
@ -215,6 +220,7 @@ def test_crossing_spans():
guess = sorted(guess)
assert gold == guess
@pytest.mark.skipif(not has_torch, reason="Torch not available")
def test_sentence_map(snlp):
doc = snlp("I like text. This is text.")