diff --git a/spacy/tests/pipeline/test_coref.py b/spacy/tests/pipeline/test_coref.py index d252cfa83..8a20e43a4 100644 --- a/spacy/tests/pipeline/test_coref.py +++ b/spacy/tests/pipeline/test_coref.py @@ -6,7 +6,11 @@ from spacy.training import Example from spacy.lang.en import English from spacy.tests.util import make_tempdir from spacy.pipeline.coref import DEFAULT_CLUSTERS_PREFIX -from spacy.ml.models.coref_util import select_non_crossing_spans, get_candidate_mentions +from spacy.ml.models.coref_util import ( + select_non_crossing_spans, + get_candidate_mentions, + get_sentence_map, +) # fmt: off TRAIN_DATA = [ @@ -35,6 +39,13 @@ def nlp(): return English() +@pytest.fixture +def snlp(): + en = English() + en.add_pipe("sentencizer") + return en + + def test_add_pipe(nlp): nlp.add_pipe("coref") assert nlp.pipe_names == ["coref"] @@ -158,17 +169,21 @@ def test_crossing_spans(): assert gold == guess -def test_mention_generator(): - # don't use the fixture because we want the sentencizer - nlp = English() +def test_mention_generator(snlp): + nlp = snlp doc = nlp("I like text.") # four tokens max_width = 20 mentions = get_candidate_mentions(doc, max_width) assert len(mentions[0]) == 10 # check multiple sentences - nlp.add_pipe("sentencizer") doc = nlp("I like text. This is text.") # eight tokens, two sents max_width = 20 mentions = get_candidate_mentions(doc, max_width) assert len(mentions[0]) == 20 + + +def test_sentence_map(snlp): + doc = snlp("I like text. This is text.") + sm = get_sentence_map(doc) + assert sm == [0, 0, 0, 0, 1, 1, 1, 1]