diff --git a/spacy/_ml.pyx b/spacy/_ml.pyx index 5bfa05573..706396fd1 100644 --- a/spacy/_ml.pyx +++ b/spacy/_ml.pyx @@ -29,22 +29,25 @@ cdef int arg_max(const weight_t* scores, const int n_classes) nogil: cdef int arg_max_if_true(const weight_t* scores, const int* is_valid, const int n_classes) nogil: cdef int i - cdef int best = 0 - cdef weight_t mode = -900000 + cdef int best = -1 + cdef weight_t mode = 0 for i in range(n_classes): - if is_valid[i] and scores[i] > mode: + if is_valid[i] and (best == -1 or scores[i] > mode): mode = scores[i] best = i return best +class ValidationError(Exception): + pass + cdef int arg_max_if_zero(const weight_t* scores, const int* costs, const int n_classes) nogil: cdef int i - cdef int best = 0 - cdef weight_t mode = -900000 + cdef int best = -1 + cdef weight_t mode = 0 for i in range(n_classes): - if costs[i] == 0 and scores[i] > mode: + if costs[i] == 0 and (best == -1 or scores[i] > mode): mode = scores[i] best = i return best @@ -63,13 +66,18 @@ cdef class Model: self._model.load(self.model_loc, freq_thresh=0) def predict(self, Example eg): + assert self.n_classes == eg.c.nr_class memset(eg.c.scores, 0, sizeof(weight_t) * eg.c.nr_class) self.set_scores(eg.c.scores, eg.c.atoms) eg.c.guess = arg_max_if_true(eg.c.scores, eg.c.is_valid, self.n_classes) + if eg.c.guess == -1: + raise ValidationError("No valid classes during prediction") def train(self, Example eg): self.predict(eg) eg.c.best = arg_max_if_zero(eg.c.scores, eg.c.costs, self.n_classes) + if eg.c.best == -1: + raise ValidationError("No zero-cost classes during training.") eg.c.cost = eg.c.costs[eg.c.guess] self.update(eg.c.atoms, eg.c.guess, eg.c.best, eg.c.cost)