Add test for crossing spans

This should maybe go elsewhere?
This commit is contained in:
Paul O'Leary McCann 2021-06-28 18:21:00 +09:00
parent 4f377d8de8
commit b02df61eb9

View File

@ -6,6 +6,7 @@ from spacy.training import Example
from spacy.lang.en import English
from spacy.tests.util import make_tempdir
from spacy.pipeline.coref import DEFAULT_CLUSTERS_PREFIX
from spacy.ml.models.coref_util import select_non_crossing_spans
# fmt: off
TRAIN_DATA = [
@ -142,3 +143,14 @@ def test_overfitting_IO(nlp):
print(no_batch_deps)
# assert_equal(batch_deps_1, batch_deps_2)
# assert_equal(batch_deps_1, no_batch_deps)
def test_crossing_spans():
starts = [ 6, 10, 0, 1, 0, 1, 0, 1, 2, 2, 2]
ends = [12, 12, 2, 3, 3, 4, 4, 4, 3, 4, 5]
idxs = list(range(len(starts)))
limit = 5
gold = sorted([0 , 1, 2, 4, 6])
guess = select_non_crossing_spans(idxs, starts, ends, limit)
guess = sorted(guess)
assert gold == guess