From f5c4e0b751c9757c6ac9cc69c3b44035c26d1dda Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Fri, 7 Aug 2020 18:51:02 +0200 Subject: [PATCH] Add docstrings for batchers --- spacy/gold/batchers.py | 83 +++++++++++++++++++++++++++++++++++------- 1 file changed, 69 insertions(+), 14 deletions(-) diff --git a/spacy/gold/batchers.py b/spacy/gold/batchers.py index 57c6b4b3a..c15b88502 100644 --- a/spacy/gold/batchers.py +++ b/spacy/gold/batchers.py @@ -19,6 +19,22 @@ def configure_minibatch_by_padded_size( discard_oversize: bool, get_length: Optional[Callable[[ItemT], int]] = None ) -> BatcherT: + """Create a batcher that uses the `batch_by_padded_size` strategy. + + The padded size is defined as the maximum length of sequences within the + batch multiplied by the number of sequences in the batch. + + size (int or Iterable[int]): The largest padded size to batch sequences into. + Can be a single integer, or a sequence, allowing for variable batch sizes. + buffer (int): The number of sequences to accumulate before sorting by length. + A larger buffer will result in more even sizing, but if the buffer is + very large, the iteration order will be less random, which can result + in suboptimal training. + discard_oversize (bool): Whether to discard sequences that are by themselves + longer than the largest padded batch size. + get_length (Callable or None): Function to get the length of a sequence item. + The `len` function is used by default. + """ # Avoid displacing optional values from the underlying function. optionals = {"get_length": get_length} if get_length is not None else {} return partial( @@ -38,6 +54,16 @@ def configure_minibatch_by_words( discard_oversize: bool, get_length: Optional[Callable[[ItemT], int]] = None ) -> BatcherT: + """Create a batcher that uses the "minibatch by words" strategy. + + size (int or Iterable[int]): The target number of words per batch. + Can be a single integer, or a sequence, allowing for variable batch sizes. + tolerance (float): What percentage of the size to allow batches to exceed. + discard_oversize (bool): Whether to discard sequences that by themselves + exceed the tolerated size. + get_length (Callable or None): Function to get the length of a sequence + item. The `len` function is used by default. + """ optionals = {"get_length": get_length} if get_length is not None else {} return partial( minibatch_by_words, size=size, discard_oversize=discard_oversize, **optionals @@ -48,22 +74,43 @@ def configure_minibatch_by_words( def configure_minibatch( size: Sizing, get_length: Optional[Callable[[ItemT], int]] = None ) -> BatcherT: + """Create a batcher that creates batches of the specified size. + + size (int or Iterable[int]): The target number of items per batch. + Can be a single integer, or a sequence, allowing for variable batch sizes. + """ optionals = {"get_length": get_length} if get_length is not None else {} return partial(minibatch, size=size, **optionals) def minibatch_by_padded_size( - docs: Iterator["Doc"], + seqs: Iterable[ItemT], size: Sizing, buffer: int = 256, discard_oversize: bool = False, get_length: Callable = len, -) -> Iterator[Iterator["Doc"]]: +) -> Iterable[List[ItemT]]: + """Minibatch a sequence by the size of padded batches that would result, + with sequences binned by length within a window. + + The padded size is defined as the maximum length of sequences within the + batch multiplied by the number of sequences in the batch. + + size (int): The largest padded size to batch sequences into. + buffer (int): The number of sequences to accumulate before sorting by length. + A larger buffer will result in more even sizing, but if the buffer is + very large, the iteration order will be less random, which can result + in suboptimal training. + discard_oversize (bool): Whether to discard sequences that are by themselves + longer than the largest padded batch size. + get_length (Callable or None): Function to get the length of a sequence item. + The `len` function is used by default. + """ if isinstance(size, int): size_ = itertools.repeat(size) else: size_ = size - for outer_batch in minibatch(docs, size=buffer): + for outer_batch in minibatch(seqs, size=buffer): outer_batch = list(outer_batch) target_size = next(size_) for indices in _batch_by_length(outer_batch, target_size, get_length): @@ -76,12 +123,20 @@ def minibatch_by_padded_size( def minibatch_by_words( - docs, size, tolerance=0.2, discard_oversize=False, get_length=len -): + seqs: Iterable[ItemT], size: Sizing, tolerance=0.2, discard_oversize=False, get_length=len +) -> Iterable[List[ItemT]]: """Create minibatches of roughly a given number of words. If any examples are longer than the specified batch length, they will appear in a batch by themselves, or be discarded if discard_oversize=True. - The argument 'docs' can be a list of strings, Docs or Examples. + + seqs (Iterable[Sequence]): The sequences to minibatch. + size (int or Iterable[int]): The target number of words per batch. + Can be a single integer, or a sequence, allowing for variable batch sizes. + tolerance (float): What percentage of the size to allow batches to exceed. + discard_oversize (bool): Whether to discard sequences that by themselves + exceed the tolerated size. + get_length (Callable or None): Function to get the length of a sequence + item. The `len` function is used by default. """ if isinstance(size, int): size_ = itertools.repeat(size) @@ -95,20 +150,20 @@ def minibatch_by_words( overflow = [] batch_size = 0 overflow_size = 0 - for doc in docs: - n_words = get_length(doc) + for seq in seqs: + n_words = get_length(seq) # if the current example exceeds the maximum batch size, it is returned separately # but only if discard_oversize=False. if n_words > target_size + tol_size: if not discard_oversize: - yield [doc] + yield [seq] # add the example to the current batch if there's no overflow yet and it still fits elif overflow_size == 0 and (batch_size + n_words) <= target_size: - batch.append(doc) + batch.append(seq) batch_size += n_words # add the example to the overflow buffer if it fits in the tolerance margin elif (batch_size + overflow_size + n_words) <= (target_size + tol_size): - overflow.append(doc) + overflow.append(seq) overflow_size += n_words # yield the previous batch and start a new one. The new one gets the overflow examples. else: @@ -122,11 +177,11 @@ def minibatch_by_words( overflow_size = 0 # this example still fits if (batch_size + n_words) <= target_size: - batch.append(doc) + batch.append(seq) batch_size += n_words # this example fits in overflow elif (batch_size + n_words) <= (target_size + tol_size): - overflow.append(doc) + overflow.append(seq) overflow_size += n_words # this example does not fit with the previous overflow: start another new batch else: @@ -134,7 +189,7 @@ def minibatch_by_words( yield batch target_size = next(size_) tol_size = target_size * tolerance - batch = [doc] + batch = [seq] batch_size = n_words batch.extend(overflow) if batch: