diff --git a/spacy/training/batchers.py b/spacy/training/batchers.py index 73678c7fc..9dd26d765 100644 --- a/spacy/training/batchers.py +++ b/spacy/training/batchers.py @@ -2,12 +2,13 @@ from typing import Union, Iterable, Sequence, TypeVar, List, Callable, Iterator from typing import Optional, Any from functools import partial import itertools -from thinc.schedules import Schedule, constant as constant_schedule +from thinc.schedules import Schedule from ..util import registry, minibatch -Sizing = Union[Sequence[int], int, Schedule[int]] +SizingSchedule = Union[Iterable[int], int, Schedule] +Sizing = Union[Iterable[int], int] ItemT = TypeVar("ItemT") BatcherT = Callable[[Iterable[ItemT]], Iterable[List[ItemT]]] @@ -15,7 +16,7 @@ BatcherT = Callable[[Iterable[ItemT]], Iterable[List[ItemT]]] @registry.batchers("spacy.batch_by_padded.v1") def configure_minibatch_by_padded_size( *, - size: Sizing, + size: SizingSchedule, buffer: int, discard_oversize: bool, get_length: Optional[Callable[[ItemT], int]] = None @@ -25,7 +26,7 @@ def configure_minibatch_by_padded_size( The padded size is defined as the maximum length of sequences within the batch multiplied by the number of sequences in the batch. - size (int or Sequence[int]): The largest padded size to batch sequences into. + size (int or Iterable[int]): The largest padded size to batch sequences into. Can be a single integer, or a sequence, allowing for variable batch sizes. buffer (int): The number of sequences to accumulate before sorting by length. A larger buffer will result in more even sizing, but if the buffer is @@ -40,7 +41,7 @@ def configure_minibatch_by_padded_size( optionals = {"get_length": get_length} if get_length is not None else {} return partial( minibatch_by_padded_size, - size=size, + size=_schedule_to_sizing(size), buffer=buffer, discard_oversize=discard_oversize, **optionals @@ -50,14 +51,14 @@ def configure_minibatch_by_padded_size( @registry.batchers("spacy.batch_by_words.v1") def configure_minibatch_by_words( *, - size: Sizing, + size: SizingSchedule, tolerance: float, discard_oversize: bool, get_length: Optional[Callable[[ItemT], int]] = None ) -> BatcherT: """Create a batcher that uses the "minibatch by words" strategy. - size (int or Sequence[int]): The target number of words per batch. + size (int or Iterable[int]): The target number of words per batch. Can be a single integer, or a sequence, allowing for variable batch sizes. tolerance (float): What percentage of the size to allow batches to exceed. discard_oversize (bool): Whether to discard sequences that by themselves @@ -68,7 +69,7 @@ def configure_minibatch_by_words( optionals = {"get_length": get_length} if get_length is not None else {} return partial( minibatch_by_words, - size=size, + size=_schedule_to_sizing(size), tolerance=tolerance, discard_oversize=discard_oversize, **optionals @@ -77,15 +78,15 @@ def configure_minibatch_by_words( @registry.batchers("spacy.batch_by_sequence.v1") def configure_minibatch( - size: Sizing, get_length: Optional[Callable[[ItemT], int]] = None + size: SizingSchedule, get_length: Optional[Callable[[ItemT], int]] = None ) -> BatcherT: """Create a batcher that creates batches of the specified size. - size (int or Sequence[int]): The target number of items per batch. + size (int or Iterable[int]): The target number of items per batch. Can be a single integer, or a sequence, allowing for variable batch sizes. """ optionals = {"get_length": get_length} if get_length is not None else {} - return partial(minibatch, size=size, **optionals) + return partial(minibatch, size=_schedule_to_sizing(size), **optionals) def minibatch_by_padded_size( @@ -101,7 +102,7 @@ def minibatch_by_padded_size( The padded size is defined as the maximum length of sequences within the batch multiplied by the number of sequences in the batch. - size (int or Sequence[int]): The largest padded size to batch sequences into. + size (int or Iterable[int]): The largest padded size to batch sequences into. buffer (int): The number of sequences to accumulate before sorting by length. A larger buffer will result in more even sizing, but if the buffer is very large, the iteration order will be less random, which can result @@ -112,13 +113,12 @@ def minibatch_by_padded_size( The `len` function is used by default. """ if isinstance(size, int): - size_ = constant_schedule(size) + size_ = itertools.repeat(size) # type: Iterator[int] else: - assert isinstance(size, Schedule) - size_ = size - for step, outer_batch in enumerate(minibatch(seqs, size=buffer)): + size_ = iter(size) + for outer_batch in minibatch(seqs, size=buffer): outer_batch = list(outer_batch) - target_size = size_(step) + target_size = next(size_) for indices in _batch_by_length(outer_batch, target_size, get_length): subbatch = [outer_batch[i] for i in indices] padded_size = max(len(seq) for seq in subbatch) * len(subbatch) @@ -149,12 +149,10 @@ def minibatch_by_words( item. The `len` function is used by default. """ if isinstance(size, int): - size_ = constant_schedule(size) + size_ = itertools.repeat(size) # type: Iterator[int] else: - assert isinstance(size, Schedule) - size_ = size - step = 0 - target_size = size_(step) + size_ = iter(size) + target_size = next(size_) tol_size = target_size * tolerance batch = [] overflow = [] @@ -179,8 +177,7 @@ def minibatch_by_words( else: if batch: yield batch - step += 1 - target_size = size_(step) + target_size = next(size_) tol_size = target_size * tolerance batch = overflow batch_size = overflow_size @@ -198,8 +195,7 @@ def minibatch_by_words( else: if batch: yield batch - step += 1 - target_size = size_(step) + target_size = next(size_) tol_size = target_size * tolerance batch = [seq] batch_size = n_words @@ -236,3 +232,9 @@ def _batch_by_length( batches = [list(sorted(batch)) for batch in batches] batches.reverse() return batches + + +def _schedule_to_sizing(size: SizingSchedule) -> Sizing: + if isinstance(size, Schedule): + return size.to_generator() + return size