2020-08-09 23:36:23 +03:00
|
|
|
from typing import Union, Iterable, Sequence, TypeVar, List, Callable
|
2020-08-04 16:09:37 +03:00
|
|
|
from typing import Optional, Any
|
|
|
|
from functools import partial
|
|
|
|
import itertools
|
|
|
|
|
|
|
|
from ..util import registry, minibatch
|
|
|
|
|
|
|
|
|
|
|
|
Sizing = Union[Iterable[int], int]
|
|
|
|
ItemT = TypeVar("ItemT")
|
|
|
|
BatcherT = Callable[[Iterable[ItemT]], Iterable[List[ItemT]]]
|
|
|
|
|
|
|
|
|
2020-09-03 18:30:41 +03:00
|
|
|
@registry.batchers("spacy.batch_by_padded.v1")
|
2020-08-04 16:09:37 +03:00
|
|
|
def configure_minibatch_by_padded_size(
|
|
|
|
*,
|
|
|
|
size: Sizing,
|
|
|
|
buffer: int,
|
|
|
|
discard_oversize: bool,
|
|
|
|
get_length: Optional[Callable[[ItemT], int]] = None
|
|
|
|
) -> BatcherT:
|
2020-08-07 19:51:02 +03:00
|
|
|
"""Create a batcher that uses the `batch_by_padded_size` strategy.
|
2020-08-09 23:36:23 +03:00
|
|
|
|
2020-08-07 19:51:02 +03:00
|
|
|
The padded size is defined as the maximum length of sequences within the
|
|
|
|
batch multiplied by the number of sequences in the batch.
|
|
|
|
|
|
|
|
size (int or Iterable[int]): The largest padded size to batch sequences into.
|
|
|
|
Can be a single integer, or a sequence, allowing for variable batch sizes.
|
|
|
|
buffer (int): The number of sequences to accumulate before sorting by length.
|
|
|
|
A larger buffer will result in more even sizing, but if the buffer is
|
|
|
|
very large, the iteration order will be less random, which can result
|
|
|
|
in suboptimal training.
|
|
|
|
discard_oversize (bool): Whether to discard sequences that are by themselves
|
|
|
|
longer than the largest padded batch size.
|
|
|
|
get_length (Callable or None): Function to get the length of a sequence item.
|
|
|
|
The `len` function is used by default.
|
|
|
|
"""
|
2020-08-04 16:09:37 +03:00
|
|
|
# Avoid displacing optional values from the underlying function.
|
|
|
|
optionals = {"get_length": get_length} if get_length is not None else {}
|
|
|
|
return partial(
|
|
|
|
minibatch_by_padded_size,
|
|
|
|
size=size,
|
|
|
|
buffer=buffer,
|
|
|
|
discard_oversize=discard_oversize,
|
|
|
|
**optionals
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2020-09-03 18:30:41 +03:00
|
|
|
@registry.batchers("spacy.batch_by_words.v1")
|
2020-08-04 16:09:37 +03:00
|
|
|
def configure_minibatch_by_words(
|
|
|
|
*,
|
|
|
|
size: Sizing,
|
|
|
|
tolerance: float,
|
|
|
|
discard_oversize: bool,
|
|
|
|
get_length: Optional[Callable[[ItemT], int]] = None
|
|
|
|
) -> BatcherT:
|
2020-08-07 19:51:02 +03:00
|
|
|
"""Create a batcher that uses the "minibatch by words" strategy.
|
|
|
|
|
|
|
|
size (int or Iterable[int]): The target number of words per batch.
|
|
|
|
Can be a single integer, or a sequence, allowing for variable batch sizes.
|
|
|
|
tolerance (float): What percentage of the size to allow batches to exceed.
|
|
|
|
discard_oversize (bool): Whether to discard sequences that by themselves
|
|
|
|
exceed the tolerated size.
|
|
|
|
get_length (Callable or None): Function to get the length of a sequence
|
|
|
|
item. The `len` function is used by default.
|
|
|
|
"""
|
2020-08-04 16:09:37 +03:00
|
|
|
optionals = {"get_length": get_length} if get_length is not None else {}
|
|
|
|
return partial(
|
2021-06-28 12:48:00 +03:00
|
|
|
minibatch_by_words,
|
|
|
|
size=size,
|
|
|
|
tolerance=tolerance,
|
|
|
|
discard_oversize=discard_oversize,
|
|
|
|
**optionals
|
2020-08-04 16:09:37 +03:00
|
|
|
)
|
|
|
|
|
|
|
|
|
2020-09-03 18:30:41 +03:00
|
|
|
@registry.batchers("spacy.batch_by_sequence.v1")
|
2020-08-05 21:29:46 +03:00
|
|
|
def configure_minibatch(
|
|
|
|
size: Sizing, get_length: Optional[Callable[[ItemT], int]] = None
|
|
|
|
) -> BatcherT:
|
2020-08-07 19:51:02 +03:00
|
|
|
"""Create a batcher that creates batches of the specified size.
|
|
|
|
|
|
|
|
size (int or Iterable[int]): The target number of items per batch.
|
|
|
|
Can be a single integer, or a sequence, allowing for variable batch sizes.
|
|
|
|
"""
|
2020-08-05 17:00:59 +03:00
|
|
|
optionals = {"get_length": get_length} if get_length is not None else {}
|
2020-08-04 16:09:37 +03:00
|
|
|
return partial(minibatch, size=size, **optionals)
|
|
|
|
|
|
|
|
|
|
|
|
def minibatch_by_padded_size(
|
2020-08-07 19:51:02 +03:00
|
|
|
seqs: Iterable[ItemT],
|
2020-08-04 16:09:37 +03:00
|
|
|
size: Sizing,
|
|
|
|
buffer: int = 256,
|
|
|
|
discard_oversize: bool = False,
|
|
|
|
get_length: Callable = len,
|
2020-08-07 19:51:02 +03:00
|
|
|
) -> Iterable[List[ItemT]]:
|
|
|
|
"""Minibatch a sequence by the size of padded batches that would result,
|
|
|
|
with sequences binned by length within a window.
|
2020-08-09 23:36:23 +03:00
|
|
|
|
2020-08-07 19:51:02 +03:00
|
|
|
The padded size is defined as the maximum length of sequences within the
|
|
|
|
batch multiplied by the number of sequences in the batch.
|
|
|
|
|
|
|
|
size (int): The largest padded size to batch sequences into.
|
|
|
|
buffer (int): The number of sequences to accumulate before sorting by length.
|
|
|
|
A larger buffer will result in more even sizing, but if the buffer is
|
|
|
|
very large, the iteration order will be less random, which can result
|
|
|
|
in suboptimal training.
|
|
|
|
discard_oversize (bool): Whether to discard sequences that are by themselves
|
|
|
|
longer than the largest padded batch size.
|
|
|
|
get_length (Callable or None): Function to get the length of a sequence item.
|
|
|
|
The `len` function is used by default.
|
|
|
|
"""
|
2020-08-04 16:09:37 +03:00
|
|
|
if isinstance(size, int):
|
|
|
|
size_ = itertools.repeat(size)
|
|
|
|
else:
|
|
|
|
size_ = size
|
2020-08-07 19:51:02 +03:00
|
|
|
for outer_batch in minibatch(seqs, size=buffer):
|
2020-08-04 16:09:37 +03:00
|
|
|
outer_batch = list(outer_batch)
|
|
|
|
target_size = next(size_)
|
|
|
|
for indices in _batch_by_length(outer_batch, target_size, get_length):
|
|
|
|
subbatch = [outer_batch[i] for i in indices]
|
|
|
|
padded_size = max(len(seq) for seq in subbatch) * len(subbatch)
|
|
|
|
if discard_oversize and padded_size >= target_size:
|
|
|
|
pass
|
|
|
|
else:
|
|
|
|
yield subbatch
|
|
|
|
|
|
|
|
|
|
|
|
def minibatch_by_words(
|
2020-08-09 23:36:23 +03:00
|
|
|
seqs: Iterable[ItemT],
|
|
|
|
size: Sizing,
|
|
|
|
tolerance=0.2,
|
|
|
|
discard_oversize=False,
|
|
|
|
get_length=len,
|
2020-08-07 19:51:02 +03:00
|
|
|
) -> Iterable[List[ItemT]]:
|
2020-08-04 16:09:37 +03:00
|
|
|
"""Create minibatches of roughly a given number of words. If any examples
|
|
|
|
are longer than the specified batch length, they will appear in a batch by
|
|
|
|
themselves, or be discarded if discard_oversize=True.
|
2020-08-07 19:51:02 +03:00
|
|
|
|
|
|
|
seqs (Iterable[Sequence]): The sequences to minibatch.
|
|
|
|
size (int or Iterable[int]): The target number of words per batch.
|
|
|
|
Can be a single integer, or a sequence, allowing for variable batch sizes.
|
|
|
|
tolerance (float): What percentage of the size to allow batches to exceed.
|
|
|
|
discard_oversize (bool): Whether to discard sequences that by themselves
|
|
|
|
exceed the tolerated size.
|
|
|
|
get_length (Callable or None): Function to get the length of a sequence
|
|
|
|
item. The `len` function is used by default.
|
2020-08-05 21:29:46 +03:00
|
|
|
"""
|
2020-08-04 16:09:37 +03:00
|
|
|
if isinstance(size, int):
|
|
|
|
size_ = itertools.repeat(size)
|
|
|
|
elif isinstance(size, List):
|
|
|
|
size_ = iter(size)
|
|
|
|
else:
|
|
|
|
size_ = size
|
|
|
|
target_size = next(size_)
|
|
|
|
tol_size = target_size * tolerance
|
|
|
|
batch = []
|
|
|
|
overflow = []
|
|
|
|
batch_size = 0
|
|
|
|
overflow_size = 0
|
2020-08-07 19:51:02 +03:00
|
|
|
for seq in seqs:
|
|
|
|
n_words = get_length(seq)
|
2020-08-04 16:09:37 +03:00
|
|
|
# if the current example exceeds the maximum batch size, it is returned separately
|
|
|
|
# but only if discard_oversize=False.
|
|
|
|
if n_words > target_size + tol_size:
|
|
|
|
if not discard_oversize:
|
2020-08-07 19:51:02 +03:00
|
|
|
yield [seq]
|
2020-08-04 16:09:37 +03:00
|
|
|
# add the example to the current batch if there's no overflow yet and it still fits
|
|
|
|
elif overflow_size == 0 and (batch_size + n_words) <= target_size:
|
2020-08-07 19:51:02 +03:00
|
|
|
batch.append(seq)
|
2020-08-04 16:09:37 +03:00
|
|
|
batch_size += n_words
|
|
|
|
# add the example to the overflow buffer if it fits in the tolerance margin
|
|
|
|
elif (batch_size + overflow_size + n_words) <= (target_size + tol_size):
|
2020-08-07 19:51:02 +03:00
|
|
|
overflow.append(seq)
|
2020-08-04 16:09:37 +03:00
|
|
|
overflow_size += n_words
|
|
|
|
# yield the previous batch and start a new one. The new one gets the overflow examples.
|
|
|
|
else:
|
|
|
|
if batch:
|
|
|
|
yield batch
|
|
|
|
target_size = next(size_)
|
|
|
|
tol_size = target_size * tolerance
|
|
|
|
batch = overflow
|
|
|
|
batch_size = overflow_size
|
|
|
|
overflow = []
|
|
|
|
overflow_size = 0
|
|
|
|
# this example still fits
|
|
|
|
if (batch_size + n_words) <= target_size:
|
2020-08-07 19:51:02 +03:00
|
|
|
batch.append(seq)
|
2020-08-04 16:09:37 +03:00
|
|
|
batch_size += n_words
|
|
|
|
# this example fits in overflow
|
|
|
|
elif (batch_size + n_words) <= (target_size + tol_size):
|
2020-08-07 19:51:02 +03:00
|
|
|
overflow.append(seq)
|
2020-08-04 16:09:37 +03:00
|
|
|
overflow_size += n_words
|
|
|
|
# this example does not fit with the previous overflow: start another new batch
|
|
|
|
else:
|
|
|
|
if batch:
|
|
|
|
yield batch
|
|
|
|
target_size = next(size_)
|
|
|
|
tol_size = target_size * tolerance
|
2020-08-07 19:51:02 +03:00
|
|
|
batch = [seq]
|
2020-08-04 16:09:37 +03:00
|
|
|
batch_size = n_words
|
|
|
|
batch.extend(overflow)
|
|
|
|
if batch:
|
|
|
|
yield batch
|
|
|
|
|
|
|
|
|
|
|
|
def _batch_by_length(
|
|
|
|
seqs: Sequence[Any], max_words: int, get_length=len
|
|
|
|
) -> List[List[Any]]:
|
|
|
|
"""Given a list of sequences, return a batched list of indices into the
|
|
|
|
list, where the batches are grouped by length, in descending order.
|
|
|
|
|
|
|
|
Batches may be at most max_words in size, defined as max sequence length * size.
|
|
|
|
"""
|
|
|
|
# Use negative index so we can get sort by position ascending.
|
|
|
|
lengths_indices = [(get_length(seq), i) for i, seq in enumerate(seqs)]
|
|
|
|
lengths_indices.sort()
|
|
|
|
batches = []
|
|
|
|
batch = []
|
|
|
|
for length, i in lengths_indices:
|
|
|
|
if not batch:
|
|
|
|
batch.append(i)
|
|
|
|
elif length * (len(batch) + 1) <= max_words:
|
|
|
|
batch.append(i)
|
|
|
|
else:
|
|
|
|
batches.append(batch)
|
|
|
|
batch = [i]
|
|
|
|
if batch:
|
|
|
|
batches.append(batch)
|
|
|
|
# Check lengths match
|
|
|
|
assert sum(len(b) for b in batches) == len(seqs)
|
|
|
|
batches = [list(sorted(batch)) for batch in batches]
|
|
|
|
batches.reverse()
|
|
|
|
return batches
|