Run black

This commit is contained in:
Paul O'Leary McCann 2022-07-03 14:49:02 +09:00
parent 5192ac1617
commit 1dacecbbfb

View File

@ -34,17 +34,19 @@ TRAIN_DATA = [
] ]
# fmt: on # fmt: on
def spans2ints(doc): def spans2ints(doc):
"""Convert doc.spans to nested list of ints for comparison. """Convert doc.spans to nested list of ints for comparison.
The ints are token indices. The ints are token indices.
This is useful for checking consistency of predictions. This is useful for checking consistency of predictions.
""" """
out = [] out = []
for key, cluster in doc.spans.items(): for key, cluster in doc.spans.items():
out.append( [(ss.start, ss.end) for ss in cluster] ) out.append([(ss.start, ss.end) for ss in cluster])
return out return out
@pytest.fixture @pytest.fixture
def nlp(): def nlp():
return English() return English()
@ -70,6 +72,7 @@ def test_not_initialized(nlp):
with pytest.raises(ValueError, match="E109"): with pytest.raises(ValueError, match="E109"):
nlp(text) nlp(text)
@pytest.mark.skipif(not has_torch, reason="Torch not available") @pytest.mark.skipif(not has_torch, reason="Torch not available")
def test_initialized(nlp): def test_initialized(nlp):
nlp.add_pipe("coref") nlp.add_pipe("coref")
@ -148,6 +151,7 @@ def test_overfitting_IO(nlp):
assert spans2ints(docs1[0]) == spans2ints(docs2[0]) assert spans2ints(docs1[0]) == spans2ints(docs2[0])
assert spans2ints(docs1[0]) == spans2ints(docs3[0]) assert spans2ints(docs1[0]) == spans2ints(docs3[0])
@pytest.mark.skipif(not has_torch, reason="Torch not available") @pytest.mark.skipif(not has_torch, reason="Torch not available")
def test_tokenization_mismatch(nlp): def test_tokenization_mismatch(nlp):
train_examples = [] train_examples = []
@ -158,7 +162,7 @@ def test_tokenization_mismatch(nlp):
for key, cluster in ref.spans.items(): for key, cluster in ref.spans.items():
char_spans[key] = [] char_spans[key] = []
for span in cluster: for span in cluster:
char_spans[key].append( (span[0].idx, span[-1].idx + len(span[-1])) ) char_spans[key].append((span[0].idx, span[-1].idx + len(span[-1])))
with ref.retokenize() as retokenizer: with ref.retokenize() as retokenizer:
# merge "many friends" # merge "many friends"
retokenizer.merge(ref[5:7]) retokenizer.merge(ref[5:7])
@ -203,6 +207,7 @@ def test_tokenization_mismatch(nlp):
assert spans2ints(docs1[0]) == spans2ints(docs2[0]) assert spans2ints(docs1[0]) == spans2ints(docs2[0])
assert spans2ints(docs1[0]) == spans2ints(docs3[0]) assert spans2ints(docs1[0]) == spans2ints(docs3[0])
@pytest.mark.skipif(not has_torch, reason="Torch not available") @pytest.mark.skipif(not has_torch, reason="Torch not available")
def test_crossing_spans(): def test_crossing_spans():
starts = [6, 10, 0, 1, 0, 1, 0, 1, 2, 2, 2] starts = [6, 10, 0, 1, 0, 1, 0, 1, 2, 2, 2]
@ -215,6 +220,7 @@ def test_crossing_spans():
guess = sorted(guess) guess = sorted(guess)
assert gold == guess assert gold == guess
@pytest.mark.skipif(not has_torch, reason="Torch not available") @pytest.mark.skipif(not has_torch, reason="Torch not available")
def test_sentence_map(snlp): def test_sentence_map(snlp):
doc = snlp("I like text. This is text.") doc = snlp("I like text. This is text.")