Add docstrings for batchers

This commit is contained in:
Matthew Honnibal 2020-08-07 18:51:02 +02:00
parent 3901b088ff
commit f5c4e0b751

View File

@ -19,6 +19,22 @@ def configure_minibatch_by_padded_size(
discard_oversize: bool, discard_oversize: bool,
get_length: Optional[Callable[[ItemT], int]] = None get_length: Optional[Callable[[ItemT], int]] = None
) -> BatcherT: ) -> BatcherT:
"""Create a batcher that uses the `batch_by_padded_size` strategy.
The padded size is defined as the maximum length of sequences within the
batch multiplied by the number of sequences in the batch.
size (int or Iterable[int]): The largest padded size to batch sequences into.
Can be a single integer, or a sequence, allowing for variable batch sizes.
buffer (int): The number of sequences to accumulate before sorting by length.
A larger buffer will result in more even sizing, but if the buffer is
very large, the iteration order will be less random, which can result
in suboptimal training.
discard_oversize (bool): Whether to discard sequences that are by themselves
longer than the largest padded batch size.
get_length (Callable or None): Function to get the length of a sequence item.
The `len` function is used by default.
"""
# Avoid displacing optional values from the underlying function. # Avoid displacing optional values from the underlying function.
optionals = {"get_length": get_length} if get_length is not None else {} optionals = {"get_length": get_length} if get_length is not None else {}
return partial( return partial(
@ -38,6 +54,16 @@ def configure_minibatch_by_words(
discard_oversize: bool, discard_oversize: bool,
get_length: Optional[Callable[[ItemT], int]] = None get_length: Optional[Callable[[ItemT], int]] = None
) -> BatcherT: ) -> BatcherT:
"""Create a batcher that uses the "minibatch by words" strategy.
size (int or Iterable[int]): The target number of words per batch.
Can be a single integer, or a sequence, allowing for variable batch sizes.
tolerance (float): What percentage of the size to allow batches to exceed.
discard_oversize (bool): Whether to discard sequences that by themselves
exceed the tolerated size.
get_length (Callable or None): Function to get the length of a sequence
item. The `len` function is used by default.
"""
optionals = {"get_length": get_length} if get_length is not None else {} optionals = {"get_length": get_length} if get_length is not None else {}
return partial( return partial(
minibatch_by_words, size=size, discard_oversize=discard_oversize, **optionals minibatch_by_words, size=size, discard_oversize=discard_oversize, **optionals
@ -48,22 +74,43 @@ def configure_minibatch_by_words(
def configure_minibatch( def configure_minibatch(
size: Sizing, get_length: Optional[Callable[[ItemT], int]] = None size: Sizing, get_length: Optional[Callable[[ItemT], int]] = None
) -> BatcherT: ) -> BatcherT:
"""Create a batcher that creates batches of the specified size.
size (int or Iterable[int]): The target number of items per batch.
Can be a single integer, or a sequence, allowing for variable batch sizes.
"""
optionals = {"get_length": get_length} if get_length is not None else {} optionals = {"get_length": get_length} if get_length is not None else {}
return partial(minibatch, size=size, **optionals) return partial(minibatch, size=size, **optionals)
def minibatch_by_padded_size( def minibatch_by_padded_size(
docs: Iterator["Doc"], seqs: Iterable[ItemT],
size: Sizing, size: Sizing,
buffer: int = 256, buffer: int = 256,
discard_oversize: bool = False, discard_oversize: bool = False,
get_length: Callable = len, get_length: Callable = len,
) -> Iterator[Iterator["Doc"]]: ) -> Iterable[List[ItemT]]:
"""Minibatch a sequence by the size of padded batches that would result,
with sequences binned by length within a window.
The padded size is defined as the maximum length of sequences within the
batch multiplied by the number of sequences in the batch.
size (int): The largest padded size to batch sequences into.
buffer (int): The number of sequences to accumulate before sorting by length.
A larger buffer will result in more even sizing, but if the buffer is
very large, the iteration order will be less random, which can result
in suboptimal training.
discard_oversize (bool): Whether to discard sequences that are by themselves
longer than the largest padded batch size.
get_length (Callable or None): Function to get the length of a sequence item.
The `len` function is used by default.
"""
if isinstance(size, int): if isinstance(size, int):
size_ = itertools.repeat(size) size_ = itertools.repeat(size)
else: else:
size_ = size size_ = size
for outer_batch in minibatch(docs, size=buffer): for outer_batch in minibatch(seqs, size=buffer):
outer_batch = list(outer_batch) outer_batch = list(outer_batch)
target_size = next(size_) target_size = next(size_)
for indices in _batch_by_length(outer_batch, target_size, get_length): for indices in _batch_by_length(outer_batch, target_size, get_length):
@ -76,12 +123,20 @@ def minibatch_by_padded_size(
def minibatch_by_words( def minibatch_by_words(
docs, size, tolerance=0.2, discard_oversize=False, get_length=len seqs: Iterable[ItemT], size: Sizing, tolerance=0.2, discard_oversize=False, get_length=len
): ) -> Iterable[List[ItemT]]:
"""Create minibatches of roughly a given number of words. If any examples """Create minibatches of roughly a given number of words. If any examples
are longer than the specified batch length, they will appear in a batch by are longer than the specified batch length, they will appear in a batch by
themselves, or be discarded if discard_oversize=True. themselves, or be discarded if discard_oversize=True.
The argument 'docs' can be a list of strings, Docs or Examples.
seqs (Iterable[Sequence]): The sequences to minibatch.
size (int or Iterable[int]): The target number of words per batch.
Can be a single integer, or a sequence, allowing for variable batch sizes.
tolerance (float): What percentage of the size to allow batches to exceed.
discard_oversize (bool): Whether to discard sequences that by themselves
exceed the tolerated size.
get_length (Callable or None): Function to get the length of a sequence
item. The `len` function is used by default.
""" """
if isinstance(size, int): if isinstance(size, int):
size_ = itertools.repeat(size) size_ = itertools.repeat(size)
@ -95,20 +150,20 @@ def minibatch_by_words(
overflow = [] overflow = []
batch_size = 0 batch_size = 0
overflow_size = 0 overflow_size = 0
for doc in docs: for seq in seqs:
n_words = get_length(doc) n_words = get_length(seq)
# if the current example exceeds the maximum batch size, it is returned separately # if the current example exceeds the maximum batch size, it is returned separately
# but only if discard_oversize=False. # but only if discard_oversize=False.
if n_words > target_size + tol_size: if n_words > target_size + tol_size:
if not discard_oversize: if not discard_oversize:
yield [doc] yield [seq]
# add the example to the current batch if there's no overflow yet and it still fits # add the example to the current batch if there's no overflow yet and it still fits
elif overflow_size == 0 and (batch_size + n_words) <= target_size: elif overflow_size == 0 and (batch_size + n_words) <= target_size:
batch.append(doc) batch.append(seq)
batch_size += n_words batch_size += n_words
# add the example to the overflow buffer if it fits in the tolerance margin # add the example to the overflow buffer if it fits in the tolerance margin
elif (batch_size + overflow_size + n_words) <= (target_size + tol_size): elif (batch_size + overflow_size + n_words) <= (target_size + tol_size):
overflow.append(doc) overflow.append(seq)
overflow_size += n_words overflow_size += n_words
# yield the previous batch and start a new one. The new one gets the overflow examples. # yield the previous batch and start a new one. The new one gets the overflow examples.
else: else:
@ -122,11 +177,11 @@ def minibatch_by_words(
overflow_size = 0 overflow_size = 0
# this example still fits # this example still fits
if (batch_size + n_words) <= target_size: if (batch_size + n_words) <= target_size:
batch.append(doc) batch.append(seq)
batch_size += n_words batch_size += n_words
# this example fits in overflow # this example fits in overflow
elif (batch_size + n_words) <= (target_size + tol_size): elif (batch_size + n_words) <= (target_size + tol_size):
overflow.append(doc) overflow.append(seq)
overflow_size += n_words overflow_size += n_words
# this example does not fit with the previous overflow: start another new batch # this example does not fit with the previous overflow: start another new batch
else: else:
@ -134,7 +189,7 @@ def minibatch_by_words(
yield batch yield batch
target_size = next(size_) target_size = next(size_)
tol_size = target_size * tolerance tol_size = target_size * tolerance
batch = [doc] batch = [seq]
batch_size = n_words batch_size = n_words
batch.extend(overflow) batch.extend(overflow)
if batch: if batch: