mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	extending algorithm to deal better with edge cases
This commit is contained in:
		
							parent
							
								
									f2e162fc60
								
							
						
					
					
						commit
						aa6271b16c
					
				| 
						 | 
				
			
			@ -11,13 +11,29 @@ from spacy.util import minibatch_by_words
 | 
			
		|||
    [
 | 
			
		||||
        ([400, 400, 199], [3]),
 | 
			
		||||
        ([400, 400, 199, 3], [4]),
 | 
			
		||||
        ([400, 400, 199, 3, 1], [5]),
 | 
			
		||||
        ([400, 400, 199, 3, 200], [3, 2]),
 | 
			
		||||
 | 
			
		||||
        ([400, 400, 199, 3, 1], [5]),
 | 
			
		||||
        ([400, 400, 199, 3, 1, 1500], [5]),    # 1500 will be discarded
 | 
			
		||||
        ([400, 400, 199, 3, 1, 200], [3, 3]),
 | 
			
		||||
        ([400, 400, 199, 3, 1, 999], [3, 3]),
 | 
			
		||||
        ([400, 400, 199, 3, 1, 999, 999], [3, 2, 1, 1]),
 | 
			
		||||
 | 
			
		||||
        ([1, 2, 999], [3]),
 | 
			
		||||
        ([1, 2, 999, 1], [4]),
 | 
			
		||||
        ([1, 200, 999, 1], [2, 2]),
 | 
			
		||||
        ([1, 999, 200, 1], [2, 2]),
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
def test_util_minibatch(doc_sizes, expected_batches):
 | 
			
		||||
    docs = [get_doc(doc_size) for doc_size in doc_sizes]
 | 
			
		||||
    examples = [Example(doc=doc) for doc in docs]
 | 
			
		||||
    batches = list(minibatch_by_words(examples=examples, size=1000))
 | 
			
		||||
    tol = 0.2
 | 
			
		||||
    batch_size = 1000
 | 
			
		||||
    batches = list(minibatch_by_words(examples=examples, size=batch_size, tolerance=tol, discard_oversize=True))
 | 
			
		||||
    assert [len(batch) for batch in batches] == expected_batches
 | 
			
		||||
 | 
			
		||||
    max_size = batch_size + batch_size * tol
 | 
			
		||||
    for batch in batches:
 | 
			
		||||
        assert sum([len(example.doc) for example in batch]) < max_size
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -671,24 +671,24 @@ def minibatch_by_words(examples, size, count_words=len, tolerance=0.2, discard_o
 | 
			
		|||
    tol_size = target_size * tolerance
 | 
			
		||||
    batch = []
 | 
			
		||||
    overflow = []
 | 
			
		||||
    current_size = 0
 | 
			
		||||
    batch_size = 0
 | 
			
		||||
    overflow_size = 0
 | 
			
		||||
 | 
			
		||||
    for example in examples:
 | 
			
		||||
        n_words = count_words(example.doc)
 | 
			
		||||
        # if the current example exceeds the batch size, it is returned separately
 | 
			
		||||
        # if the current example exceeds the maximum batch size, it is returned separately
 | 
			
		||||
        # but only if discard_oversize=False.
 | 
			
		||||
        if n_words > target_size + tol_size:
 | 
			
		||||
            if not discard_oversize:
 | 
			
		||||
                yield [example]
 | 
			
		||||
 | 
			
		||||
        # add the example to the current batch if there's no overflow yet and it still fits
 | 
			
		||||
        elif overflow_size == 0 and (current_size + n_words) < target_size:
 | 
			
		||||
        elif overflow_size == 0 and (batch_size + n_words) <= target_size:
 | 
			
		||||
            batch.append(example)
 | 
			
		||||
            current_size += n_words
 | 
			
		||||
            batch_size += n_words
 | 
			
		||||
 | 
			
		||||
        # add the example to the overflow buffer if it fits in the tolerance margin
 | 
			
		||||
        elif (current_size + overflow_size + n_words) < (target_size + tol_size):
 | 
			
		||||
        elif (batch_size + overflow_size + n_words) <= (target_size + tol_size):
 | 
			
		||||
            overflow.append(example)
 | 
			
		||||
            overflow_size += n_words
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -697,14 +697,29 @@ def minibatch_by_words(examples, size, count_words=len, tolerance=0.2, discard_o
 | 
			
		|||
            yield batch
 | 
			
		||||
            target_size = next(size_)
 | 
			
		||||
            tol_size = target_size * tolerance
 | 
			
		||||
            # In theory it may happen that the current example + overflow examples now exceed the new
 | 
			
		||||
            # target_size, but that seems like an unimportant edge case if batch sizes are variable?
 | 
			
		||||
            batch = overflow
 | 
			
		||||
            batch.append(example)
 | 
			
		||||
            current_size = overflow_size + n_words
 | 
			
		||||
            batch_size = overflow_size
 | 
			
		||||
            overflow = []
 | 
			
		||||
            overflow_size = 0
 | 
			
		||||
 | 
			
		||||
            # this example still fits
 | 
			
		||||
            if (batch_size + n_words) <= target_size:
 | 
			
		||||
                batch.append(example)
 | 
			
		||||
                batch_size += n_words
 | 
			
		||||
 | 
			
		||||
            # this example fits in overflow
 | 
			
		||||
            elif (batch_size + n_words) <= (target_size + tol_size):
 | 
			
		||||
                overflow.append(example)
 | 
			
		||||
                overflow_size += n_words
 | 
			
		||||
 | 
			
		||||
            # this example does not fit with the previous overflow: start another new batch
 | 
			
		||||
            else:
 | 
			
		||||
                yield batch
 | 
			
		||||
                target_size = next(size_)
 | 
			
		||||
                tol_size = target_size * tolerance
 | 
			
		||||
                batch = [example]
 | 
			
		||||
                batch_size = n_words
 | 
			
		||||
 | 
			
		||||
    # yield the final batch
 | 
			
		||||
    if batch:
 | 
			
		||||
        batch.extend(overflow)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user