mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 18:06:29 +03:00
Add extra batch util
This commit is contained in:
parent
eb0798c421
commit
3a7f275c02
|
@ -722,6 +722,50 @@ def minibatch(items, size=8):
|
||||||
yield list(batch)
|
yield list(batch)
|
||||||
|
|
||||||
|
|
||||||
|
def minibatch_by_padded_size(docs, size, buffer=256, discard_oversize=False):
|
||||||
|
if isinstance(size, int):
|
||||||
|
size_ = itertools.repeat(size)
|
||||||
|
else:
|
||||||
|
size_ = size
|
||||||
|
for outer_batch in minibatch(docs, buffer):
|
||||||
|
outer_batch = list(outer_batch)
|
||||||
|
target_size = next(size_)
|
||||||
|
for indices in _batch_by_length(outer_batch, target_size):
|
||||||
|
subbatch = [outer_batch[i] for i in indices]
|
||||||
|
padded_size = max(len(seq) for seq in subbatch) * len(subbatch)
|
||||||
|
if discard_oversize and padded_size >= target_size:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
yield subbatch
|
||||||
|
|
||||||
|
|
||||||
|
def _batch_by_length(seqs, max_words):
|
||||||
|
"""Given a list of sequences, return a batched list of indices into the
|
||||||
|
list, where the batches are grouped by length, in descending order.
|
||||||
|
|
||||||
|
Batches may be at most max_words in size, defined as max sequence length * size.
|
||||||
|
"""
|
||||||
|
# Use negative index so we can get sort by position ascending.
|
||||||
|
lengths_indices = [(len(seq), i) for i, seq in enumerate(seqs)]
|
||||||
|
lengths_indices.sort()
|
||||||
|
batches = []
|
||||||
|
batch = []
|
||||||
|
for length, i in lengths_indices:
|
||||||
|
if not batch:
|
||||||
|
batch.append(i)
|
||||||
|
elif length * (len(batch) + 1) <= max_words:
|
||||||
|
batch.append(i)
|
||||||
|
else:
|
||||||
|
batches.append(batch)
|
||||||
|
batch = [i]
|
||||||
|
if batch:
|
||||||
|
batches.append(batch)
|
||||||
|
# Check lengths match
|
||||||
|
assert sum(len(b) for b in batches) == len(seqs)
|
||||||
|
batches = [list(sorted(batch)) for batch in batches]
|
||||||
|
batches.reverse()
|
||||||
|
return batches
|
||||||
|
|
||||||
def minibatch_by_words(docs, size, tolerance=0.2, discard_oversize=False):
|
def minibatch_by_words(docs, size, tolerance=0.2, discard_oversize=False):
|
||||||
"""Create minibatches of roughly a given number of words. If any examples
|
"""Create minibatches of roughly a given number of words. If any examples
|
||||||
are longer than the specified batch length, they will appear in a batch by
|
are longer than the specified batch length, they will appear in a batch by
|
||||||
|
@ -768,6 +812,7 @@ def minibatch_by_words(docs, size, tolerance=0.2, discard_oversize=False):
|
||||||
|
|
||||||
# yield the previous batch and start a new one. The new one gets the overflow examples.
|
# yield the previous batch and start a new one. The new one gets the overflow examples.
|
||||||
else:
|
else:
|
||||||
|
if batch:
|
||||||
yield batch
|
yield batch
|
||||||
target_size = next(size_)
|
target_size = next(size_)
|
||||||
tol_size = target_size * tolerance
|
tol_size = target_size * tolerance
|
||||||
|
@ -788,15 +833,15 @@ def minibatch_by_words(docs, size, tolerance=0.2, discard_oversize=False):
|
||||||
|
|
||||||
# this example does not fit with the previous overflow: start another new batch
|
# this example does not fit with the previous overflow: start another new batch
|
||||||
else:
|
else:
|
||||||
|
if batch:
|
||||||
yield batch
|
yield batch
|
||||||
target_size = next(size_)
|
target_size = next(size_)
|
||||||
tol_size = target_size * tolerance
|
tol_size = target_size * tolerance
|
||||||
batch = [doc]
|
batch = [doc]
|
||||||
batch_size = n_words
|
batch_size = n_words
|
||||||
|
|
||||||
# yield the final batch
|
|
||||||
if batch:
|
|
||||||
batch.extend(overflow)
|
batch.extend(overflow)
|
||||||
|
if batch:
|
||||||
yield batch
|
yield batch
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user