# cython: infer_types=True cdef inline int arg_max(const float* scores, const int n_classes) nogil: if n_classes == 2: return 0 if scores[0] > scores[1] else 1 cdef int i cdef int best = 0 cdef float mode = scores[0] for i in range(1, n_classes): if scores[i] > mode: mode = scores[i] best = i return best cdef inline int arg_max_if_valid(const float* scores, const int* is_valid, int n) nogil: cdef int best = -1 for i in range(n): if is_valid[i] >= 1: if best == -1 or scores[i] > scores[best]: best = i return best