mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 18:06:29 +03:00
additional test with discard_oversize=False
This commit is contained in:
parent
aa6271b16c
commit
2bf5111ecf
|
@ -12,13 +12,11 @@ from spacy.util import minibatch_by_words
|
||||||
([400, 400, 199], [3]),
|
([400, 400, 199], [3]),
|
||||||
([400, 400, 199, 3], [4]),
|
([400, 400, 199, 3], [4]),
|
||||||
([400, 400, 199, 3, 200], [3, 2]),
|
([400, 400, 199, 3, 200], [3, 2]),
|
||||||
|
|
||||||
([400, 400, 199, 3, 1], [5]),
|
([400, 400, 199, 3, 1], [5]),
|
||||||
([400, 400, 199, 3, 1, 1500], [5]), # 1500 will be discarded
|
([400, 400, 199, 3, 1, 1500], [5]), # 1500 will be discarded
|
||||||
([400, 400, 199, 3, 1, 200], [3, 3]),
|
([400, 400, 199, 3, 1, 200], [3, 3]),
|
||||||
([400, 400, 199, 3, 1, 999], [3, 3]),
|
([400, 400, 199, 3, 1, 999], [3, 3]),
|
||||||
([400, 400, 199, 3, 1, 999, 999], [3, 2, 1, 1]),
|
([400, 400, 199, 3, 1, 999, 999], [3, 2, 1, 1]),
|
||||||
|
|
||||||
([1, 2, 999], [3]),
|
([1, 2, 999], [3]),
|
||||||
([1, 2, 999, 1], [4]),
|
([1, 2, 999, 1], [4]),
|
||||||
([1, 200, 999, 1], [2, 2]),
|
([1, 200, 999, 1], [2, 2]),
|
||||||
|
@ -37,3 +35,25 @@ def test_util_minibatch(doc_sizes, expected_batches):
|
||||||
for batch in batches:
|
for batch in batches:
|
||||||
assert sum([len(example.doc) for example in batch]) < max_size
|
assert sum([len(example.doc) for example in batch]) < max_size
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"doc_sizes, expected_batches",
|
||||||
|
[
|
||||||
|
([400, 4000, 199], [1, 2]),
|
||||||
|
([400, 400, 199, 3000, 200], [1, 4]),
|
||||||
|
([400, 400, 199, 3, 1, 1500], [1, 5]),
|
||||||
|
([400, 400, 199, 3000, 2000, 200, 200], [1, 1, 3, 2]),
|
||||||
|
([1, 2, 9999], [1, 2]),
|
||||||
|
([2000, 1, 2000, 1, 1, 1, 2000], [1, 1, 1, 4]),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_util_minibatch_oversize(doc_sizes, expected_batches):
|
||||||
|
""" Test that oversized documents are returned in their own batch"""
|
||||||
|
docs = [get_doc(doc_size) for doc_size in doc_sizes]
|
||||||
|
examples = [Example(doc=doc) for doc in docs]
|
||||||
|
tol = 0.2
|
||||||
|
batch_size = 1000
|
||||||
|
batches = list(minibatch_by_words(examples=examples, size=batch_size, tolerance=tol, discard_oversize=False))
|
||||||
|
assert [len(batch) for batch in batches] == expected_batches
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user