mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 09:57:26 +03:00 
			
		
		
		
	* version bump to 3.0.0a16 * rename "gold" folder to "training" * rename 'annotation_setter' to 'set_extra_annotations' * formatting
		
			
				
	
	
		
			231 lines
		
	
	
		
			8.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			231 lines
		
	
	
		
			8.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
from typing import Union, Iterable, Sequence, TypeVar, List, Callable
 | 
						|
from typing import Optional, Any
 | 
						|
from functools import partial
 | 
						|
import itertools
 | 
						|
 | 
						|
from ..util import registry, minibatch
 | 
						|
 | 
						|
 | 
						|
Sizing = Union[Iterable[int], int]
 | 
						|
ItemT = TypeVar("ItemT")
 | 
						|
BatcherT = Callable[[Iterable[ItemT]], Iterable[List[ItemT]]]
 | 
						|
 | 
						|
 | 
						|
@registry.batchers("spacy.batch_by_padded.v1")
 | 
						|
def configure_minibatch_by_padded_size(
 | 
						|
    *,
 | 
						|
    size: Sizing,
 | 
						|
    buffer: int,
 | 
						|
    discard_oversize: bool,
 | 
						|
    get_length: Optional[Callable[[ItemT], int]] = None
 | 
						|
) -> BatcherT:
 | 
						|
    """Create a batcher that uses the `batch_by_padded_size` strategy.
 | 
						|
 | 
						|
    The padded size is defined as the maximum length of sequences within the
 | 
						|
    batch multiplied by the number of sequences in the batch.
 | 
						|
 | 
						|
    size (int or Iterable[int]): The largest padded size to batch sequences into.
 | 
						|
        Can be a single integer, or a sequence, allowing for variable batch sizes.
 | 
						|
    buffer (int): The number of sequences to accumulate before sorting by length.
 | 
						|
        A larger buffer will result in more even sizing, but if the buffer is
 | 
						|
        very large, the iteration order will be less random, which can result
 | 
						|
        in suboptimal training.
 | 
						|
    discard_oversize (bool): Whether to discard sequences that are by themselves
 | 
						|
        longer than the largest padded batch size.
 | 
						|
    get_length (Callable or None): Function to get the length of a sequence item.
 | 
						|
        The `len` function is used by default.
 | 
						|
    """
 | 
						|
    # Avoid displacing optional values from the underlying function.
 | 
						|
    optionals = {"get_length": get_length} if get_length is not None else {}
 | 
						|
    return partial(
 | 
						|
        minibatch_by_padded_size,
 | 
						|
        size=size,
 | 
						|
        buffer=buffer,
 | 
						|
        discard_oversize=discard_oversize,
 | 
						|
        **optionals
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
@registry.batchers("spacy.batch_by_words.v1")
 | 
						|
def configure_minibatch_by_words(
 | 
						|
    *,
 | 
						|
    size: Sizing,
 | 
						|
    tolerance: float,
 | 
						|
    discard_oversize: bool,
 | 
						|
    get_length: Optional[Callable[[ItemT], int]] = None
 | 
						|
) -> BatcherT:
 | 
						|
    """Create a batcher that uses the "minibatch by words" strategy.
 | 
						|
 | 
						|
    size (int or Iterable[int]): The target number of words per batch.
 | 
						|
        Can be a single integer, or a sequence, allowing for variable batch sizes.
 | 
						|
    tolerance (float): What percentage of the size to allow batches to exceed.
 | 
						|
    discard_oversize (bool): Whether to discard sequences that by themselves
 | 
						|
        exceed the tolerated size.
 | 
						|
    get_length (Callable or None): Function to get the length of a sequence
 | 
						|
        item. The `len` function is used by default.
 | 
						|
    """
 | 
						|
    optionals = {"get_length": get_length} if get_length is not None else {}
 | 
						|
    return partial(
 | 
						|
        minibatch_by_words, size=size, discard_oversize=discard_oversize, **optionals
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
@registry.batchers("spacy.batch_by_sequence.v1")
 | 
						|
def configure_minibatch(
 | 
						|
    size: Sizing, get_length: Optional[Callable[[ItemT], int]] = None
 | 
						|
) -> BatcherT:
 | 
						|
    """Create a batcher that creates batches of the specified size.
 | 
						|
 | 
						|
    size (int or Iterable[int]): The target number of items per batch.
 | 
						|
        Can be a single integer, or a sequence, allowing for variable batch sizes.
 | 
						|
    """
 | 
						|
    optionals = {"get_length": get_length} if get_length is not None else {}
 | 
						|
    return partial(minibatch, size=size, **optionals)
 | 
						|
 | 
						|
 | 
						|
def minibatch_by_padded_size(
 | 
						|
    seqs: Iterable[ItemT],
 | 
						|
    size: Sizing,
 | 
						|
    buffer: int = 256,
 | 
						|
    discard_oversize: bool = False,
 | 
						|
    get_length: Callable = len,
 | 
						|
) -> Iterable[List[ItemT]]:
 | 
						|
    """Minibatch a sequence by the size of padded batches that would result,
 | 
						|
    with sequences binned by length within a window.
 | 
						|
 | 
						|
    The padded size is defined as the maximum length of sequences within the
 | 
						|
    batch multiplied by the number of sequences in the batch.
 | 
						|
 | 
						|
    size (int): The largest padded size to batch sequences into.
 | 
						|
    buffer (int): The number of sequences to accumulate before sorting by length.
 | 
						|
        A larger buffer will result in more even sizing, but if the buffer is
 | 
						|
        very large, the iteration order will be less random, which can result
 | 
						|
        in suboptimal training.
 | 
						|
    discard_oversize (bool): Whether to discard sequences that are by themselves
 | 
						|
        longer than the largest padded batch size.
 | 
						|
    get_length (Callable or None): Function to get the length of a sequence item.
 | 
						|
        The `len` function is used by default.
 | 
						|
    """
 | 
						|
    if isinstance(size, int):
 | 
						|
        size_ = itertools.repeat(size)
 | 
						|
    else:
 | 
						|
        size_ = size
 | 
						|
    for outer_batch in minibatch(seqs, size=buffer):
 | 
						|
        outer_batch = list(outer_batch)
 | 
						|
        target_size = next(size_)
 | 
						|
        for indices in _batch_by_length(outer_batch, target_size, get_length):
 | 
						|
            subbatch = [outer_batch[i] for i in indices]
 | 
						|
            padded_size = max(len(seq) for seq in subbatch) * len(subbatch)
 | 
						|
            if discard_oversize and padded_size >= target_size:
 | 
						|
                pass
 | 
						|
            else:
 | 
						|
                yield subbatch
 | 
						|
 | 
						|
 | 
						|
def minibatch_by_words(
 | 
						|
    seqs: Iterable[ItemT],
 | 
						|
    size: Sizing,
 | 
						|
    tolerance=0.2,
 | 
						|
    discard_oversize=False,
 | 
						|
    get_length=len,
 | 
						|
) -> Iterable[List[ItemT]]:
 | 
						|
    """Create minibatches of roughly a given number of words. If any examples
 | 
						|
    are longer than the specified batch length, they will appear in a batch by
 | 
						|
    themselves, or be discarded if discard_oversize=True.
 | 
						|
 | 
						|
    seqs (Iterable[Sequence]): The sequences to minibatch.
 | 
						|
    size (int or Iterable[int]): The target number of words per batch.
 | 
						|
        Can be a single integer, or a sequence, allowing for variable batch sizes.
 | 
						|
    tolerance (float): What percentage of the size to allow batches to exceed.
 | 
						|
    discard_oversize (bool): Whether to discard sequences that by themselves
 | 
						|
        exceed the tolerated size.
 | 
						|
    get_length (Callable or None): Function to get the length of a sequence
 | 
						|
        item. The `len` function is used by default.
 | 
						|
    """
 | 
						|
    if isinstance(size, int):
 | 
						|
        size_ = itertools.repeat(size)
 | 
						|
    elif isinstance(size, List):
 | 
						|
        size_ = iter(size)
 | 
						|
    else:
 | 
						|
        size_ = size
 | 
						|
    target_size = next(size_)
 | 
						|
    tol_size = target_size * tolerance
 | 
						|
    batch = []
 | 
						|
    overflow = []
 | 
						|
    batch_size = 0
 | 
						|
    overflow_size = 0
 | 
						|
    for seq in seqs:
 | 
						|
        n_words = get_length(seq)
 | 
						|
        # if the current example exceeds the maximum batch size, it is returned separately
 | 
						|
        # but only if discard_oversize=False.
 | 
						|
        if n_words > target_size + tol_size:
 | 
						|
            if not discard_oversize:
 | 
						|
                yield [seq]
 | 
						|
        # add the example to the current batch if there's no overflow yet and it still fits
 | 
						|
        elif overflow_size == 0 and (batch_size + n_words) <= target_size:
 | 
						|
            batch.append(seq)
 | 
						|
            batch_size += n_words
 | 
						|
        # add the example to the overflow buffer if it fits in the tolerance margin
 | 
						|
        elif (batch_size + overflow_size + n_words) <= (target_size + tol_size):
 | 
						|
            overflow.append(seq)
 | 
						|
            overflow_size += n_words
 | 
						|
        # yield the previous batch and start a new one. The new one gets the overflow examples.
 | 
						|
        else:
 | 
						|
            if batch:
 | 
						|
                yield batch
 | 
						|
            target_size = next(size_)
 | 
						|
            tol_size = target_size * tolerance
 | 
						|
            batch = overflow
 | 
						|
            batch_size = overflow_size
 | 
						|
            overflow = []
 | 
						|
            overflow_size = 0
 | 
						|
            # this example still fits
 | 
						|
            if (batch_size + n_words) <= target_size:
 | 
						|
                batch.append(seq)
 | 
						|
                batch_size += n_words
 | 
						|
            # this example fits in overflow
 | 
						|
            elif (batch_size + n_words) <= (target_size + tol_size):
 | 
						|
                overflow.append(seq)
 | 
						|
                overflow_size += n_words
 | 
						|
            # this example does not fit with the previous overflow: start another new batch
 | 
						|
            else:
 | 
						|
                if batch:
 | 
						|
                    yield batch
 | 
						|
                target_size = next(size_)
 | 
						|
                tol_size = target_size * tolerance
 | 
						|
                batch = [seq]
 | 
						|
                batch_size = n_words
 | 
						|
    batch.extend(overflow)
 | 
						|
    if batch:
 | 
						|
        yield batch
 | 
						|
 | 
						|
 | 
						|
def _batch_by_length(
 | 
						|
    seqs: Sequence[Any], max_words: int, get_length=len
 | 
						|
) -> List[List[Any]]:
 | 
						|
    """Given a list of sequences, return a batched list of indices into the
 | 
						|
    list, where the batches are grouped by length, in descending order.
 | 
						|
 | 
						|
    Batches may be at most max_words in size, defined as max sequence length * size.
 | 
						|
    """
 | 
						|
    # Use negative index so we can get sort by position ascending.
 | 
						|
    lengths_indices = [(get_length(seq), i) for i, seq in enumerate(seqs)]
 | 
						|
    lengths_indices.sort()
 | 
						|
    batches = []
 | 
						|
    batch = []
 | 
						|
    for length, i in lengths_indices:
 | 
						|
        if not batch:
 | 
						|
            batch.append(i)
 | 
						|
        elif length * (len(batch) + 1) <= max_words:
 | 
						|
            batch.append(i)
 | 
						|
        else:
 | 
						|
            batches.append(batch)
 | 
						|
            batch = [i]
 | 
						|
    if batch:
 | 
						|
        batches.append(batch)
 | 
						|
    # Check lengths match
 | 
						|
    assert sum(len(b) for b in batches) == len(seqs)
 | 
						|
    batches = [list(sorted(batch)) for batch in batches]
 | 
						|
    batches.reverse()
 | 
						|
    return batches
 |