mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-04 03:43:09 +03:00
* Add validation for argmaxing in _ml.pyx
This commit is contained in:
parent
5d933eec8e
commit
12dd4f745a
|
@ -29,22 +29,25 @@ cdef int arg_max(const weight_t* scores, const int n_classes) nogil:
|
||||||
cdef int arg_max_if_true(const weight_t* scores, const int* is_valid,
|
cdef int arg_max_if_true(const weight_t* scores, const int* is_valid,
|
||||||
const int n_classes) nogil:
|
const int n_classes) nogil:
|
||||||
cdef int i
|
cdef int i
|
||||||
cdef int best = 0
|
cdef int best = -1
|
||||||
cdef weight_t mode = -900000
|
cdef weight_t mode = 0
|
||||||
for i in range(n_classes):
|
for i in range(n_classes):
|
||||||
if is_valid[i] and scores[i] > mode:
|
if is_valid[i] and (best == -1 or scores[i] > mode):
|
||||||
mode = scores[i]
|
mode = scores[i]
|
||||||
best = i
|
best = i
|
||||||
return best
|
return best
|
||||||
|
|
||||||
|
|
||||||
|
class ValidationError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
cdef int arg_max_if_zero(const weight_t* scores, const int* costs,
|
cdef int arg_max_if_zero(const weight_t* scores, const int* costs,
|
||||||
const int n_classes) nogil:
|
const int n_classes) nogil:
|
||||||
cdef int i
|
cdef int i
|
||||||
cdef int best = 0
|
cdef int best = -1
|
||||||
cdef weight_t mode = -900000
|
cdef weight_t mode = 0
|
||||||
for i in range(n_classes):
|
for i in range(n_classes):
|
||||||
if costs[i] == 0 and scores[i] > mode:
|
if costs[i] == 0 and (best == -1 or scores[i] > mode):
|
||||||
mode = scores[i]
|
mode = scores[i]
|
||||||
best = i
|
best = i
|
||||||
return best
|
return best
|
||||||
|
@ -63,13 +66,18 @@ cdef class Model:
|
||||||
self._model.load(self.model_loc, freq_thresh=0)
|
self._model.load(self.model_loc, freq_thresh=0)
|
||||||
|
|
||||||
def predict(self, Example eg):
|
def predict(self, Example eg):
|
||||||
|
assert self.n_classes == eg.c.nr_class
|
||||||
memset(eg.c.scores, 0, sizeof(weight_t) * eg.c.nr_class)
|
memset(eg.c.scores, 0, sizeof(weight_t) * eg.c.nr_class)
|
||||||
self.set_scores(eg.c.scores, eg.c.atoms)
|
self.set_scores(eg.c.scores, eg.c.atoms)
|
||||||
eg.c.guess = arg_max_if_true(eg.c.scores, eg.c.is_valid, self.n_classes)
|
eg.c.guess = arg_max_if_true(eg.c.scores, eg.c.is_valid, self.n_classes)
|
||||||
|
if eg.c.guess == -1:
|
||||||
|
raise ValidationError("No valid classes during prediction")
|
||||||
|
|
||||||
def train(self, Example eg):
|
def train(self, Example eg):
|
||||||
self.predict(eg)
|
self.predict(eg)
|
||||||
eg.c.best = arg_max_if_zero(eg.c.scores, eg.c.costs, self.n_classes)
|
eg.c.best = arg_max_if_zero(eg.c.scores, eg.c.costs, self.n_classes)
|
||||||
|
if eg.c.best == -1:
|
||||||
|
raise ValidationError("No zero-cost classes during training.")
|
||||||
eg.c.cost = eg.c.costs[eg.c.guess]
|
eg.c.cost = eg.c.costs[eg.c.guess]
|
||||||
self.update(eg.c.atoms, eg.c.guess, eg.c.best, eg.c.cost)
|
self.update(eg.c.atoms, eg.c.guess, eg.c.best, eg.c.cost)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user