mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 07:57:35 +03:00 
			
		
		
		
	Merge pull request #5895 from explosion/docs/batchers
Draft docstrings for batchers
This commit is contained in:
		
						commit
						fd20f84927
					
				|  | @ -19,6 +19,22 @@ def configure_minibatch_by_padded_size( | ||||||
|     discard_oversize: bool, |     discard_oversize: bool, | ||||||
|     get_length: Optional[Callable[[ItemT], int]] = None |     get_length: Optional[Callable[[ItemT], int]] = None | ||||||
| ) -> BatcherT: | ) -> BatcherT: | ||||||
|  |     """Create a batcher that uses the `batch_by_padded_size` strategy. | ||||||
|  |      | ||||||
|  |     The padded size is defined as the maximum length of sequences within the | ||||||
|  |     batch multiplied by the number of sequences in the batch. | ||||||
|  | 
 | ||||||
|  |     size (int or Iterable[int]): The largest padded size to batch sequences into. | ||||||
|  |         Can be a single integer, or a sequence, allowing for variable batch sizes. | ||||||
|  |     buffer (int): The number of sequences to accumulate before sorting by length. | ||||||
|  |         A larger buffer will result in more even sizing, but if the buffer is | ||||||
|  |         very large, the iteration order will be less random, which can result | ||||||
|  |         in suboptimal training. | ||||||
|  |     discard_oversize (bool): Whether to discard sequences that are by themselves | ||||||
|  |         longer than the largest padded batch size. | ||||||
|  |     get_length (Callable or None): Function to get the length of a sequence item. | ||||||
|  |         The `len` function is used by default. | ||||||
|  |     """ | ||||||
|     # Avoid displacing optional values from the underlying function. |     # Avoid displacing optional values from the underlying function. | ||||||
|     optionals = {"get_length": get_length} if get_length is not None else {} |     optionals = {"get_length": get_length} if get_length is not None else {} | ||||||
|     return partial( |     return partial( | ||||||
|  | @ -38,6 +54,16 @@ def configure_minibatch_by_words( | ||||||
|     discard_oversize: bool, |     discard_oversize: bool, | ||||||
|     get_length: Optional[Callable[[ItemT], int]] = None |     get_length: Optional[Callable[[ItemT], int]] = None | ||||||
| ) -> BatcherT: | ) -> BatcherT: | ||||||
|  |     """Create a batcher that uses the "minibatch by words" strategy. | ||||||
|  | 
 | ||||||
|  |     size (int or Iterable[int]): The target number of words per batch. | ||||||
|  |         Can be a single integer, or a sequence, allowing for variable batch sizes. | ||||||
|  |     tolerance (float): What percentage of the size to allow batches to exceed. | ||||||
|  |     discard_oversize (bool): Whether to discard sequences that by themselves | ||||||
|  |         exceed the tolerated size. | ||||||
|  |     get_length (Callable or None): Function to get the length of a sequence | ||||||
|  |         item. The `len` function is used by default. | ||||||
|  |     """ | ||||||
|     optionals = {"get_length": get_length} if get_length is not None else {} |     optionals = {"get_length": get_length} if get_length is not None else {} | ||||||
|     return partial( |     return partial( | ||||||
|         minibatch_by_words, size=size, discard_oversize=discard_oversize, **optionals |         minibatch_by_words, size=size, discard_oversize=discard_oversize, **optionals | ||||||
|  | @ -48,22 +74,43 @@ def configure_minibatch_by_words( | ||||||
| def configure_minibatch( | def configure_minibatch( | ||||||
|     size: Sizing, get_length: Optional[Callable[[ItemT], int]] = None |     size: Sizing, get_length: Optional[Callable[[ItemT], int]] = None | ||||||
| ) -> BatcherT: | ) -> BatcherT: | ||||||
|  |     """Create a batcher that creates batches of the specified size. | ||||||
|  | 
 | ||||||
|  |     size (int or Iterable[int]): The target number of items per batch. | ||||||
|  |         Can be a single integer, or a sequence, allowing for variable batch sizes. | ||||||
|  |     """ | ||||||
|     optionals = {"get_length": get_length} if get_length is not None else {} |     optionals = {"get_length": get_length} if get_length is not None else {} | ||||||
|     return partial(minibatch, size=size, **optionals) |     return partial(minibatch, size=size, **optionals) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def minibatch_by_padded_size( | def minibatch_by_padded_size( | ||||||
|     docs: Iterator["Doc"], |     seqs: Iterable[ItemT], | ||||||
|     size: Sizing, |     size: Sizing, | ||||||
|     buffer: int = 256, |     buffer: int = 256, | ||||||
|     discard_oversize: bool = False, |     discard_oversize: bool = False, | ||||||
|     get_length: Callable = len, |     get_length: Callable = len, | ||||||
| ) -> Iterator[Iterator["Doc"]]: | ) -> Iterable[List[ItemT]]: | ||||||
|  |     """Minibatch a sequence by the size of padded batches that would result, | ||||||
|  |     with sequences binned by length within a window. | ||||||
|  |      | ||||||
|  |     The padded size is defined as the maximum length of sequences within the | ||||||
|  |     batch multiplied by the number of sequences in the batch. | ||||||
|  | 
 | ||||||
|  |     size (int): The largest padded size to batch sequences into. | ||||||
|  |     buffer (int): The number of sequences to accumulate before sorting by length. | ||||||
|  |         A larger buffer will result in more even sizing, but if the buffer is | ||||||
|  |         very large, the iteration order will be less random, which can result | ||||||
|  |         in suboptimal training. | ||||||
|  |     discard_oversize (bool): Whether to discard sequences that are by themselves | ||||||
|  |         longer than the largest padded batch size. | ||||||
|  |     get_length (Callable or None): Function to get the length of a sequence item. | ||||||
|  |         The `len` function is used by default. | ||||||
|  |     """ | ||||||
|     if isinstance(size, int): |     if isinstance(size, int): | ||||||
|         size_ = itertools.repeat(size) |         size_ = itertools.repeat(size) | ||||||
|     else: |     else: | ||||||
|         size_ = size |         size_ = size | ||||||
|     for outer_batch in minibatch(docs, size=buffer): |     for outer_batch in minibatch(seqs, size=buffer): | ||||||
|         outer_batch = list(outer_batch) |         outer_batch = list(outer_batch) | ||||||
|         target_size = next(size_) |         target_size = next(size_) | ||||||
|         for indices in _batch_by_length(outer_batch, target_size, get_length): |         for indices in _batch_by_length(outer_batch, target_size, get_length): | ||||||
|  | @ -76,12 +123,20 @@ def minibatch_by_padded_size( | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def minibatch_by_words( | def minibatch_by_words( | ||||||
|     docs, size, tolerance=0.2, discard_oversize=False, get_length=len |     seqs: Iterable[ItemT], size: Sizing, tolerance=0.2, discard_oversize=False, get_length=len | ||||||
| ): | ) -> Iterable[List[ItemT]]: | ||||||
|     """Create minibatches of roughly a given number of words. If any examples |     """Create minibatches of roughly a given number of words. If any examples | ||||||
|     are longer than the specified batch length, they will appear in a batch by |     are longer than the specified batch length, they will appear in a batch by | ||||||
|     themselves, or be discarded if discard_oversize=True. |     themselves, or be discarded if discard_oversize=True. | ||||||
|     The argument 'docs' can be a list of strings, Docs or Examples. | 
 | ||||||
|  |     seqs (Iterable[Sequence]): The sequences to minibatch. | ||||||
|  |     size (int or Iterable[int]): The target number of words per batch. | ||||||
|  |         Can be a single integer, or a sequence, allowing for variable batch sizes. | ||||||
|  |     tolerance (float): What percentage of the size to allow batches to exceed. | ||||||
|  |     discard_oversize (bool): Whether to discard sequences that by themselves | ||||||
|  |         exceed the tolerated size. | ||||||
|  |     get_length (Callable or None): Function to get the length of a sequence | ||||||
|  |         item. The `len` function is used by default. | ||||||
|     """ |     """ | ||||||
|     if isinstance(size, int): |     if isinstance(size, int): | ||||||
|         size_ = itertools.repeat(size) |         size_ = itertools.repeat(size) | ||||||
|  | @ -95,20 +150,20 @@ def minibatch_by_words( | ||||||
|     overflow = [] |     overflow = [] | ||||||
|     batch_size = 0 |     batch_size = 0 | ||||||
|     overflow_size = 0 |     overflow_size = 0 | ||||||
|     for doc in docs: |     for seq in seqs: | ||||||
|         n_words = get_length(doc) |         n_words = get_length(seq) | ||||||
|         # if the current example exceeds the maximum batch size, it is returned separately |         # if the current example exceeds the maximum batch size, it is returned separately | ||||||
|         # but only if discard_oversize=False. |         # but only if discard_oversize=False. | ||||||
|         if n_words > target_size + tol_size: |         if n_words > target_size + tol_size: | ||||||
|             if not discard_oversize: |             if not discard_oversize: | ||||||
|                 yield [doc] |                 yield [seq] | ||||||
|         # add the example to the current batch if there's no overflow yet and it still fits |         # add the example to the current batch if there's no overflow yet and it still fits | ||||||
|         elif overflow_size == 0 and (batch_size + n_words) <= target_size: |         elif overflow_size == 0 and (batch_size + n_words) <= target_size: | ||||||
|             batch.append(doc) |             batch.append(seq) | ||||||
|             batch_size += n_words |             batch_size += n_words | ||||||
|         # add the example to the overflow buffer if it fits in the tolerance margin |         # add the example to the overflow buffer if it fits in the tolerance margin | ||||||
|         elif (batch_size + overflow_size + n_words) <= (target_size + tol_size): |         elif (batch_size + overflow_size + n_words) <= (target_size + tol_size): | ||||||
|             overflow.append(doc) |             overflow.append(seq) | ||||||
|             overflow_size += n_words |             overflow_size += n_words | ||||||
|         # yield the previous batch and start a new one. The new one gets the overflow examples. |         # yield the previous batch and start a new one. The new one gets the overflow examples. | ||||||
|         else: |         else: | ||||||
|  | @ -122,11 +177,11 @@ def minibatch_by_words( | ||||||
|             overflow_size = 0 |             overflow_size = 0 | ||||||
|             # this example still fits |             # this example still fits | ||||||
|             if (batch_size + n_words) <= target_size: |             if (batch_size + n_words) <= target_size: | ||||||
|                 batch.append(doc) |                 batch.append(seq) | ||||||
|                 batch_size += n_words |                 batch_size += n_words | ||||||
|             # this example fits in overflow |             # this example fits in overflow | ||||||
|             elif (batch_size + n_words) <= (target_size + tol_size): |             elif (batch_size + n_words) <= (target_size + tol_size): | ||||||
|                 overflow.append(doc) |                 overflow.append(seq) | ||||||
|                 overflow_size += n_words |                 overflow_size += n_words | ||||||
|             # this example does not fit with the previous overflow: start another new batch |             # this example does not fit with the previous overflow: start another new batch | ||||||
|             else: |             else: | ||||||
|  | @ -134,7 +189,7 @@ def minibatch_by_words( | ||||||
|                     yield batch |                     yield batch | ||||||
|                 target_size = next(size_) |                 target_size = next(size_) | ||||||
|                 tol_size = target_size * tolerance |                 tol_size = target_size * tolerance | ||||||
|                 batch = [doc] |                 batch = [seq] | ||||||
|                 batch_size = n_words |                 batch_size = n_words | ||||||
|     batch.extend(overflow) |     batch.extend(overflow) | ||||||
|     if batch: |     if batch: | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user