* Add validation for argmaxing in _ml.pyx

This commit is contained in:
Matthew Honnibal 2015-07-03 09:18:33 +02:00
parent 5d933eec8e
commit 12dd4f745a

View File

@ -29,22 +29,25 @@ cdef int arg_max(const weight_t* scores, const int n_classes) nogil:
cdef int arg_max_if_true(const weight_t* scores, const int* is_valid, cdef int arg_max_if_true(const weight_t* scores, const int* is_valid,
const int n_classes) nogil: const int n_classes) nogil:
cdef int i cdef int i
cdef int best = 0 cdef int best = -1
cdef weight_t mode = -900000 cdef weight_t mode = 0
for i in range(n_classes): for i in range(n_classes):
if is_valid[i] and scores[i] > mode: if is_valid[i] and (best == -1 or scores[i] > mode):
mode = scores[i] mode = scores[i]
best = i best = i
return best return best
class ValidationError(Exception):
pass
cdef int arg_max_if_zero(const weight_t* scores, const int* costs, cdef int arg_max_if_zero(const weight_t* scores, const int* costs,
const int n_classes) nogil: const int n_classes) nogil:
cdef int i cdef int i
cdef int best = 0 cdef int best = -1
cdef weight_t mode = -900000 cdef weight_t mode = 0
for i in range(n_classes): for i in range(n_classes):
if costs[i] == 0 and scores[i] > mode: if costs[i] == 0 and (best == -1 or scores[i] > mode):
mode = scores[i] mode = scores[i]
best = i best = i
return best return best
@ -63,13 +66,18 @@ cdef class Model:
self._model.load(self.model_loc, freq_thresh=0) self._model.load(self.model_loc, freq_thresh=0)
def predict(self, Example eg): def predict(self, Example eg):
assert self.n_classes == eg.c.nr_class
memset(eg.c.scores, 0, sizeof(weight_t) * eg.c.nr_class) memset(eg.c.scores, 0, sizeof(weight_t) * eg.c.nr_class)
self.set_scores(eg.c.scores, eg.c.atoms) self.set_scores(eg.c.scores, eg.c.atoms)
eg.c.guess = arg_max_if_true(eg.c.scores, eg.c.is_valid, self.n_classes) eg.c.guess = arg_max_if_true(eg.c.scores, eg.c.is_valid, self.n_classes)
if eg.c.guess == -1:
raise ValidationError("No valid classes during prediction")
def train(self, Example eg): def train(self, Example eg):
self.predict(eg) self.predict(eg)
eg.c.best = arg_max_if_zero(eg.c.scores, eg.c.costs, self.n_classes) eg.c.best = arg_max_if_zero(eg.c.scores, eg.c.costs, self.n_classes)
if eg.c.best == -1:
raise ValidationError("No zero-cost classes during training.")
eg.c.cost = eg.c.costs[eg.c.guess] eg.c.cost = eg.c.costs[eg.c.guess]
self.update(eg.c.atoms, eg.c.guess, eg.c.best, eg.c.cost) self.update(eg.c.atoms, eg.c.guess, eg.c.best, eg.c.cost)