mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	Clean up TextCategorizer slightly
This commit is contained in:
		
							parent
							
								
									d13b9373bf
								
							
						
					
					
						commit
						6b0008afc6
					
				| 
						 | 
				
			
			@ -946,7 +946,7 @@ class TextCategorizer(Pipe):
 | 
			
		|||
        not_missing = self.model.ops.asarray(not_missing)
 | 
			
		||||
        d_scores = (scores-truths) / scores.shape[0]
 | 
			
		||||
        d_scores *= not_missing
 | 
			
		||||
        mean_square_error = ((scores-truths)**2).sum(axis=1).mean()
 | 
			
		||||
        mean_square_error = (d_scores**2).sum(axis=1).mean()
 | 
			
		||||
        return float(mean_square_error), d_scores
 | 
			
		||||
 | 
			
		||||
    def add_label(self, label):
 | 
			
		||||
| 
						 | 
				
			
			@ -968,11 +968,6 @@ class TextCategorizer(Pipe):
 | 
			
		|||
 | 
			
		||||
    def begin_training(self, get_gold_tuples=lambda: [], pipeline=None, sgd=None,
 | 
			
		||||
                       **kwargs):
 | 
			
		||||
        if pipeline and getattr(pipeline[0], 'name', None) == 'tensorizer':
 | 
			
		||||
            token_vector_width = pipeline[0].model.nO
 | 
			
		||||
        else:
 | 
			
		||||
            token_vector_width = 64
 | 
			
		||||
 | 
			
		||||
        if self.model is True:
 | 
			
		||||
            self.cfg['pretrained_vectors'] = kwargs.get('pretrained_vectors')
 | 
			
		||||
            self.model = self.Model(len(self.labels), **self.cfg)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user