mirror of
https://github.com/explosion/spaCy.git
synced 2025-03-03 10:55:52 +03:00
Merge pull request #5895 from explosion/docs/batchers
Draft docstrings for batchers
This commit is contained in:
commit
fd20f84927
|
@ -19,6 +19,22 @@ def configure_minibatch_by_padded_size(
|
||||||
discard_oversize: bool,
|
discard_oversize: bool,
|
||||||
get_length: Optional[Callable[[ItemT], int]] = None
|
get_length: Optional[Callable[[ItemT], int]] = None
|
||||||
) -> BatcherT:
|
) -> BatcherT:
|
||||||
|
"""Create a batcher that uses the `batch_by_padded_size` strategy.
|
||||||
|
|
||||||
|
The padded size is defined as the maximum length of sequences within the
|
||||||
|
batch multiplied by the number of sequences in the batch.
|
||||||
|
|
||||||
|
size (int or Iterable[int]): The largest padded size to batch sequences into.
|
||||||
|
Can be a single integer, or a sequence, allowing for variable batch sizes.
|
||||||
|
buffer (int): The number of sequences to accumulate before sorting by length.
|
||||||
|
A larger buffer will result in more even sizing, but if the buffer is
|
||||||
|
very large, the iteration order will be less random, which can result
|
||||||
|
in suboptimal training.
|
||||||
|
discard_oversize (bool): Whether to discard sequences that are by themselves
|
||||||
|
longer than the largest padded batch size.
|
||||||
|
get_length (Callable or None): Function to get the length of a sequence item.
|
||||||
|
The `len` function is used by default.
|
||||||
|
"""
|
||||||
# Avoid displacing optional values from the underlying function.
|
# Avoid displacing optional values from the underlying function.
|
||||||
optionals = {"get_length": get_length} if get_length is not None else {}
|
optionals = {"get_length": get_length} if get_length is not None else {}
|
||||||
return partial(
|
return partial(
|
||||||
|
@ -38,6 +54,16 @@ def configure_minibatch_by_words(
|
||||||
discard_oversize: bool,
|
discard_oversize: bool,
|
||||||
get_length: Optional[Callable[[ItemT], int]] = None
|
get_length: Optional[Callable[[ItemT], int]] = None
|
||||||
) -> BatcherT:
|
) -> BatcherT:
|
||||||
|
"""Create a batcher that uses the "minibatch by words" strategy.
|
||||||
|
|
||||||
|
size (int or Iterable[int]): The target number of words per batch.
|
||||||
|
Can be a single integer, or a sequence, allowing for variable batch sizes.
|
||||||
|
tolerance (float): What percentage of the size to allow batches to exceed.
|
||||||
|
discard_oversize (bool): Whether to discard sequences that by themselves
|
||||||
|
exceed the tolerated size.
|
||||||
|
get_length (Callable or None): Function to get the length of a sequence
|
||||||
|
item. The `len` function is used by default.
|
||||||
|
"""
|
||||||
optionals = {"get_length": get_length} if get_length is not None else {}
|
optionals = {"get_length": get_length} if get_length is not None else {}
|
||||||
return partial(
|
return partial(
|
||||||
minibatch_by_words, size=size, discard_oversize=discard_oversize, **optionals
|
minibatch_by_words, size=size, discard_oversize=discard_oversize, **optionals
|
||||||
|
@ -48,22 +74,43 @@ def configure_minibatch_by_words(
|
||||||
def configure_minibatch(
|
def configure_minibatch(
|
||||||
size: Sizing, get_length: Optional[Callable[[ItemT], int]] = None
|
size: Sizing, get_length: Optional[Callable[[ItemT], int]] = None
|
||||||
) -> BatcherT:
|
) -> BatcherT:
|
||||||
|
"""Create a batcher that creates batches of the specified size.
|
||||||
|
|
||||||
|
size (int or Iterable[int]): The target number of items per batch.
|
||||||
|
Can be a single integer, or a sequence, allowing for variable batch sizes.
|
||||||
|
"""
|
||||||
optionals = {"get_length": get_length} if get_length is not None else {}
|
optionals = {"get_length": get_length} if get_length is not None else {}
|
||||||
return partial(minibatch, size=size, **optionals)
|
return partial(minibatch, size=size, **optionals)
|
||||||
|
|
||||||
|
|
||||||
def minibatch_by_padded_size(
|
def minibatch_by_padded_size(
|
||||||
docs: Iterator["Doc"],
|
seqs: Iterable[ItemT],
|
||||||
size: Sizing,
|
size: Sizing,
|
||||||
buffer: int = 256,
|
buffer: int = 256,
|
||||||
discard_oversize: bool = False,
|
discard_oversize: bool = False,
|
||||||
get_length: Callable = len,
|
get_length: Callable = len,
|
||||||
) -> Iterator[Iterator["Doc"]]:
|
) -> Iterable[List[ItemT]]:
|
||||||
|
"""Minibatch a sequence by the size of padded batches that would result,
|
||||||
|
with sequences binned by length within a window.
|
||||||
|
|
||||||
|
The padded size is defined as the maximum length of sequences within the
|
||||||
|
batch multiplied by the number of sequences in the batch.
|
||||||
|
|
||||||
|
size (int): The largest padded size to batch sequences into.
|
||||||
|
buffer (int): The number of sequences to accumulate before sorting by length.
|
||||||
|
A larger buffer will result in more even sizing, but if the buffer is
|
||||||
|
very large, the iteration order will be less random, which can result
|
||||||
|
in suboptimal training.
|
||||||
|
discard_oversize (bool): Whether to discard sequences that are by themselves
|
||||||
|
longer than the largest padded batch size.
|
||||||
|
get_length (Callable or None): Function to get the length of a sequence item.
|
||||||
|
The `len` function is used by default.
|
||||||
|
"""
|
||||||
if isinstance(size, int):
|
if isinstance(size, int):
|
||||||
size_ = itertools.repeat(size)
|
size_ = itertools.repeat(size)
|
||||||
else:
|
else:
|
||||||
size_ = size
|
size_ = size
|
||||||
for outer_batch in minibatch(docs, size=buffer):
|
for outer_batch in minibatch(seqs, size=buffer):
|
||||||
outer_batch = list(outer_batch)
|
outer_batch = list(outer_batch)
|
||||||
target_size = next(size_)
|
target_size = next(size_)
|
||||||
for indices in _batch_by_length(outer_batch, target_size, get_length):
|
for indices in _batch_by_length(outer_batch, target_size, get_length):
|
||||||
|
@ -76,12 +123,20 @@ def minibatch_by_padded_size(
|
||||||
|
|
||||||
|
|
||||||
def minibatch_by_words(
|
def minibatch_by_words(
|
||||||
docs, size, tolerance=0.2, discard_oversize=False, get_length=len
|
seqs: Iterable[ItemT], size: Sizing, tolerance=0.2, discard_oversize=False, get_length=len
|
||||||
):
|
) -> Iterable[List[ItemT]]:
|
||||||
"""Create minibatches of roughly a given number of words. If any examples
|
"""Create minibatches of roughly a given number of words. If any examples
|
||||||
are longer than the specified batch length, they will appear in a batch by
|
are longer than the specified batch length, they will appear in a batch by
|
||||||
themselves, or be discarded if discard_oversize=True.
|
themselves, or be discarded if discard_oversize=True.
|
||||||
The argument 'docs' can be a list of strings, Docs or Examples.
|
|
||||||
|
seqs (Iterable[Sequence]): The sequences to minibatch.
|
||||||
|
size (int or Iterable[int]): The target number of words per batch.
|
||||||
|
Can be a single integer, or a sequence, allowing for variable batch sizes.
|
||||||
|
tolerance (float): What percentage of the size to allow batches to exceed.
|
||||||
|
discard_oversize (bool): Whether to discard sequences that by themselves
|
||||||
|
exceed the tolerated size.
|
||||||
|
get_length (Callable or None): Function to get the length of a sequence
|
||||||
|
item. The `len` function is used by default.
|
||||||
"""
|
"""
|
||||||
if isinstance(size, int):
|
if isinstance(size, int):
|
||||||
size_ = itertools.repeat(size)
|
size_ = itertools.repeat(size)
|
||||||
|
@ -95,20 +150,20 @@ def minibatch_by_words(
|
||||||
overflow = []
|
overflow = []
|
||||||
batch_size = 0
|
batch_size = 0
|
||||||
overflow_size = 0
|
overflow_size = 0
|
||||||
for doc in docs:
|
for seq in seqs:
|
||||||
n_words = get_length(doc)
|
n_words = get_length(seq)
|
||||||
# if the current example exceeds the maximum batch size, it is returned separately
|
# if the current example exceeds the maximum batch size, it is returned separately
|
||||||
# but only if discard_oversize=False.
|
# but only if discard_oversize=False.
|
||||||
if n_words > target_size + tol_size:
|
if n_words > target_size + tol_size:
|
||||||
if not discard_oversize:
|
if not discard_oversize:
|
||||||
yield [doc]
|
yield [seq]
|
||||||
# add the example to the current batch if there's no overflow yet and it still fits
|
# add the example to the current batch if there's no overflow yet and it still fits
|
||||||
elif overflow_size == 0 and (batch_size + n_words) <= target_size:
|
elif overflow_size == 0 and (batch_size + n_words) <= target_size:
|
||||||
batch.append(doc)
|
batch.append(seq)
|
||||||
batch_size += n_words
|
batch_size += n_words
|
||||||
# add the example to the overflow buffer if it fits in the tolerance margin
|
# add the example to the overflow buffer if it fits in the tolerance margin
|
||||||
elif (batch_size + overflow_size + n_words) <= (target_size + tol_size):
|
elif (batch_size + overflow_size + n_words) <= (target_size + tol_size):
|
||||||
overflow.append(doc)
|
overflow.append(seq)
|
||||||
overflow_size += n_words
|
overflow_size += n_words
|
||||||
# yield the previous batch and start a new one. The new one gets the overflow examples.
|
# yield the previous batch and start a new one. The new one gets the overflow examples.
|
||||||
else:
|
else:
|
||||||
|
@ -122,11 +177,11 @@ def minibatch_by_words(
|
||||||
overflow_size = 0
|
overflow_size = 0
|
||||||
# this example still fits
|
# this example still fits
|
||||||
if (batch_size + n_words) <= target_size:
|
if (batch_size + n_words) <= target_size:
|
||||||
batch.append(doc)
|
batch.append(seq)
|
||||||
batch_size += n_words
|
batch_size += n_words
|
||||||
# this example fits in overflow
|
# this example fits in overflow
|
||||||
elif (batch_size + n_words) <= (target_size + tol_size):
|
elif (batch_size + n_words) <= (target_size + tol_size):
|
||||||
overflow.append(doc)
|
overflow.append(seq)
|
||||||
overflow_size += n_words
|
overflow_size += n_words
|
||||||
# this example does not fit with the previous overflow: start another new batch
|
# this example does not fit with the previous overflow: start another new batch
|
||||||
else:
|
else:
|
||||||
|
@ -134,7 +189,7 @@ def minibatch_by_words(
|
||||||
yield batch
|
yield batch
|
||||||
target_size = next(size_)
|
target_size = next(size_)
|
||||||
tol_size = target_size * tolerance
|
tol_size = target_size * tolerance
|
||||||
batch = [doc]
|
batch = [seq]
|
||||||
batch_size = n_words
|
batch_size = n_words
|
||||||
batch.extend(overflow)
|
batch.extend(overflow)
|
||||||
if batch:
|
if batch:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user