mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 07:57:35 +03:00 
			
		
		
		
	Add low_data mode in textcat
This commit is contained in:
		
							parent
							
								
									ead78c7b9b
								
							
						
					
					
						commit
						a3b69bcb3d
					
				
							
								
								
									
										30
									
								
								spacy/_ml.py
									
									
									
									
									
								
							
							
						
						
									
										30
									
								
								spacy/_ml.py
									
									
									
									
									
								
							|  | @ -510,9 +510,23 @@ def foreach(layer, drop_factor=1.0): | |||
| 
 | ||||
| 
 | ||||
| def build_text_classifier(nr_class, width=64, **cfg): | ||||
|     nr_vector = cfg.get('nr_vector', 200) | ||||
|     nr_vector = cfg.get('nr_vector', 5000) | ||||
|     with Model.define_operators({'>>': chain, '+': add, '|': concatenate, | ||||
|                                  '**': clone}): | ||||
|         if cfg.get('low_data'): | ||||
|             model = ( | ||||
|                 SpacyVectors | ||||
|                 >> flatten_add_lengths | ||||
|                 >> with_getitem(0, LN(Affine(width, 300))) | ||||
|                 >> ParametricAttention(width) | ||||
|                 >> Pooling(sum_pool) | ||||
|                 >> Residual(ReLu(width, width)) ** 2 | ||||
|                 >> zero_init(Affine(nr_class, width, drop_factor=0.0)) | ||||
|                 >> logistic | ||||
|             ) | ||||
|             return model | ||||
| 
 | ||||
| 
 | ||||
|         lower = HashEmbed(width, nr_vector, column=1) | ||||
|         prefix = HashEmbed(width//2, nr_vector, column=2) | ||||
|         suffix = HashEmbed(width//2, nr_vector, column=3) | ||||
|  | @ -523,7 +537,7 @@ def build_text_classifier(nr_class, width=64, **cfg): | |||
|             >> with_flatten( | ||||
|                 uniqued( | ||||
|                     (lower | prefix | suffix | shape) | ||||
|                     >> LN(Maxout(width, 64+32+32+32)), | ||||
|                     >> LN(Maxout(width, width+(width//2)*3)), | ||||
|                     column=0 | ||||
|                 ) | ||||
|             ) | ||||
|  | @ -537,14 +551,16 @@ def build_text_classifier(nr_class, width=64, **cfg): | |||
|         cnn_model = ( | ||||
|             # TODO Make concatenate support lists | ||||
|             concatenate_lists(trained_vectors, static_vectors)  | ||||
|             >> flatten_add_lengths | ||||
|             >> with_getitem(0, | ||||
|                 SELU(width, width*2) | ||||
|                 >> (ExtractWindow(nW=1) >> SELU(width, width*3)) ** 2 | ||||
|             >> with_flatten( | ||||
|                 LN(Maxout(width, width*2)) | ||||
|                 >> Residual( | ||||
|                     (ExtractWindow(nW=1) >> zero_init(Maxout(width, width*3))) | ||||
|                 ) ** 2, pad=2 | ||||
|             ) | ||||
|             >> flatten_add_lengths | ||||
|             >> ParametricAttention(width) | ||||
|             >> Pooling(sum_pool) | ||||
|             >> SELU(width, width) ** 2 | ||||
|             >> Residual(zero_init(Maxout(width, width))) | ||||
|             >> zero_init(Affine(nr_class, width, drop_factor=0.0)) | ||||
|         ) | ||||
| 
 | ||||
|  |  | |||
|  | @ -638,12 +638,13 @@ class TextCategorizer(BaseThincComponent): | |||
|         return mean_square_error, d_scores | ||||
| 
 | ||||
|     def begin_training(self, gold_tuples=tuple(), pipeline=None): | ||||
|         if pipeline: | ||||
|         if pipeline and getattr(pipeline[0], 'name', None) == 'tensorizer': | ||||
|             token_vector_width = pipeline[0].model.nO | ||||
|         else: | ||||
|             token_vector_width = 64 | ||||
|         if self.model is True: | ||||
|             self.model = self.Model(len(self.labels), token_vector_width) | ||||
|             self.model = self.Model(len(self.labels), token_vector_width, | ||||
|                                     **self.cfg) | ||||
| 
 | ||||
| 
 | ||||
| cdef class EntityRecognizer(LinearParser): | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user