mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-11 12:18:04 +03:00
Merge pull request #5533 from svlandeg/bugfix/minibatch-oversize
add oversize examples before StopIteration returns
This commit is contained in:
commit
f74784575c
59
spacy/tests/test_util.py
Normal file
59
spacy/tests/test_util.py
Normal file
|
@ -0,0 +1,59 @@
|
||||||
|
import pytest
|
||||||
|
from spacy.gold import Example
|
||||||
|
|
||||||
|
from .util import get_random_doc
|
||||||
|
|
||||||
|
from spacy.util import minibatch_by_words
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"doc_sizes, expected_batches",
|
||||||
|
[
|
||||||
|
([400, 400, 199], [3]),
|
||||||
|
([400, 400, 199, 3], [4]),
|
||||||
|
([400, 400, 199, 3, 200], [3, 2]),
|
||||||
|
([400, 400, 199, 3, 1], [5]),
|
||||||
|
([400, 400, 199, 3, 1, 1500], [5]), # 1500 will be discarded
|
||||||
|
([400, 400, 199, 3, 1, 200], [3, 3]),
|
||||||
|
([400, 400, 199, 3, 1, 999], [3, 3]),
|
||||||
|
([400, 400, 199, 3, 1, 999, 999], [3, 2, 1, 1]),
|
||||||
|
([1, 2, 999], [3]),
|
||||||
|
([1, 2, 999, 1], [4]),
|
||||||
|
([1, 200, 999, 1], [2, 2]),
|
||||||
|
([1, 999, 200, 1], [2, 2]),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_util_minibatch(doc_sizes, expected_batches):
|
||||||
|
docs = [get_random_doc(doc_size) for doc_size in doc_sizes]
|
||||||
|
examples = [Example(doc=doc) for doc in docs]
|
||||||
|
tol = 0.2
|
||||||
|
batch_size = 1000
|
||||||
|
batches = list(minibatch_by_words(examples=examples, size=batch_size, tolerance=tol, discard_oversize=True))
|
||||||
|
assert [len(batch) for batch in batches] == expected_batches
|
||||||
|
|
||||||
|
max_size = batch_size + batch_size * tol
|
||||||
|
for batch in batches:
|
||||||
|
assert sum([len(example.doc) for example in batch]) < max_size
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"doc_sizes, expected_batches",
|
||||||
|
[
|
||||||
|
([400, 4000, 199], [1, 2]),
|
||||||
|
([400, 400, 199, 3000, 200], [1, 4]),
|
||||||
|
([400, 400, 199, 3, 1, 1500], [1, 5]),
|
||||||
|
([400, 400, 199, 3000, 2000, 200, 200], [1, 1, 3, 2]),
|
||||||
|
([1, 2, 9999], [1, 2]),
|
||||||
|
([2000, 1, 2000, 1, 1, 1, 2000], [1, 1, 1, 4]),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_util_minibatch_oversize(doc_sizes, expected_batches):
|
||||||
|
""" Test that oversized documents are returned in their own batch"""
|
||||||
|
docs = [get_random_doc(doc_size) for doc_size in doc_sizes]
|
||||||
|
examples = [Example(doc=doc) for doc in docs]
|
||||||
|
tol = 0.2
|
||||||
|
batch_size = 1000
|
||||||
|
batches = list(minibatch_by_words(examples=examples, size=batch_size, tolerance=tol, discard_oversize=False))
|
||||||
|
assert [len(batch) for batch in batches] == expected_batches
|
||||||
|
|
||||||
|
|
|
@ -92,6 +92,13 @@ def get_batch(batch_size):
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
|
||||||
|
def get_random_doc(n_words):
|
||||||
|
vocab = Vocab()
|
||||||
|
# Make the words numbers, so that they're easy to track.
|
||||||
|
numbers = [str(i) for i in range(0, n_words)]
|
||||||
|
return Doc(vocab, words=numbers)
|
||||||
|
|
||||||
|
|
||||||
def apply_transition_sequence(parser, doc, sequence):
|
def apply_transition_sequence(parser, doc, sequence):
|
||||||
"""Perform a series of pre-specified transitions, to put the parser in a
|
"""Perform a series of pre-specified transitions, to put the parser in a
|
||||||
desired state."""
|
desired state."""
|
||||||
|
|
|
@ -656,41 +656,73 @@ def decaying(start, stop, decay):
|
||||||
curr -= decay
|
curr -= decay
|
||||||
|
|
||||||
|
|
||||||
def minibatch_by_words(examples, size, tuples=True, count_words=len, tolerance=0.2):
|
def minibatch_by_words(examples, size, count_words=len, tolerance=0.2, discard_oversize=False):
|
||||||
"""Create minibatches of roughly a given number of words. If any examples
|
"""Create minibatches of roughly a given number of words. If any examples
|
||||||
are longer than the specified batch length, they will appear in a batch by
|
are longer than the specified batch length, they will appear in a batch by
|
||||||
themselves."""
|
themselves, or be discarded if discard_oversize=True."""
|
||||||
if isinstance(size, int):
|
if isinstance(size, int):
|
||||||
size_ = itertools.repeat(size)
|
size_ = itertools.repeat(size)
|
||||||
elif isinstance(size, List):
|
elif isinstance(size, List):
|
||||||
size_ = iter(size)
|
size_ = iter(size)
|
||||||
else:
|
else:
|
||||||
size_ = size
|
size_ = size
|
||||||
examples = iter(examples)
|
|
||||||
oversize = []
|
target_size = next(size_)
|
||||||
while True:
|
tol_size = target_size * tolerance
|
||||||
batch_size = next(size_)
|
|
||||||
tol_size = batch_size * 0.2
|
|
||||||
batch = []
|
batch = []
|
||||||
if oversize:
|
overflow = []
|
||||||
example = oversize.pop(0)
|
batch_size = 0
|
||||||
|
overflow_size = 0
|
||||||
|
|
||||||
|
for example in examples:
|
||||||
n_words = count_words(example.doc)
|
n_words = count_words(example.doc)
|
||||||
|
# if the current example exceeds the maximum batch size, it is returned separately
|
||||||
|
# but only if discard_oversize=False.
|
||||||
|
if n_words > target_size + tol_size:
|
||||||
|
if not discard_oversize:
|
||||||
|
yield [example]
|
||||||
|
|
||||||
|
# add the example to the current batch if there's no overflow yet and it still fits
|
||||||
|
elif overflow_size == 0 and (batch_size + n_words) <= target_size:
|
||||||
batch.append(example)
|
batch.append(example)
|
||||||
batch_size -= n_words
|
batch_size += n_words
|
||||||
while batch_size >= 1:
|
|
||||||
try:
|
# add the example to the overflow buffer if it fits in the tolerance margin
|
||||||
example = next(examples)
|
elif (batch_size + overflow_size + n_words) <= (target_size + tol_size):
|
||||||
except StopIteration:
|
overflow.append(example)
|
||||||
if batch:
|
overflow_size += n_words
|
||||||
yield batch
|
|
||||||
return
|
# yield the previous batch and start a new one. The new one gets the overflow examples.
|
||||||
n_words = count_words(example.doc)
|
|
||||||
if n_words < (batch_size + tol_size):
|
|
||||||
batch_size -= n_words
|
|
||||||
batch.append(example)
|
|
||||||
else:
|
else:
|
||||||
oversize.append(example)
|
yield batch
|
||||||
|
target_size = next(size_)
|
||||||
|
tol_size = target_size * tolerance
|
||||||
|
batch = overflow
|
||||||
|
batch_size = overflow_size
|
||||||
|
overflow = []
|
||||||
|
overflow_size = 0
|
||||||
|
|
||||||
|
# this example still fits
|
||||||
|
if (batch_size + n_words) <= target_size:
|
||||||
|
batch.append(example)
|
||||||
|
batch_size += n_words
|
||||||
|
|
||||||
|
# this example fits in overflow
|
||||||
|
elif (batch_size + n_words) <= (target_size + tol_size):
|
||||||
|
overflow.append(example)
|
||||||
|
overflow_size += n_words
|
||||||
|
|
||||||
|
# this example does not fit with the previous overflow: start another new batch
|
||||||
|
else:
|
||||||
|
yield batch
|
||||||
|
target_size = next(size_)
|
||||||
|
tol_size = target_size * tolerance
|
||||||
|
batch = [example]
|
||||||
|
batch_size = n_words
|
||||||
|
|
||||||
|
# yield the final batch
|
||||||
if batch:
|
if batch:
|
||||||
|
batch.extend(overflow)
|
||||||
yield batch
|
yield batch
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user