from typing import Union, Iterator, Iterable, Sequence, TypeVar, List, Callable from typing import Optional, Any from functools import partial import itertools from ..util import registry, minibatch Sizing = Union[Iterable[int], int] ItemT = TypeVar("ItemT") BatcherT = Callable[[Iterable[ItemT]], Iterable[List[ItemT]]] @registry.batchers("batch_by_padded.v1") def configure_minibatch_by_padded_size( *, size: Sizing, buffer: int, discard_oversize: bool, get_length: Optional[Callable[[ItemT], int]] = None ) -> BatcherT: # Avoid displacing optional values from the underlying function. optionals = {"get_length": get_length} if get_length is not None else {} return partial( minibatch_by_padded_size, size=size, buffer=buffer, discard_oversize=discard_oversize, **optionals ) @registry.batchers("batch_by_words.v1") def configure_minibatch_by_words( *, size: Sizing, tolerance: float, discard_oversize: bool, get_length: Optional[Callable[[ItemT], int]] = None ) -> BatcherT: optionals = {"get_length": get_length} if get_length is not None else {} return partial( minibatch_by_words, size=size, discard_oversize=discard_oversize, **optionals ) @registry.batchers("batch_by_sequence.v1") def configure_minibatch( size: Sizing, get_length: Optional[Callable[[ItemT], int]] = None ) -> BatcherT: optionals = {"get_length": get_length} if get_length is not None else {} return partial(minibatch, size=size, **optionals) def minibatch_by_padded_size( docs: Iterator["Doc"], size: Sizing, buffer: int = 256, discard_oversize: bool = False, get_length: Callable = len, ) -> Iterator[Iterator["Doc"]]: if isinstance(size, int): size_ = itertools.repeat(size) else: size_ = size for outer_batch in minibatch(docs, size=buffer): outer_batch = list(outer_batch) target_size = next(size_) for indices in _batch_by_length(outer_batch, target_size, get_length): subbatch = [outer_batch[i] for i in indices] padded_size = max(len(seq) for seq in subbatch) * len(subbatch) if discard_oversize and padded_size >= target_size: pass else: yield subbatch def minibatch_by_words( docs, size, tolerance=0.2, discard_oversize=False, get_length=len ): """Create minibatches of roughly a given number of words. If any examples are longer than the specified batch length, they will appear in a batch by themselves, or be discarded if discard_oversize=True. The argument 'docs' can be a list of strings, Docs or Examples. """ if isinstance(size, int): size_ = itertools.repeat(size) elif isinstance(size, List): size_ = iter(size) else: size_ = size target_size = next(size_) tol_size = target_size * tolerance batch = [] overflow = [] batch_size = 0 overflow_size = 0 for doc in docs: n_words = get_length(doc) # if the current example exceeds the maximum batch size, it is returned separately # but only if discard_oversize=False. if n_words > target_size + tol_size: if not discard_oversize: yield [doc] # add the example to the current batch if there's no overflow yet and it still fits elif overflow_size == 0 and (batch_size + n_words) <= target_size: batch.append(doc) batch_size += n_words # add the example to the overflow buffer if it fits in the tolerance margin elif (batch_size + overflow_size + n_words) <= (target_size + tol_size): overflow.append(doc) overflow_size += n_words # yield the previous batch and start a new one. The new one gets the overflow examples. else: if batch: yield batch target_size = next(size_) tol_size = target_size * tolerance batch = overflow batch_size = overflow_size overflow = [] overflow_size = 0 # this example still fits if (batch_size + n_words) <= target_size: batch.append(doc) batch_size += n_words # this example fits in overflow elif (batch_size + n_words) <= (target_size + tol_size): overflow.append(doc) overflow_size += n_words # this example does not fit with the previous overflow: start another new batch else: if batch: yield batch target_size = next(size_) tol_size = target_size * tolerance batch = [doc] batch_size = n_words batch.extend(overflow) if batch: yield batch def _batch_by_length( seqs: Sequence[Any], max_words: int, get_length=len ) -> List[List[Any]]: """Given a list of sequences, return a batched list of indices into the list, where the batches are grouped by length, in descending order. Batches may be at most max_words in size, defined as max sequence length * size. """ # Use negative index so we can get sort by position ascending. lengths_indices = [(get_length(seq), i) for i, seq in enumerate(seqs)] lengths_indices.sort() batches = [] batch = [] for length, i in lengths_indices: if not batch: batch.append(i) elif length * (len(batch) + 1) <= max_words: batch.append(i) else: batches.append(batch) batch = [i] if batch: batches.append(batch) # Check lengths match assert sum(len(b) for b in batches) == len(seqs) batches = [list(sorted(batch)) for batch in batches] batches.reverse() return batches