mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-29 11:26:28 +03:00
23 lines
633 B
Cython
23 lines
633 B
Cython
|
# cython: infer_types=True
|
||
|
|
||
|
cdef inline int arg_max(const float* scores, const int n_classes) nogil:
|
||
|
if n_classes == 2:
|
||
|
return 0 if scores[0] > scores[1] else 1
|
||
|
cdef int i
|
||
|
cdef int best = 0
|
||
|
cdef float mode = scores[0]
|
||
|
for i in range(1, n_classes):
|
||
|
if scores[i] > mode:
|
||
|
mode = scores[i]
|
||
|
best = i
|
||
|
return best
|
||
|
|
||
|
|
||
|
cdef inline int arg_max_if_valid(const float* scores, const int* is_valid, int n) nogil:
|
||
|
cdef int best = -1
|
||
|
for i in range(n):
|
||
|
if is_valid[i] >= 1:
|
||
|
if best == -1 or scores[i] > scores[best]:
|
||
|
best = i
|
||
|
return best
|