diff --git a/spacy/tests/test_util.py b/spacy/tests/test_util.py index 207805c81..6b6e84a17 100644 --- a/spacy/tests/test_util.py +++ b/spacy/tests/test_util.py @@ -12,13 +12,11 @@ from spacy.util import minibatch_by_words ([400, 400, 199], [3]), ([400, 400, 199, 3], [4]), ([400, 400, 199, 3, 200], [3, 2]), - ([400, 400, 199, 3, 1], [5]), ([400, 400, 199, 3, 1, 1500], [5]), # 1500 will be discarded ([400, 400, 199, 3, 1, 200], [3, 3]), ([400, 400, 199, 3, 1, 999], [3, 3]), ([400, 400, 199, 3, 1, 999, 999], [3, 2, 1, 1]), - ([1, 2, 999], [3]), ([1, 2, 999, 1], [4]), ([1, 200, 999, 1], [2, 2]), @@ -37,3 +35,25 @@ def test_util_minibatch(doc_sizes, expected_batches): for batch in batches: assert sum([len(example.doc) for example in batch]) < max_size + +@pytest.mark.parametrize( + "doc_sizes, expected_batches", + [ + ([400, 4000, 199], [1, 2]), + ([400, 400, 199, 3000, 200], [1, 4]), + ([400, 400, 199, 3, 1, 1500], [1, 5]), + ([400, 400, 199, 3000, 2000, 200, 200], [1, 1, 3, 2]), + ([1, 2, 9999], [1, 2]), + ([2000, 1, 2000, 1, 1, 1, 2000], [1, 1, 1, 4]), + ], +) +def test_util_minibatch_oversize(doc_sizes, expected_batches): + """ Test that oversized documents are returned in their own batch""" + docs = [get_doc(doc_size) for doc_size in doc_sizes] + examples = [Example(doc=doc) for doc in docs] + tol = 0.2 + batch_size = 1000 + batches = list(minibatch_by_words(examples=examples, size=batch_size, tolerance=tol, discard_oversize=False)) + assert [len(batch) for batch in batches] == expected_batches + +