From 85b0597ed5f8e23de337f56966e4b342827a99c3 Mon Sep 17 00:00:00 2001 From: svlandeg Date: Tue, 2 Jun 2020 18:26:21 +0200 Subject: [PATCH] add test for minibatch util --- spacy/tests/test_util.py | 23 +++++++++++++++++++++++ spacy/tests/util.py | 7 +++++++ 2 files changed, 30 insertions(+) create mode 100644 spacy/tests/test_util.py diff --git a/spacy/tests/test_util.py b/spacy/tests/test_util.py new file mode 100644 index 000000000..382a8f548 --- /dev/null +++ b/spacy/tests/test_util.py @@ -0,0 +1,23 @@ +import pytest +from spacy.gold import Example + +from .util import get_doc + +from spacy.util import minibatch_by_words + + +@pytest.mark.parametrize( + "doc_sizes, expected_batches", + [ + ([400, 400, 199], [3]), + ([400, 400, 199, 3], [4]), + ([400, 400, 199, 3, 250], [3, 2]), + ], +) +def test_util_minibatch(doc_sizes, expected_batches): + docs = [get_doc(doc_size) for doc_size in doc_sizes] + + examples = [Example(doc=doc) for doc in docs] + + batches = list(minibatch_by_words(examples=examples, size=1000)) + assert [len(batch) for batch in batches] == expected_batches diff --git a/spacy/tests/util.py b/spacy/tests/util.py index e29342268..73650a6f7 100644 --- a/spacy/tests/util.py +++ b/spacy/tests/util.py @@ -92,6 +92,13 @@ def get_batch(batch_size): return docs +def get_doc(n_words): + vocab = Vocab() + # Make the words numbers, so that they're easy to track. + numbers = [str(i) for i in range(0, n_words)] + return Doc(vocab, words=numbers) + + def apply_transition_sequence(parser, doc, sequence): """Perform a series of pre-specified transitions, to put the parser in a desired state."""