mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-26 13:41:21 +03:00 
			
		
		
		
	
		
			
				
	
	
		
			231 lines
		
	
	
		
			8.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			231 lines
		
	
	
		
			8.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from typing import Union, Iterable, Sequence, TypeVar, List, Callable
 | |
| from typing import Optional, Any
 | |
| from functools import partial
 | |
| import itertools
 | |
| 
 | |
| from ..util import registry, minibatch
 | |
| 
 | |
| 
 | |
| Sizing = Union[Iterable[int], int]
 | |
| ItemT = TypeVar("ItemT")
 | |
| BatcherT = Callable[[Iterable[ItemT]], Iterable[List[ItemT]]]
 | |
| 
 | |
| 
 | |
| @registry.batchers("spacy.batch_by_padded.v1")
 | |
| def configure_minibatch_by_padded_size(
 | |
|     *,
 | |
|     size: Sizing,
 | |
|     buffer: int,
 | |
|     discard_oversize: bool,
 | |
|     get_length: Optional[Callable[[ItemT], int]] = None
 | |
| ) -> BatcherT:
 | |
|     """Create a batcher that uses the `batch_by_padded_size` strategy.
 | |
| 
 | |
|     The padded size is defined as the maximum length of sequences within the
 | |
|     batch multiplied by the number of sequences in the batch.
 | |
| 
 | |
|     size (int or Iterable[int]): The largest padded size to batch sequences into.
 | |
|         Can be a single integer, or a sequence, allowing for variable batch sizes.
 | |
|     buffer (int): The number of sequences to accumulate before sorting by length.
 | |
|         A larger buffer will result in more even sizing, but if the buffer is
 | |
|         very large, the iteration order will be less random, which can result
 | |
|         in suboptimal training.
 | |
|     discard_oversize (bool): Whether to discard sequences that are by themselves
 | |
|         longer than the largest padded batch size.
 | |
|     get_length (Callable or None): Function to get the length of a sequence item.
 | |
|         The `len` function is used by default.
 | |
|     """
 | |
|     # Avoid displacing optional values from the underlying function.
 | |
|     optionals = {"get_length": get_length} if get_length is not None else {}
 | |
|     return partial(
 | |
|         minibatch_by_padded_size,
 | |
|         size=size,
 | |
|         buffer=buffer,
 | |
|         discard_oversize=discard_oversize,
 | |
|         **optionals
 | |
|     )
 | |
| 
 | |
| 
 | |
| @registry.batchers("spacy.batch_by_words.v1")
 | |
| def configure_minibatch_by_words(
 | |
|     *,
 | |
|     size: Sizing,
 | |
|     tolerance: float,
 | |
|     discard_oversize: bool,
 | |
|     get_length: Optional[Callable[[ItemT], int]] = None
 | |
| ) -> BatcherT:
 | |
|     """Create a batcher that uses the "minibatch by words" strategy.
 | |
| 
 | |
|     size (int or Iterable[int]): The target number of words per batch.
 | |
|         Can be a single integer, or a sequence, allowing for variable batch sizes.
 | |
|     tolerance (float): What percentage of the size to allow batches to exceed.
 | |
|     discard_oversize (bool): Whether to discard sequences that by themselves
 | |
|         exceed the tolerated size.
 | |
|     get_length (Callable or None): Function to get the length of a sequence
 | |
|         item. The `len` function is used by default.
 | |
|     """
 | |
|     optionals = {"get_length": get_length} if get_length is not None else {}
 | |
|     return partial(
 | |
|         minibatch_by_words, size=size, tolerance=tolerance, discard_oversize=discard_oversize, **optionals
 | |
|     )
 | |
| 
 | |
| 
 | |
| @registry.batchers("spacy.batch_by_sequence.v1")
 | |
| def configure_minibatch(
 | |
|     size: Sizing, get_length: Optional[Callable[[ItemT], int]] = None
 | |
| ) -> BatcherT:
 | |
|     """Create a batcher that creates batches of the specified size.
 | |
| 
 | |
|     size (int or Iterable[int]): The target number of items per batch.
 | |
|         Can be a single integer, or a sequence, allowing for variable batch sizes.
 | |
|     """
 | |
|     optionals = {"get_length": get_length} if get_length is not None else {}
 | |
|     return partial(minibatch, size=size, **optionals)
 | |
| 
 | |
| 
 | |
| def minibatch_by_padded_size(
 | |
|     seqs: Iterable[ItemT],
 | |
|     size: Sizing,
 | |
|     buffer: int = 256,
 | |
|     discard_oversize: bool = False,
 | |
|     get_length: Callable = len,
 | |
| ) -> Iterable[List[ItemT]]:
 | |
|     """Minibatch a sequence by the size of padded batches that would result,
 | |
|     with sequences binned by length within a window.
 | |
| 
 | |
|     The padded size is defined as the maximum length of sequences within the
 | |
|     batch multiplied by the number of sequences in the batch.
 | |
| 
 | |
|     size (int): The largest padded size to batch sequences into.
 | |
|     buffer (int): The number of sequences to accumulate before sorting by length.
 | |
|         A larger buffer will result in more even sizing, but if the buffer is
 | |
|         very large, the iteration order will be less random, which can result
 | |
|         in suboptimal training.
 | |
|     discard_oversize (bool): Whether to discard sequences that are by themselves
 | |
|         longer than the largest padded batch size.
 | |
|     get_length (Callable or None): Function to get the length of a sequence item.
 | |
|         The `len` function is used by default.
 | |
|     """
 | |
|     if isinstance(size, int):
 | |
|         size_ = itertools.repeat(size)
 | |
|     else:
 | |
|         size_ = size
 | |
|     for outer_batch in minibatch(seqs, size=buffer):
 | |
|         outer_batch = list(outer_batch)
 | |
|         target_size = next(size_)
 | |
|         for indices in _batch_by_length(outer_batch, target_size, get_length):
 | |
|             subbatch = [outer_batch[i] for i in indices]
 | |
|             padded_size = max(len(seq) for seq in subbatch) * len(subbatch)
 | |
|             if discard_oversize and padded_size >= target_size:
 | |
|                 pass
 | |
|             else:
 | |
|                 yield subbatch
 | |
| 
 | |
| 
 | |
| def minibatch_by_words(
 | |
|     seqs: Iterable[ItemT],
 | |
|     size: Sizing,
 | |
|     tolerance=0.2,
 | |
|     discard_oversize=False,
 | |
|     get_length=len,
 | |
| ) -> Iterable[List[ItemT]]:
 | |
|     """Create minibatches of roughly a given number of words. If any examples
 | |
|     are longer than the specified batch length, they will appear in a batch by
 | |
|     themselves, or be discarded if discard_oversize=True.
 | |
| 
 | |
|     seqs (Iterable[Sequence]): The sequences to minibatch.
 | |
|     size (int or Iterable[int]): The target number of words per batch.
 | |
|         Can be a single integer, or a sequence, allowing for variable batch sizes.
 | |
|     tolerance (float): What percentage of the size to allow batches to exceed.
 | |
|     discard_oversize (bool): Whether to discard sequences that by themselves
 | |
|         exceed the tolerated size.
 | |
|     get_length (Callable or None): Function to get the length of a sequence
 | |
|         item. The `len` function is used by default.
 | |
|     """
 | |
|     if isinstance(size, int):
 | |
|         size_ = itertools.repeat(size)
 | |
|     elif isinstance(size, List):
 | |
|         size_ = iter(size)
 | |
|     else:
 | |
|         size_ = size
 | |
|     target_size = next(size_)
 | |
|     tol_size = target_size * tolerance
 | |
|     batch = []
 | |
|     overflow = []
 | |
|     batch_size = 0
 | |
|     overflow_size = 0
 | |
|     for seq in seqs:
 | |
|         n_words = get_length(seq)
 | |
|         # if the current example exceeds the maximum batch size, it is returned separately
 | |
|         # but only if discard_oversize=False.
 | |
|         if n_words > target_size + tol_size:
 | |
|             if not discard_oversize:
 | |
|                 yield [seq]
 | |
|         # add the example to the current batch if there's no overflow yet and it still fits
 | |
|         elif overflow_size == 0 and (batch_size + n_words) <= target_size:
 | |
|             batch.append(seq)
 | |
|             batch_size += n_words
 | |
|         # add the example to the overflow buffer if it fits in the tolerance margin
 | |
|         elif (batch_size + overflow_size + n_words) <= (target_size + tol_size):
 | |
|             overflow.append(seq)
 | |
|             overflow_size += n_words
 | |
|         # yield the previous batch and start a new one. The new one gets the overflow examples.
 | |
|         else:
 | |
|             if batch:
 | |
|                 yield batch
 | |
|             target_size = next(size_)
 | |
|             tol_size = target_size * tolerance
 | |
|             batch = overflow
 | |
|             batch_size = overflow_size
 | |
|             overflow = []
 | |
|             overflow_size = 0
 | |
|             # this example still fits
 | |
|             if (batch_size + n_words) <= target_size:
 | |
|                 batch.append(seq)
 | |
|                 batch_size += n_words
 | |
|             # this example fits in overflow
 | |
|             elif (batch_size + n_words) <= (target_size + tol_size):
 | |
|                 overflow.append(seq)
 | |
|                 overflow_size += n_words
 | |
|             # this example does not fit with the previous overflow: start another new batch
 | |
|             else:
 | |
|                 if batch:
 | |
|                     yield batch
 | |
|                 target_size = next(size_)
 | |
|                 tol_size = target_size * tolerance
 | |
|                 batch = [seq]
 | |
|                 batch_size = n_words
 | |
|     batch.extend(overflow)
 | |
|     if batch:
 | |
|         yield batch
 | |
| 
 | |
| 
 | |
| def _batch_by_length(
 | |
|     seqs: Sequence[Any], max_words: int, get_length=len
 | |
| ) -> List[List[Any]]:
 | |
|     """Given a list of sequences, return a batched list of indices into the
 | |
|     list, where the batches are grouped by length, in descending order.
 | |
| 
 | |
|     Batches may be at most max_words in size, defined as max sequence length * size.
 | |
|     """
 | |
|     # Use negative index so we can get sort by position ascending.
 | |
|     lengths_indices = [(get_length(seq), i) for i, seq in enumerate(seqs)]
 | |
|     lengths_indices.sort()
 | |
|     batches = []
 | |
|     batch = []
 | |
|     for length, i in lengths_indices:
 | |
|         if not batch:
 | |
|             batch.append(i)
 | |
|         elif length * (len(batch) + 1) <= max_words:
 | |
|             batch.append(i)
 | |
|         else:
 | |
|             batches.append(batch)
 | |
|             batch = [i]
 | |
|     if batch:
 | |
|         batches.append(batch)
 | |
|     # Check lengths match
 | |
|     assert sum(len(b) for b in batches) == len(seqs)
 | |
|     batches = [list(sorted(batch)) for batch in batches]
 | |
|     batches.reverse()
 | |
|     return batches
 |