diff --git a/spacy/tests/pipeline/test_coref.py b/spacy/tests/pipeline/test_coref.py index 7b96c5540..e09d4827d 100644 --- a/spacy/tests/pipeline/test_coref.py +++ b/spacy/tests/pipeline/test_coref.py @@ -6,6 +6,7 @@ from spacy.training import Example from spacy.lang.en import English from spacy.tests.util import make_tempdir from spacy.pipeline.coref import DEFAULT_CLUSTERS_PREFIX +from spacy.ml.models.coref_util import select_non_crossing_spans # fmt: off TRAIN_DATA = [ @@ -142,3 +143,14 @@ def test_overfitting_IO(nlp): print(no_batch_deps) # assert_equal(batch_deps_1, batch_deps_2) # assert_equal(batch_deps_1, no_batch_deps) + +def test_crossing_spans(): + starts = [ 6, 10, 0, 1, 0, 1, 0, 1, 2, 2, 2] + ends = [12, 12, 2, 3, 3, 4, 4, 4, 3, 4, 5] + idxs = list(range(len(starts))) + limit = 5 + + gold = sorted([0 , 1, 2, 4, 6]) + guess = select_non_crossing_spans(idxs, starts, ends, limit) + guess = sorted(guess) + assert gold == guess