From 3a7f275c02ea75f91c54c7897a900637f8ebea5d Mon Sep 17 00:00:00 2001 From: Matthw Honnibal Date: Thu, 9 Jul 2020 14:38:41 +0200 Subject: [PATCH] Add extra batch util --- spacy/util.py | 53 +++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 49 insertions(+), 4 deletions(-) diff --git a/spacy/util.py b/spacy/util.py index 4a17b7f24..a721eb85b 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -722,6 +722,50 @@ def minibatch(items, size=8): yield list(batch) +def minibatch_by_padded_size(docs, size, buffer=256, discard_oversize=False): + if isinstance(size, int): + size_ = itertools.repeat(size) + else: + size_ = size + for outer_batch in minibatch(docs, buffer): + outer_batch = list(outer_batch) + target_size = next(size_) + for indices in _batch_by_length(outer_batch, target_size): + subbatch = [outer_batch[i] for i in indices] + padded_size = max(len(seq) for seq in subbatch) * len(subbatch) + if discard_oversize and padded_size >= target_size: + pass + else: + yield subbatch + + +def _batch_by_length(seqs, max_words): + """Given a list of sequences, return a batched list of indices into the + list, where the batches are grouped by length, in descending order. + + Batches may be at most max_words in size, defined as max sequence length * size. + """ + # Use negative index so we can get sort by position ascending. + lengths_indices = [(len(seq), i) for i, seq in enumerate(seqs)] + lengths_indices.sort() + batches = [] + batch = [] + for length, i in lengths_indices: + if not batch: + batch.append(i) + elif length * (len(batch) + 1) <= max_words: + batch.append(i) + else: + batches.append(batch) + batch = [i] + if batch: + batches.append(batch) + # Check lengths match + assert sum(len(b) for b in batches) == len(seqs) + batches = [list(sorted(batch)) for batch in batches] + batches.reverse() + return batches + def minibatch_by_words(docs, size, tolerance=0.2, discard_oversize=False): """Create minibatches of roughly a given number of words. If any examples are longer than the specified batch length, they will appear in a batch by @@ -768,7 +812,8 @@ def minibatch_by_words(docs, size, tolerance=0.2, discard_oversize=False): # yield the previous batch and start a new one. The new one gets the overflow examples. else: - yield batch + if batch: + yield batch target_size = next(size_) tol_size = target_size * tolerance batch = overflow @@ -788,15 +833,15 @@ def minibatch_by_words(docs, size, tolerance=0.2, discard_oversize=False): # this example does not fit with the previous overflow: start another new batch else: - yield batch + if batch: + yield batch target_size = next(size_) tol_size = target_size * tolerance batch = [doc] batch_size = n_words - # yield the final batch + batch.extend(overflow) if batch: - batch.extend(overflow) yield batch