spaCy/spacy/tests/test_util.py

24 lines
682 B
Python
Raw Normal View History

2020-06-02 19:26:21 +03:00
import pytest
from spacy.gold import Example
from .util import get_doc
from spacy.util import minibatch_by_words
@pytest.mark.parametrize(
"doc_sizes, expected_batches",
[
([400, 400, 199], [3]),
([400, 400, 199, 3], [4]),
([400, 400, 199, 3, 1], [5]),
2020-06-02 20:47:30 +03:00
([400, 400, 199, 3, 200], [3, 2]),
([400, 400, 199, 3, 1, 200], [3, 3]),
2020-06-02 19:26:21 +03:00
],
)
def test_util_minibatch(doc_sizes, expected_batches):
docs = [get_doc(doc_size) for doc_size in doc_sizes]
examples = [Example(doc=doc) for doc in docs]
batches = list(minibatch_by_words(examples=examples, size=1000))
assert [len(batch) for batch in batches] == expected_batches