mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 07:57:35 +03:00 
			
		
		
		
	* register extract_ngrams layer * fix import * bump spacy-legacy to 3.0.6 * revert bump (wrong PR)
		
			
				
	
	
		
			35 lines
		
	
	
		
			1.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			35 lines
		
	
	
		
			1.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from thinc.api import Model
 | |
| 
 | |
| from ..util import registry
 | |
| from ..attrs import LOWER
 | |
| 
 | |
| 
 | |
| @registry.layers("spacy.extract_ngrams.v1")
 | |
| def extract_ngrams(ngram_size: int, attr: int = LOWER) -> Model:
 | |
|     model = Model("extract_ngrams", forward)
 | |
|     model.attrs["ngram_size"] = ngram_size
 | |
|     model.attrs["attr"] = attr
 | |
|     return model
 | |
| 
 | |
| 
 | |
| def forward(model: Model, docs, is_train: bool):
 | |
|     batch_keys = []
 | |
|     batch_vals = []
 | |
|     for doc in docs:
 | |
|         unigrams = model.ops.asarray(doc.to_array([model.attrs["attr"]]))
 | |
|         ngrams = [unigrams]
 | |
|         for n in range(2, model.attrs["ngram_size"] + 1):
 | |
|             ngrams.append(model.ops.ngrams(n, unigrams))
 | |
|         keys = model.ops.xp.concatenate(ngrams)
 | |
|         keys, vals = model.ops.xp.unique(keys, return_counts=True)
 | |
|         batch_keys.append(keys)
 | |
|         batch_vals.append(vals)
 | |
|     lengths = model.ops.asarray([arr.shape[0] for arr in batch_keys], dtype="int32")
 | |
|     batch_keys = model.ops.xp.concatenate(batch_keys)
 | |
|     batch_vals = model.ops.asarray(model.ops.xp.concatenate(batch_vals), dtype="f")
 | |
| 
 | |
|     def backprop(dY):
 | |
|         return []
 | |
| 
 | |
|     return (batch_keys, batch_vals, lengths), backprop
 |