Run black

This commit is contained in:
Paul O'Leary McCann 2022-07-03 14:49:02 +09:00
parent 5192ac1617
commit 1dacecbbfb

View File

@ -34,17 +34,19 @@ TRAIN_DATA = [
]
# fmt: on
def spans2ints(doc):
"""Convert doc.spans to nested list of ints for comparison.
"""Convert doc.spans to nested list of ints for comparison.
The ints are token indices.
This is useful for checking consistency of predictions.
"""
out = []
for key, cluster in doc.spans.items():
out.append( [(ss.start, ss.end) for ss in cluster] )
out.append([(ss.start, ss.end) for ss in cluster])
return out
@pytest.fixture
def nlp():
return English()
@ -70,6 +72,7 @@ def test_not_initialized(nlp):
with pytest.raises(ValueError, match="E109"):
nlp(text)
@pytest.mark.skipif(not has_torch, reason="Torch not available")
def test_initialized(nlp):
nlp.add_pipe("coref")
@ -148,6 +151,7 @@ def test_overfitting_IO(nlp):
assert spans2ints(docs1[0]) == spans2ints(docs2[0])
assert spans2ints(docs1[0]) == spans2ints(docs3[0])
@pytest.mark.skipif(not has_torch, reason="Torch not available")
def test_tokenization_mismatch(nlp):
train_examples = []
@ -158,7 +162,7 @@ def test_tokenization_mismatch(nlp):
for key, cluster in ref.spans.items():
char_spans[key] = []
for span in cluster:
char_spans[key].append( (span[0].idx, span[-1].idx + len(span[-1])) )
char_spans[key].append((span[0].idx, span[-1].idx + len(span[-1])))
with ref.retokenize() as retokenizer:
# merge "many friends"
retokenizer.merge(ref[5:7])
@ -203,6 +207,7 @@ def test_tokenization_mismatch(nlp):
assert spans2ints(docs1[0]) == spans2ints(docs2[0])
assert spans2ints(docs1[0]) == spans2ints(docs3[0])
@pytest.mark.skipif(not has_torch, reason="Torch not available")
def test_crossing_spans():
starts = [6, 10, 0, 1, 0, 1, 0, 1, 2, 2, 2]
@ -215,6 +220,7 @@ def test_crossing_spans():
guess = sorted(guess)
assert gold == guess
@pytest.mark.skipif(not has_torch, reason="Torch not available")
def test_sentence_map(snlp):
doc = snlp("I like text. This is text.")