mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 07:57:35 +03:00 
			
		
		
		
	Add extra batch util
This commit is contained in:
		
							parent
							
								
									eb0798c421
								
							
						
					
					
						commit
						3a7f275c02
					
				|  | @ -722,6 +722,50 @@ def minibatch(items, size=8): | ||||||
|         yield list(batch) |         yield list(batch) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def minibatch_by_padded_size(docs, size, buffer=256, discard_oversize=False): | ||||||
|  |     if isinstance(size, int): | ||||||
|  |         size_ = itertools.repeat(size) | ||||||
|  |     else: | ||||||
|  |         size_ = size | ||||||
|  |     for outer_batch in minibatch(docs, buffer): | ||||||
|  |         outer_batch = list(outer_batch) | ||||||
|  |         target_size = next(size_) | ||||||
|  |         for indices in _batch_by_length(outer_batch, target_size): | ||||||
|  |             subbatch = [outer_batch[i] for i in indices] | ||||||
|  |             padded_size = max(len(seq) for seq in subbatch) * len(subbatch) | ||||||
|  |             if discard_oversize and padded_size >= target_size: | ||||||
|  |                 pass | ||||||
|  |             else: | ||||||
|  |                 yield subbatch | ||||||
|  |   | ||||||
|  | 
 | ||||||
|  | def _batch_by_length(seqs, max_words): | ||||||
|  |     """Given a list of sequences, return a batched list of indices into the  | ||||||
|  |     list, where the batches are grouped by length, in descending order. | ||||||
|  |      | ||||||
|  |     Batches may be at most max_words in size, defined as max sequence length * size. | ||||||
|  |     """ | ||||||
|  |     # Use negative index so we can get sort by position ascending. | ||||||
|  |     lengths_indices = [(len(seq), i) for i, seq in enumerate(seqs)] | ||||||
|  |     lengths_indices.sort() | ||||||
|  |     batches = [] | ||||||
|  |     batch = [] | ||||||
|  |     for length, i in lengths_indices: | ||||||
|  |         if not batch: | ||||||
|  |             batch.append(i) | ||||||
|  |         elif length * (len(batch) + 1) <= max_words: | ||||||
|  |             batch.append(i) | ||||||
|  |         else: | ||||||
|  |             batches.append(batch) | ||||||
|  |             batch = [i] | ||||||
|  |     if batch: | ||||||
|  |         batches.append(batch) | ||||||
|  |     # Check lengths match | ||||||
|  |     assert sum(len(b) for b in batches) == len(seqs) | ||||||
|  |     batches = [list(sorted(batch)) for batch in batches] | ||||||
|  |     batches.reverse() | ||||||
|  |     return batches | ||||||
|  | 
 | ||||||
| def minibatch_by_words(docs, size, tolerance=0.2, discard_oversize=False): | def minibatch_by_words(docs, size, tolerance=0.2, discard_oversize=False): | ||||||
|     """Create minibatches of roughly a given number of words. If any examples |     """Create minibatches of roughly a given number of words. If any examples | ||||||
|     are longer than the specified batch length, they will appear in a batch by |     are longer than the specified batch length, they will appear in a batch by | ||||||
|  | @ -768,7 +812,8 @@ def minibatch_by_words(docs, size, tolerance=0.2, discard_oversize=False): | ||||||
| 
 | 
 | ||||||
|         # yield the previous batch and start a new one. The new one gets the overflow examples. |         # yield the previous batch and start a new one. The new one gets the overflow examples. | ||||||
|         else: |         else: | ||||||
|             yield batch |             if batch: | ||||||
|  |                 yield batch | ||||||
|             target_size = next(size_) |             target_size = next(size_) | ||||||
|             tol_size = target_size * tolerance |             tol_size = target_size * tolerance | ||||||
|             batch = overflow |             batch = overflow | ||||||
|  | @ -788,15 +833,15 @@ def minibatch_by_words(docs, size, tolerance=0.2, discard_oversize=False): | ||||||
| 
 | 
 | ||||||
|             # this example does not fit with the previous overflow: start another new batch |             # this example does not fit with the previous overflow: start another new batch | ||||||
|             else: |             else: | ||||||
|                 yield batch |                 if batch: | ||||||
|  |                     yield batch | ||||||
|                 target_size = next(size_) |                 target_size = next(size_) | ||||||
|                 tol_size = target_size * tolerance |                 tol_size = target_size * tolerance | ||||||
|                 batch = [doc] |                 batch = [doc] | ||||||
|                 batch_size = n_words |                 batch_size = n_words | ||||||
| 
 | 
 | ||||||
|     # yield the final batch |     batch.extend(overflow) | ||||||
|     if batch: |     if batch: | ||||||
|         batch.extend(overflow) |  | ||||||
|         yield batch |         yield batch | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user