using overflow buffer for examples within the tolerance margin

This commit is contained in:
svlandeg 2020-06-02 19:43:39 +02:00
parent 85b0597ed5
commit 6651fafd5c
2 changed files with 16 additions and 5 deletions

View File

@ -11,13 +11,13 @@ from spacy.util import minibatch_by_words
[ [
([400, 400, 199], [3]), ([400, 400, 199], [3]),
([400, 400, 199, 3], [4]), ([400, 400, 199, 3], [4]),
([400, 400, 199, 3, 1], [5]),
([400, 400, 199, 3, 250], [3, 2]), ([400, 400, 199, 3, 250], [3, 2]),
([400, 400, 199, 3, 1, 250], [3, 3]),
], ],
) )
def test_util_minibatch(doc_sizes, expected_batches): def test_util_minibatch(doc_sizes, expected_batches):
docs = [get_doc(doc_size) for doc_size in doc_sizes] docs = [get_doc(doc_size) for doc_size in doc_sizes]
examples = [Example(doc=doc) for doc in docs] examples = [Example(doc=doc) for doc in docs]
batches = list(minibatch_by_words(examples=examples, size=1000)) batches = list(minibatch_by_words(examples=examples, size=1000))
assert [len(batch) for batch in batches] == expected_batches assert [len(batch) for batch in batches] == expected_batches

View File

@ -670,7 +670,9 @@ def minibatch_by_words(examples, size, count_words=len, tolerance=0.2, discard_o
target_size = next(size_) target_size = next(size_)
tol_size = target_size * tolerance tol_size = target_size * tolerance
batch = [] batch = []
overflow = []
current_size = 0 current_size = 0
overflow_size = 0
for example in examples: for example in examples:
n_words = count_words(example.doc) n_words = count_words(example.doc)
@ -681,10 +683,15 @@ def minibatch_by_words(examples, size, count_words=len, tolerance=0.2, discard_o
yield [example] yield [example]
# add the example to the current batch if it still fits # add the example to the current batch if it still fits
elif (current_size + n_words) < (target_size + tol_size): elif (current_size + n_words) < target_size:
batch.append(example) batch.append(example)
current_size += n_words current_size += n_words
# add the example to the overflow buffer if it fits in the tolerance margins
elif (current_size + n_words) < (target_size + tol_size):
overflow.append(example)
overflow_size += n_words
# yield the previous batch and start a new one # yield the previous batch and start a new one
else: else:
yield batch yield batch
@ -692,11 +699,15 @@ def minibatch_by_words(examples, size, count_words=len, tolerance=0.2, discard_o
tol_size = target_size * tolerance tol_size = target_size * tolerance
# In theory it may happen that the current example now exceeds the new target_size, # In theory it may happen that the current example now exceeds the new target_size,
# but that seems like an unimportant edge case if batch sizes are variable anyway? # but that seems like an unimportant edge case if batch sizes are variable anyway?
batch = [example] batch = overflow
current_size = n_words batch.append(example)
current_size = overflow_size + n_words
overflow = []
overflow_size = 0
# yield the final batch # yield the final batch
if batch: if batch:
batch.extend(overflow)
yield batch yield batch