mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-30 23:47:31 +03:00 
			
		
		
		
	* Set scores to 0 before prediction
This commit is contained in:
		
							parent
							
								
									2be517ba6d
								
							
						
					
					
						commit
						8f068dc6fe
					
				|  | @ -2,6 +2,8 @@ | ||||||
| from __future__ import unicode_literals | from __future__ import unicode_literals | ||||||
| from __future__ import division | from __future__ import division | ||||||
| 
 | 
 | ||||||
|  | from libc.string cimport memset | ||||||
|  | 
 | ||||||
| from os import path | from os import path | ||||||
| import os | import os | ||||||
| import shutil | import shutil | ||||||
|  | @ -61,6 +63,7 @@ cdef class Model: | ||||||
|             self._model.load(self.model_loc, freq_thresh=0) |             self._model.load(self.model_loc, freq_thresh=0) | ||||||
| 
 | 
 | ||||||
|     def predict(self, Example eg): |     def predict(self, Example eg): | ||||||
|  |         memset(eg.c.scores, 0, sizeof(weight_t) * eg.c.nr_class) | ||||||
|         self.set_scores(eg.c.scores, eg.c.atoms) |         self.set_scores(eg.c.scores, eg.c.atoms) | ||||||
|         eg.c.guess = arg_max_if_true(eg.c.scores, eg.c.is_valid, self.n_classes) |         eg.c.guess = arg_max_if_true(eg.c.scores, eg.c.is_valid, self.n_classes) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user