spaCy/spacy/training/batchers.py

231 lines
8.8 KiB
Python
Raw Normal View History

2020-08-09 23:36:23 +03:00
from typing import Union, Iterable, Sequence, TypeVar, List, Callable
from typing import Optional, Any
from functools import partial
import itertools
from ..util import registry, minibatch
Sizing = Union[Iterable[int], int]
ItemT = TypeVar("ItemT")
BatcherT = Callable[[Iterable[ItemT]], Iterable[List[ItemT]]]
2020-09-03 18:30:41 +03:00
@registry.batchers("spacy.batch_by_padded.v1")
def configure_minibatch_by_padded_size(
*,
size: Sizing,
buffer: int,
discard_oversize: bool,
get_length: Optional[Callable[[ItemT], int]] = None
) -> BatcherT:
2020-08-07 19:51:02 +03:00
"""Create a batcher that uses the `batch_by_padded_size` strategy.
2020-08-09 23:36:23 +03:00
2020-08-07 19:51:02 +03:00
The padded size is defined as the maximum length of sequences within the
batch multiplied by the number of sequences in the batch.
size (int or Iterable[int]): The largest padded size to batch sequences into.
Can be a single integer, or a sequence, allowing for variable batch sizes.
buffer (int): The number of sequences to accumulate before sorting by length.
A larger buffer will result in more even sizing, but if the buffer is
very large, the iteration order will be less random, which can result
in suboptimal training.
discard_oversize (bool): Whether to discard sequences that are by themselves
longer than the largest padded batch size.
get_length (Callable or None): Function to get the length of a sequence item.
The `len` function is used by default.
"""
# Avoid displacing optional values from the underlying function.
optionals = {"get_length": get_length} if get_length is not None else {}
return partial(
minibatch_by_padded_size,
size=size,
buffer=buffer,
discard_oversize=discard_oversize,
**optionals
)
2020-09-03 18:30:41 +03:00
@registry.batchers("spacy.batch_by_words.v1")
def configure_minibatch_by_words(
*,
size: Sizing,
tolerance: float,
discard_oversize: bool,
get_length: Optional[Callable[[ItemT], int]] = None
) -> BatcherT:
2020-08-07 19:51:02 +03:00
"""Create a batcher that uses the "minibatch by words" strategy.
size (int or Iterable[int]): The target number of words per batch.
Can be a single integer, or a sequence, allowing for variable batch sizes.
tolerance (float): What percentage of the size to allow batches to exceed.
discard_oversize (bool): Whether to discard sequences that by themselves
exceed the tolerated size.
get_length (Callable or None): Function to get the length of a sequence
item. The `len` function is used by default.
"""
optionals = {"get_length": get_length} if get_length is not None else {}
return partial(
2020-08-05 17:00:59 +03:00
minibatch_by_words, size=size, discard_oversize=discard_oversize, **optionals
)
2020-09-03 18:30:41 +03:00
@registry.batchers("spacy.batch_by_sequence.v1")
2020-08-05 21:29:46 +03:00
def configure_minibatch(
size: Sizing, get_length: Optional[Callable[[ItemT], int]] = None
) -> BatcherT:
2020-08-07 19:51:02 +03:00
"""Create a batcher that creates batches of the specified size.
size (int or Iterable[int]): The target number of items per batch.
Can be a single integer, or a sequence, allowing for variable batch sizes.
"""
2020-08-05 17:00:59 +03:00
optionals = {"get_length": get_length} if get_length is not None else {}
return partial(minibatch, size=size, **optionals)
def minibatch_by_padded_size(
2020-08-07 19:51:02 +03:00
seqs: Iterable[ItemT],
size: Sizing,
buffer: int = 256,
discard_oversize: bool = False,
get_length: Callable = len,
2020-08-07 19:51:02 +03:00
) -> Iterable[List[ItemT]]:
"""Minibatch a sequence by the size of padded batches that would result,
with sequences binned by length within a window.
2020-08-09 23:36:23 +03:00
2020-08-07 19:51:02 +03:00
The padded size is defined as the maximum length of sequences within the
batch multiplied by the number of sequences in the batch.
size (int): The largest padded size to batch sequences into.
buffer (int): The number of sequences to accumulate before sorting by length.
A larger buffer will result in more even sizing, but if the buffer is
very large, the iteration order will be less random, which can result
in suboptimal training.
discard_oversize (bool): Whether to discard sequences that are by themselves
longer than the largest padded batch size.
get_length (Callable or None): Function to get the length of a sequence item.
The `len` function is used by default.
"""
if isinstance(size, int):
size_ = itertools.repeat(size)
else:
size_ = size
2020-08-07 19:51:02 +03:00
for outer_batch in minibatch(seqs, size=buffer):
outer_batch = list(outer_batch)
target_size = next(size_)
for indices in _batch_by_length(outer_batch, target_size, get_length):
subbatch = [outer_batch[i] for i in indices]
padded_size = max(len(seq) for seq in subbatch) * len(subbatch)
if discard_oversize and padded_size >= target_size:
pass
else:
yield subbatch
def minibatch_by_words(
2020-08-09 23:36:23 +03:00
seqs: Iterable[ItemT],
size: Sizing,
tolerance=0.2,
discard_oversize=False,
get_length=len,
2020-08-07 19:51:02 +03:00
) -> Iterable[List[ItemT]]:
"""Create minibatches of roughly a given number of words. If any examples
are longer than the specified batch length, they will appear in a batch by
themselves, or be discarded if discard_oversize=True.
2020-08-07 19:51:02 +03:00
seqs (Iterable[Sequence]): The sequences to minibatch.
size (int or Iterable[int]): The target number of words per batch.
Can be a single integer, or a sequence, allowing for variable batch sizes.
tolerance (float): What percentage of the size to allow batches to exceed.
discard_oversize (bool): Whether to discard sequences that by themselves
exceed the tolerated size.
get_length (Callable or None): Function to get the length of a sequence
item. The `len` function is used by default.
2020-08-05 21:29:46 +03:00
"""
if isinstance(size, int):
size_ = itertools.repeat(size)
elif isinstance(size, List):
size_ = iter(size)
else:
size_ = size
target_size = next(size_)
tol_size = target_size * tolerance
batch = []
overflow = []
batch_size = 0
overflow_size = 0
2020-08-07 19:51:02 +03:00
for seq in seqs:
n_words = get_length(seq)
# if the current example exceeds the maximum batch size, it is returned separately
# but only if discard_oversize=False.
if n_words > target_size + tol_size:
if not discard_oversize:
2020-08-07 19:51:02 +03:00
yield [seq]
# add the example to the current batch if there's no overflow yet and it still fits
elif overflow_size == 0 and (batch_size + n_words) <= target_size:
2020-08-07 19:51:02 +03:00
batch.append(seq)
batch_size += n_words
# add the example to the overflow buffer if it fits in the tolerance margin
elif (batch_size + overflow_size + n_words) <= (target_size + tol_size):
2020-08-07 19:51:02 +03:00
overflow.append(seq)
overflow_size += n_words
# yield the previous batch and start a new one. The new one gets the overflow examples.
else:
if batch:
yield batch
target_size = next(size_)
tol_size = target_size * tolerance
batch = overflow
batch_size = overflow_size
overflow = []
overflow_size = 0
# this example still fits
if (batch_size + n_words) <= target_size:
2020-08-07 19:51:02 +03:00
batch.append(seq)
batch_size += n_words
# this example fits in overflow
elif (batch_size + n_words) <= (target_size + tol_size):
2020-08-07 19:51:02 +03:00
overflow.append(seq)
overflow_size += n_words
# this example does not fit with the previous overflow: start another new batch
else:
if batch:
yield batch
target_size = next(size_)
tol_size = target_size * tolerance
2020-08-07 19:51:02 +03:00
batch = [seq]
batch_size = n_words
batch.extend(overflow)
if batch:
yield batch
def _batch_by_length(
seqs: Sequence[Any], max_words: int, get_length=len
) -> List[List[Any]]:
"""Given a list of sequences, return a batched list of indices into the
list, where the batches are grouped by length, in descending order.
Batches may be at most max_words in size, defined as max sequence length * size.
"""
# Use negative index so we can get sort by position ascending.
lengths_indices = [(get_length(seq), i) for i, seq in enumerate(seqs)]
lengths_indices.sort()
batches = []
batch = []
for length, i in lengths_indices:
if not batch:
batch.append(i)
elif length * (len(batch) + 1) <= max_words:
batch.append(i)
else:
batches.append(batch)
batch = [i]
if batch:
batches.append(batch)
# Check lengths match
assert sum(len(b) for b in batches) == len(seqs)
batches = [list(sorted(batch)) for batch in batches]
batches.reverse()
return batches