mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 16:07:41 +03:00 
			
		
		
		
	Add low_data mode in textcat
This commit is contained in:
		
							parent
							
								
									ead78c7b9b
								
							
						
					
					
						commit
						a3b69bcb3d
					
				
							
								
								
									
										30
									
								
								spacy/_ml.py
									
									
									
									
									
								
							
							
						
						
									
										30
									
								
								spacy/_ml.py
									
									
									
									
									
								
							|  | @ -510,9 +510,23 @@ def foreach(layer, drop_factor=1.0): | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def build_text_classifier(nr_class, width=64, **cfg): | def build_text_classifier(nr_class, width=64, **cfg): | ||||||
|     nr_vector = cfg.get('nr_vector', 200) |     nr_vector = cfg.get('nr_vector', 5000) | ||||||
|     with Model.define_operators({'>>': chain, '+': add, '|': concatenate, |     with Model.define_operators({'>>': chain, '+': add, '|': concatenate, | ||||||
|                                  '**': clone}): |                                  '**': clone}): | ||||||
|  |         if cfg.get('low_data'): | ||||||
|  |             model = ( | ||||||
|  |                 SpacyVectors | ||||||
|  |                 >> flatten_add_lengths | ||||||
|  |                 >> with_getitem(0, LN(Affine(width, 300))) | ||||||
|  |                 >> ParametricAttention(width) | ||||||
|  |                 >> Pooling(sum_pool) | ||||||
|  |                 >> Residual(ReLu(width, width)) ** 2 | ||||||
|  |                 >> zero_init(Affine(nr_class, width, drop_factor=0.0)) | ||||||
|  |                 >> logistic | ||||||
|  |             ) | ||||||
|  |             return model | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|         lower = HashEmbed(width, nr_vector, column=1) |         lower = HashEmbed(width, nr_vector, column=1) | ||||||
|         prefix = HashEmbed(width//2, nr_vector, column=2) |         prefix = HashEmbed(width//2, nr_vector, column=2) | ||||||
|         suffix = HashEmbed(width//2, nr_vector, column=3) |         suffix = HashEmbed(width//2, nr_vector, column=3) | ||||||
|  | @ -523,7 +537,7 @@ def build_text_classifier(nr_class, width=64, **cfg): | ||||||
|             >> with_flatten( |             >> with_flatten( | ||||||
|                 uniqued( |                 uniqued( | ||||||
|                     (lower | prefix | suffix | shape) |                     (lower | prefix | suffix | shape) | ||||||
|                     >> LN(Maxout(width, 64+32+32+32)), |                     >> LN(Maxout(width, width+(width//2)*3)), | ||||||
|                     column=0 |                     column=0 | ||||||
|                 ) |                 ) | ||||||
|             ) |             ) | ||||||
|  | @ -537,14 +551,16 @@ def build_text_classifier(nr_class, width=64, **cfg): | ||||||
|         cnn_model = ( |         cnn_model = ( | ||||||
|             # TODO Make concatenate support lists |             # TODO Make concatenate support lists | ||||||
|             concatenate_lists(trained_vectors, static_vectors)  |             concatenate_lists(trained_vectors, static_vectors)  | ||||||
|             >> flatten_add_lengths |             >> with_flatten( | ||||||
|             >> with_getitem(0, |                 LN(Maxout(width, width*2)) | ||||||
|                 SELU(width, width*2) |                 >> Residual( | ||||||
|                 >> (ExtractWindow(nW=1) >> SELU(width, width*3)) ** 2 |                     (ExtractWindow(nW=1) >> zero_init(Maxout(width, width*3))) | ||||||
|  |                 ) ** 2, pad=2 | ||||||
|             ) |             ) | ||||||
|  |             >> flatten_add_lengths | ||||||
|             >> ParametricAttention(width) |             >> ParametricAttention(width) | ||||||
|             >> Pooling(sum_pool) |             >> Pooling(sum_pool) | ||||||
|             >> SELU(width, width) ** 2 |             >> Residual(zero_init(Maxout(width, width))) | ||||||
|             >> zero_init(Affine(nr_class, width, drop_factor=0.0)) |             >> zero_init(Affine(nr_class, width, drop_factor=0.0)) | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -638,12 +638,13 @@ class TextCategorizer(BaseThincComponent): | ||||||
|         return mean_square_error, d_scores |         return mean_square_error, d_scores | ||||||
| 
 | 
 | ||||||
|     def begin_training(self, gold_tuples=tuple(), pipeline=None): |     def begin_training(self, gold_tuples=tuple(), pipeline=None): | ||||||
|         if pipeline: |         if pipeline and getattr(pipeline[0], 'name', None) == 'tensorizer': | ||||||
|             token_vector_width = pipeline[0].model.nO |             token_vector_width = pipeline[0].model.nO | ||||||
|         else: |         else: | ||||||
|             token_vector_width = 64 |             token_vector_width = 64 | ||||||
|         if self.model is True: |         if self.model is True: | ||||||
|             self.model = self.Model(len(self.labels), token_vector_width) |             self.model = self.Model(len(self.labels), token_vector_width, | ||||||
|  |                                     **self.cfg) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| cdef class EntityRecognizer(LinearParser): | cdef class EntityRecognizer(LinearParser): | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user