mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 09:57:26 +03:00 
			
		
		
		
	Merge remote-tracking branch 'upstream/develop' into feature/pretrain-config
This commit is contained in:
		
						commit
						6504b7f161
					
				
							
								
								
									
										59
									
								
								spacy/tests/test_util.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										59
									
								
								spacy/tests/test_util.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,59 @@
 | 
				
			||||||
 | 
					import pytest
 | 
				
			||||||
 | 
					from spacy.gold import Example
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .util import get_random_doc
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from spacy.util import minibatch_by_words
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.parametrize(
 | 
				
			||||||
 | 
					    "doc_sizes, expected_batches",
 | 
				
			||||||
 | 
					    [
 | 
				
			||||||
 | 
					        ([400, 400, 199], [3]),
 | 
				
			||||||
 | 
					        ([400, 400, 199, 3], [4]),
 | 
				
			||||||
 | 
					        ([400, 400, 199, 3, 200], [3, 2]),
 | 
				
			||||||
 | 
					        ([400, 400, 199, 3, 1], [5]),
 | 
				
			||||||
 | 
					        ([400, 400, 199, 3, 1, 1500], [5]),    # 1500 will be discarded
 | 
				
			||||||
 | 
					        ([400, 400, 199, 3, 1, 200], [3, 3]),
 | 
				
			||||||
 | 
					        ([400, 400, 199, 3, 1, 999], [3, 3]),
 | 
				
			||||||
 | 
					        ([400, 400, 199, 3, 1, 999, 999], [3, 2, 1, 1]),
 | 
				
			||||||
 | 
					        ([1, 2, 999], [3]),
 | 
				
			||||||
 | 
					        ([1, 2, 999, 1], [4]),
 | 
				
			||||||
 | 
					        ([1, 200, 999, 1], [2, 2]),
 | 
				
			||||||
 | 
					        ([1, 999, 200, 1], [2, 2]),
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					def test_util_minibatch(doc_sizes, expected_batches):
 | 
				
			||||||
 | 
					    docs = [get_random_doc(doc_size) for doc_size in doc_sizes]
 | 
				
			||||||
 | 
					    examples = [Example(doc=doc) for doc in docs]
 | 
				
			||||||
 | 
					    tol = 0.2
 | 
				
			||||||
 | 
					    batch_size = 1000
 | 
				
			||||||
 | 
					    batches = list(minibatch_by_words(examples=examples, size=batch_size, tolerance=tol, discard_oversize=True))
 | 
				
			||||||
 | 
					    assert [len(batch) for batch in batches] == expected_batches
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    max_size = batch_size + batch_size * tol
 | 
				
			||||||
 | 
					    for batch in batches:
 | 
				
			||||||
 | 
					        assert sum([len(example.doc) for example in batch]) < max_size
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.parametrize(
 | 
				
			||||||
 | 
					    "doc_sizes, expected_batches",
 | 
				
			||||||
 | 
					    [
 | 
				
			||||||
 | 
					        ([400, 4000, 199], [1, 2]),
 | 
				
			||||||
 | 
					        ([400, 400, 199, 3000, 200], [1, 4]),
 | 
				
			||||||
 | 
					        ([400, 400, 199, 3, 1, 1500], [1, 5]),
 | 
				
			||||||
 | 
					        ([400, 400, 199, 3000, 2000, 200, 200], [1, 1, 3, 2]),
 | 
				
			||||||
 | 
					        ([1, 2, 9999], [1, 2]),
 | 
				
			||||||
 | 
					        ([2000, 1, 2000, 1, 1, 1, 2000], [1, 1, 1, 4]),
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					def test_util_minibatch_oversize(doc_sizes, expected_batches):
 | 
				
			||||||
 | 
					    """ Test that oversized documents are returned in their own batch"""
 | 
				
			||||||
 | 
					    docs = [get_random_doc(doc_size) for doc_size in doc_sizes]
 | 
				
			||||||
 | 
					    examples = [Example(doc=doc) for doc in docs]
 | 
				
			||||||
 | 
					    tol = 0.2
 | 
				
			||||||
 | 
					    batch_size = 1000
 | 
				
			||||||
 | 
					    batches = list(minibatch_by_words(examples=examples, size=batch_size, tolerance=tol, discard_oversize=False))
 | 
				
			||||||
 | 
					    assert [len(batch) for batch in batches] == expected_batches
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -92,6 +92,13 @@ def get_batch(batch_size):
 | 
				
			||||||
    return docs
 | 
					    return docs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_random_doc(n_words):
 | 
				
			||||||
 | 
					    vocab = Vocab()
 | 
				
			||||||
 | 
					    # Make the words numbers, so that they're easy to track.
 | 
				
			||||||
 | 
					    numbers = [str(i) for i in range(0, n_words)]
 | 
				
			||||||
 | 
					    return Doc(vocab, words=numbers)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def apply_transition_sequence(parser, doc, sequence):
 | 
					def apply_transition_sequence(parser, doc, sequence):
 | 
				
			||||||
    """Perform a series of pre-specified transitions, to put the parser in a
 | 
					    """Perform a series of pre-specified transitions, to put the parser in a
 | 
				
			||||||
    desired state."""
 | 
					    desired state."""
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -656,41 +656,73 @@ def decaying(start, stop, decay):
 | 
				
			||||||
        curr -= decay
 | 
					        curr -= decay
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def minibatch_by_words(examples, size, tuples=True, count_words=len, tolerance=0.2):
 | 
					def minibatch_by_words(examples, size, count_words=len, tolerance=0.2, discard_oversize=False):
 | 
				
			||||||
    """Create minibatches of roughly a given number of words. If any examples
 | 
					    """Create minibatches of roughly a given number of words. If any examples
 | 
				
			||||||
    are longer than the specified batch length, they will appear in a batch by
 | 
					    are longer than the specified batch length, they will appear in a batch by
 | 
				
			||||||
    themselves."""
 | 
					    themselves, or be discarded if discard_oversize=True."""
 | 
				
			||||||
    if isinstance(size, int):
 | 
					    if isinstance(size, int):
 | 
				
			||||||
        size_ = itertools.repeat(size)
 | 
					        size_ = itertools.repeat(size)
 | 
				
			||||||
    elif isinstance(size, List):
 | 
					    elif isinstance(size, List):
 | 
				
			||||||
        size_ = iter(size)
 | 
					        size_ = iter(size)
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        size_ = size
 | 
					        size_ = size
 | 
				
			||||||
    examples = iter(examples)
 | 
					
 | 
				
			||||||
    oversize = []
 | 
					    target_size = next(size_)
 | 
				
			||||||
    while True:
 | 
					    tol_size = target_size * tolerance
 | 
				
			||||||
        batch_size = next(size_)
 | 
					 | 
				
			||||||
        tol_size = batch_size * 0.2
 | 
					 | 
				
			||||||
    batch = []
 | 
					    batch = []
 | 
				
			||||||
        if oversize:
 | 
					    overflow = []
 | 
				
			||||||
            example = oversize.pop(0)
 | 
					    batch_size = 0
 | 
				
			||||||
 | 
					    overflow_size = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for example in examples:
 | 
				
			||||||
        n_words = count_words(example.doc)
 | 
					        n_words = count_words(example.doc)
 | 
				
			||||||
 | 
					        # if the current example exceeds the maximum batch size, it is returned separately
 | 
				
			||||||
 | 
					        # but only if discard_oversize=False.
 | 
				
			||||||
 | 
					        if n_words > target_size + tol_size:
 | 
				
			||||||
 | 
					            if not discard_oversize:
 | 
				
			||||||
 | 
					                yield [example]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # add the example to the current batch if there's no overflow yet and it still fits
 | 
				
			||||||
 | 
					        elif overflow_size == 0 and (batch_size + n_words) <= target_size:
 | 
				
			||||||
            batch.append(example)
 | 
					            batch.append(example)
 | 
				
			||||||
            batch_size -= n_words
 | 
					            batch_size += n_words
 | 
				
			||||||
        while batch_size >= 1:
 | 
					
 | 
				
			||||||
            try:
 | 
					        # add the example to the overflow buffer if it fits in the tolerance margin
 | 
				
			||||||
                example = next(examples)
 | 
					        elif (batch_size + overflow_size + n_words) <= (target_size + tol_size):
 | 
				
			||||||
            except StopIteration:
 | 
					            overflow.append(example)
 | 
				
			||||||
                if batch:
 | 
					            overflow_size += n_words
 | 
				
			||||||
                    yield batch
 | 
					
 | 
				
			||||||
                return
 | 
					        # yield the previous batch and start a new one. The new one gets the overflow examples.
 | 
				
			||||||
            n_words = count_words(example.doc)
 | 
					 | 
				
			||||||
            if n_words < (batch_size + tol_size):
 | 
					 | 
				
			||||||
                batch_size -= n_words
 | 
					 | 
				
			||||||
                batch.append(example)
 | 
					 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
                oversize.append(example)
 | 
					            yield batch
 | 
				
			||||||
 | 
					            target_size = next(size_)
 | 
				
			||||||
 | 
					            tol_size = target_size * tolerance
 | 
				
			||||||
 | 
					            batch = overflow
 | 
				
			||||||
 | 
					            batch_size = overflow_size
 | 
				
			||||||
 | 
					            overflow = []
 | 
				
			||||||
 | 
					            overflow_size = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # this example still fits
 | 
				
			||||||
 | 
					            if (batch_size + n_words) <= target_size:
 | 
				
			||||||
 | 
					                batch.append(example)
 | 
				
			||||||
 | 
					                batch_size += n_words
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # this example fits in overflow
 | 
				
			||||||
 | 
					            elif (batch_size + n_words) <= (target_size + tol_size):
 | 
				
			||||||
 | 
					                overflow.append(example)
 | 
				
			||||||
 | 
					                overflow_size += n_words
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # this example does not fit with the previous overflow: start another new batch
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                yield batch
 | 
				
			||||||
 | 
					                target_size = next(size_)
 | 
				
			||||||
 | 
					                tol_size = target_size * tolerance
 | 
				
			||||||
 | 
					                batch = [example]
 | 
				
			||||||
 | 
					                batch_size = n_words
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # yield the final batch
 | 
				
			||||||
    if batch:
 | 
					    if batch:
 | 
				
			||||||
 | 
					        batch.extend(overflow)
 | 
				
			||||||
        yield batch
 | 
					        yield batch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user