mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-11 04:08:09 +03:00
Add docstrings for batchers
This commit is contained in:
parent
3901b088ff
commit
f5c4e0b751
|
@ -19,6 +19,22 @@ def configure_minibatch_by_padded_size(
|
|||
discard_oversize: bool,
|
||||
get_length: Optional[Callable[[ItemT], int]] = None
|
||||
) -> BatcherT:
|
||||
"""Create a batcher that uses the `batch_by_padded_size` strategy.
|
||||
|
||||
The padded size is defined as the maximum length of sequences within the
|
||||
batch multiplied by the number of sequences in the batch.
|
||||
|
||||
size (int or Iterable[int]): The largest padded size to batch sequences into.
|
||||
Can be a single integer, or a sequence, allowing for variable batch sizes.
|
||||
buffer (int): The number of sequences to accumulate before sorting by length.
|
||||
A larger buffer will result in more even sizing, but if the buffer is
|
||||
very large, the iteration order will be less random, which can result
|
||||
in suboptimal training.
|
||||
discard_oversize (bool): Whether to discard sequences that are by themselves
|
||||
longer than the largest padded batch size.
|
||||
get_length (Callable or None): Function to get the length of a sequence item.
|
||||
The `len` function is used by default.
|
||||
"""
|
||||
# Avoid displacing optional values from the underlying function.
|
||||
optionals = {"get_length": get_length} if get_length is not None else {}
|
||||
return partial(
|
||||
|
@ -38,6 +54,16 @@ def configure_minibatch_by_words(
|
|||
discard_oversize: bool,
|
||||
get_length: Optional[Callable[[ItemT], int]] = None
|
||||
) -> BatcherT:
|
||||
"""Create a batcher that uses the "minibatch by words" strategy.
|
||||
|
||||
size (int or Iterable[int]): The target number of words per batch.
|
||||
Can be a single integer, or a sequence, allowing for variable batch sizes.
|
||||
tolerance (float): What percentage of the size to allow batches to exceed.
|
||||
discard_oversize (bool): Whether to discard sequences that by themselves
|
||||
exceed the tolerated size.
|
||||
get_length (Callable or None): Function to get the length of a sequence
|
||||
item. The `len` function is used by default.
|
||||
"""
|
||||
optionals = {"get_length": get_length} if get_length is not None else {}
|
||||
return partial(
|
||||
minibatch_by_words, size=size, discard_oversize=discard_oversize, **optionals
|
||||
|
@ -48,22 +74,43 @@ def configure_minibatch_by_words(
|
|||
def configure_minibatch(
|
||||
size: Sizing, get_length: Optional[Callable[[ItemT], int]] = None
|
||||
) -> BatcherT:
|
||||
"""Create a batcher that creates batches of the specified size.
|
||||
|
||||
size (int or Iterable[int]): The target number of items per batch.
|
||||
Can be a single integer, or a sequence, allowing for variable batch sizes.
|
||||
"""
|
||||
optionals = {"get_length": get_length} if get_length is not None else {}
|
||||
return partial(minibatch, size=size, **optionals)
|
||||
|
||||
|
||||
def minibatch_by_padded_size(
|
||||
docs: Iterator["Doc"],
|
||||
seqs: Iterable[ItemT],
|
||||
size: Sizing,
|
||||
buffer: int = 256,
|
||||
discard_oversize: bool = False,
|
||||
get_length: Callable = len,
|
||||
) -> Iterator[Iterator["Doc"]]:
|
||||
) -> Iterable[List[ItemT]]:
|
||||
"""Minibatch a sequence by the size of padded batches that would result,
|
||||
with sequences binned by length within a window.
|
||||
|
||||
The padded size is defined as the maximum length of sequences within the
|
||||
batch multiplied by the number of sequences in the batch.
|
||||
|
||||
size (int): The largest padded size to batch sequences into.
|
||||
buffer (int): The number of sequences to accumulate before sorting by length.
|
||||
A larger buffer will result in more even sizing, but if the buffer is
|
||||
very large, the iteration order will be less random, which can result
|
||||
in suboptimal training.
|
||||
discard_oversize (bool): Whether to discard sequences that are by themselves
|
||||
longer than the largest padded batch size.
|
||||
get_length (Callable or None): Function to get the length of a sequence item.
|
||||
The `len` function is used by default.
|
||||
"""
|
||||
if isinstance(size, int):
|
||||
size_ = itertools.repeat(size)
|
||||
else:
|
||||
size_ = size
|
||||
for outer_batch in minibatch(docs, size=buffer):
|
||||
for outer_batch in minibatch(seqs, size=buffer):
|
||||
outer_batch = list(outer_batch)
|
||||
target_size = next(size_)
|
||||
for indices in _batch_by_length(outer_batch, target_size, get_length):
|
||||
|
@ -76,12 +123,20 @@ def minibatch_by_padded_size(
|
|||
|
||||
|
||||
def minibatch_by_words(
|
||||
docs, size, tolerance=0.2, discard_oversize=False, get_length=len
|
||||
):
|
||||
seqs: Iterable[ItemT], size: Sizing, tolerance=0.2, discard_oversize=False, get_length=len
|
||||
) -> Iterable[List[ItemT]]:
|
||||
"""Create minibatches of roughly a given number of words. If any examples
|
||||
are longer than the specified batch length, they will appear in a batch by
|
||||
themselves, or be discarded if discard_oversize=True.
|
||||
The argument 'docs' can be a list of strings, Docs or Examples.
|
||||
|
||||
seqs (Iterable[Sequence]): The sequences to minibatch.
|
||||
size (int or Iterable[int]): The target number of words per batch.
|
||||
Can be a single integer, or a sequence, allowing for variable batch sizes.
|
||||
tolerance (float): What percentage of the size to allow batches to exceed.
|
||||
discard_oversize (bool): Whether to discard sequences that by themselves
|
||||
exceed the tolerated size.
|
||||
get_length (Callable or None): Function to get the length of a sequence
|
||||
item. The `len` function is used by default.
|
||||
"""
|
||||
if isinstance(size, int):
|
||||
size_ = itertools.repeat(size)
|
||||
|
@ -95,20 +150,20 @@ def minibatch_by_words(
|
|||
overflow = []
|
||||
batch_size = 0
|
||||
overflow_size = 0
|
||||
for doc in docs:
|
||||
n_words = get_length(doc)
|
||||
for seq in seqs:
|
||||
n_words = get_length(seq)
|
||||
# if the current example exceeds the maximum batch size, it is returned separately
|
||||
# but only if discard_oversize=False.
|
||||
if n_words > target_size + tol_size:
|
||||
if not discard_oversize:
|
||||
yield [doc]
|
||||
yield [seq]
|
||||
# add the example to the current batch if there's no overflow yet and it still fits
|
||||
elif overflow_size == 0 and (batch_size + n_words) <= target_size:
|
||||
batch.append(doc)
|
||||
batch.append(seq)
|
||||
batch_size += n_words
|
||||
# add the example to the overflow buffer if it fits in the tolerance margin
|
||||
elif (batch_size + overflow_size + n_words) <= (target_size + tol_size):
|
||||
overflow.append(doc)
|
||||
overflow.append(seq)
|
||||
overflow_size += n_words
|
||||
# yield the previous batch and start a new one. The new one gets the overflow examples.
|
||||
else:
|
||||
|
@ -122,11 +177,11 @@ def minibatch_by_words(
|
|||
overflow_size = 0
|
||||
# this example still fits
|
||||
if (batch_size + n_words) <= target_size:
|
||||
batch.append(doc)
|
||||
batch.append(seq)
|
||||
batch_size += n_words
|
||||
# this example fits in overflow
|
||||
elif (batch_size + n_words) <= (target_size + tol_size):
|
||||
overflow.append(doc)
|
||||
overflow.append(seq)
|
||||
overflow_size += n_words
|
||||
# this example does not fit with the previous overflow: start another new batch
|
||||
else:
|
||||
|
@ -134,7 +189,7 @@ def minibatch_by_words(
|
|||
yield batch
|
||||
target_size = next(size_)
|
||||
tol_size = target_size * tolerance
|
||||
batch = [doc]
|
||||
batch = [seq]
|
||||
batch_size = n_words
|
||||
batch.extend(overflow)
|
||||
if batch:
|
||||
|
|
Loading…
Reference in New Issue
Block a user