diff --git a/spacy/tests/test_util.py b/spacy/tests/test_util.py index 382a8f548..93201eb4b 100644 --- a/spacy/tests/test_util.py +++ b/spacy/tests/test_util.py @@ -11,13 +11,13 @@ from spacy.util import minibatch_by_words [ ([400, 400, 199], [3]), ([400, 400, 199, 3], [4]), + ([400, 400, 199, 3, 1], [5]), ([400, 400, 199, 3, 250], [3, 2]), + ([400, 400, 199, 3, 1, 250], [3, 3]), ], ) def test_util_minibatch(doc_sizes, expected_batches): docs = [get_doc(doc_size) for doc_size in doc_sizes] - examples = [Example(doc=doc) for doc in docs] - batches = list(minibatch_by_words(examples=examples, size=1000)) assert [len(batch) for batch in batches] == expected_batches diff --git a/spacy/util.py b/spacy/util.py index f5ca49637..8ac2fd370 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -670,7 +670,9 @@ def minibatch_by_words(examples, size, count_words=len, tolerance=0.2, discard_o target_size = next(size_) tol_size = target_size * tolerance batch = [] + overflow = [] current_size = 0 + overflow_size = 0 for example in examples: n_words = count_words(example.doc) @@ -681,10 +683,15 @@ def minibatch_by_words(examples, size, count_words=len, tolerance=0.2, discard_o yield [example] # add the example to the current batch if it still fits - elif (current_size + n_words) < (target_size + tol_size): + elif (current_size + n_words) < target_size: batch.append(example) current_size += n_words + # add the example to the overflow buffer if it fits in the tolerance margins + elif (current_size + n_words) < (target_size + tol_size): + overflow.append(example) + overflow_size += n_words + # yield the previous batch and start a new one else: yield batch @@ -692,11 +699,15 @@ def minibatch_by_words(examples, size, count_words=len, tolerance=0.2, discard_o tol_size = target_size * tolerance # In theory it may happen that the current example now exceeds the new target_size, # but that seems like an unimportant edge case if batch sizes are variable anyway? - batch = [example] - current_size = n_words + batch = overflow + batch.append(example) + current_size = overflow_size + n_words + overflow = [] + overflow_size = 0 # yield the final batch if batch: + batch.extend(overflow) yield batch