mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 16:07:41 +03:00 
			
		
		
		
	* Add validation for argmaxing in _ml.pyx
This commit is contained in:
		
							parent
							
								
									5d933eec8e
								
							
						
					
					
						commit
						12dd4f745a
					
				|  | @ -29,22 +29,25 @@ cdef int arg_max(const weight_t* scores, const int n_classes) nogil: | ||||||
| cdef int arg_max_if_true(const weight_t* scores, const int* is_valid, | cdef int arg_max_if_true(const weight_t* scores, const int* is_valid, | ||||||
|                          const int n_classes) nogil: |                          const int n_classes) nogil: | ||||||
|     cdef int i |     cdef int i | ||||||
|     cdef int best = 0 |     cdef int best = -1 | ||||||
|     cdef weight_t mode = -900000 |     cdef weight_t mode = 0 | ||||||
|     for i in range(n_classes): |     for i in range(n_classes): | ||||||
|         if is_valid[i] and scores[i] > mode: |         if is_valid[i] and (best == -1 or scores[i] > mode): | ||||||
|             mode = scores[i] |             mode = scores[i] | ||||||
|             best = i |             best = i | ||||||
|     return best |     return best | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | class ValidationError(Exception): | ||||||
|  |     pass | ||||||
|  | 
 | ||||||
| cdef int arg_max_if_zero(const weight_t* scores, const int* costs, | cdef int arg_max_if_zero(const weight_t* scores, const int* costs, | ||||||
|                          const int n_classes) nogil: |                          const int n_classes) nogil: | ||||||
|     cdef int i |     cdef int i | ||||||
|     cdef int best = 0 |     cdef int best = -1 | ||||||
|     cdef weight_t mode = -900000 |     cdef weight_t mode = 0 | ||||||
|     for i in range(n_classes): |     for i in range(n_classes): | ||||||
|         if costs[i] == 0 and scores[i] > mode: |         if costs[i] == 0 and (best == -1 or scores[i] > mode): | ||||||
|             mode = scores[i] |             mode = scores[i] | ||||||
|             best = i |             best = i | ||||||
|     return best |     return best | ||||||
|  | @ -63,13 +66,18 @@ cdef class Model: | ||||||
|             self._model.load(self.model_loc, freq_thresh=0) |             self._model.load(self.model_loc, freq_thresh=0) | ||||||
| 
 | 
 | ||||||
|     def predict(self, Example eg): |     def predict(self, Example eg): | ||||||
|  |         assert self.n_classes == eg.c.nr_class | ||||||
|         memset(eg.c.scores, 0, sizeof(weight_t) * eg.c.nr_class) |         memset(eg.c.scores, 0, sizeof(weight_t) * eg.c.nr_class) | ||||||
|         self.set_scores(eg.c.scores, eg.c.atoms) |         self.set_scores(eg.c.scores, eg.c.atoms) | ||||||
|         eg.c.guess = arg_max_if_true(eg.c.scores, eg.c.is_valid, self.n_classes) |         eg.c.guess = arg_max_if_true(eg.c.scores, eg.c.is_valid, self.n_classes) | ||||||
|  |         if eg.c.guess == -1: | ||||||
|  |             raise ValidationError("No valid classes during prediction") | ||||||
| 
 | 
 | ||||||
|     def train(self, Example eg): |     def train(self, Example eg): | ||||||
|         self.predict(eg) |         self.predict(eg) | ||||||
|         eg.c.best = arg_max_if_zero(eg.c.scores, eg.c.costs, self.n_classes) |         eg.c.best = arg_max_if_zero(eg.c.scores, eg.c.costs, self.n_classes) | ||||||
|  |         if eg.c.best == -1: | ||||||
|  |             raise ValidationError("No zero-cost classes during training.") | ||||||
|         eg.c.cost = eg.c.costs[eg.c.guess] |         eg.c.cost = eg.c.costs[eg.c.guess] | ||||||
|         self.update(eg.c.atoms, eg.c.guess, eg.c.best, eg.c.cost) |         self.update(eg.c.atoms, eg.c.guess, eg.c.best, eg.c.cost) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user