Move argmax impls to new _parser_utils Cython module (#11410)

This commit is contained in:
Madeesh Kannan 2022-09-06 13:56:10 +02:00 committed by GitHub
parent 582232bb77
commit 7eaea5de04
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 28 additions and 20 deletions

View File

@ -51,6 +51,7 @@ MOD_NAMES = [
"spacy.pipeline._parser_internals.stateclass", "spacy.pipeline._parser_internals.stateclass",
"spacy.pipeline._parser_internals.transition_system", "spacy.pipeline._parser_internals.transition_system",
"spacy.pipeline._parser_internals._beam_utils", "spacy.pipeline._parser_internals._beam_utils",
"spacy.pipeline._parser_internals._parser_utils",
"spacy.tokenizer", "spacy.tokenizer",
"spacy.training.align", "spacy.training.align",
"spacy.training.gold_io", "spacy.training.gold_io",

View File

@ -15,6 +15,7 @@ from thinc.types import Ints1d, Ints2d
from ..errors import Errors from ..errors import Errors
from ..pipeline._parser_internals import _beam_utils from ..pipeline._parser_internals import _beam_utils
from ..pipeline._parser_internals.batch import GreedyBatch from ..pipeline._parser_internals.batch import GreedyBatch
from ..pipeline._parser_internals._parser_utils cimport arg_max
from ..pipeline._parser_internals.transition_system cimport c_transition_batch, c_apply_actions from ..pipeline._parser_internals.transition_system cimport c_transition_batch, c_apply_actions
from ..pipeline._parser_internals.transition_system cimport TransitionSystem from ..pipeline._parser_internals.transition_system cimport TransitionSystem
from ..pipeline._parser_internals.stateclass cimport StateC, StateClass from ..pipeline._parser_internals.stateclass cimport StateC, StateClass
@ -514,7 +515,7 @@ cdef void _predict_states(CBlas cblas, ActivationsC* A, float* scores, StateC**
saxpy(cblas)(n.hiddens * n.pieces, 1., W.feat_bias, 1, &A.unmaxed[i*n.hiddens*n.pieces], 1) saxpy(cblas)(n.hiddens * n.pieces, 1., W.feat_bias, 1, &A.unmaxed[i*n.hiddens*n.pieces], 1)
for j in range(n.hiddens): for j in range(n.hiddens):
index = i * n.hiddens * n.pieces + j * n.pieces index = i * n.hiddens * n.pieces + j * n.pieces
which = _arg_max(&A.unmaxed[index], n.pieces) which = arg_max(&A.unmaxed[index], n.pieces)
A.hiddens[i*n.hiddens + j] = A.unmaxed[index + which] A.hiddens[i*n.hiddens + j] = A.unmaxed[index + which]
if W.hidden_weights == NULL: if W.hidden_weights == NULL:
memcpy(scores, A.hiddens, n.states * n.classes * sizeof(float)) memcpy(scores, A.hiddens, n.states * n.classes * sizeof(float))
@ -560,14 +561,3 @@ cdef void _sum_state_features(CBlas cblas, float* output,
saxpy(cblas)(O, one, <const float*>feature, 1, &output[b*O], 1) saxpy(cblas)(O, one, <const float*>feature, 1, &output[b*O], 1)
token_ids += F token_ids += F
cdef inline int _arg_max(const float* scores, const int n_classes) nogil:
if n_classes == 2:
return 0 if scores[0] > scores[1] else 1
cdef int i
cdef int best = 0
cdef float mode = scores[0]
for i in range(1, n_classes):
if scores[i] > mode:
mode = scores[i]
best = i
return best

View File

@ -0,0 +1,2 @@
cdef int arg_max(const float* scores, const int n_classes) nogil
cdef int arg_max_if_valid(const float* scores, const int* is_valid, int n) nogil

View File

@ -0,0 +1,22 @@
# cython: infer_types=True
cdef inline int arg_max(const float* scores, const int n_classes) nogil:
if n_classes == 2:
return 0 if scores[0] > scores[1] else 1
cdef int i
cdef int best = 0
cdef float mode = scores[0]
for i in range(1, n_classes):
if scores[i] > mode:
mode = scores[i]
best = i
return best
cdef inline int arg_max_if_valid(const float* scores, const int* is_valid, int n) nogil:
cdef int best = -1
for i in range(n):
if is_valid[i] >= 1:
if best == -1 or scores[i] > scores[best]:
best = i
return best

View File

@ -12,6 +12,7 @@ from ...typedefs cimport weight_t, attr_t
from ...tokens.doc cimport Doc from ...tokens.doc cimport Doc
from ...structs cimport TokenC from ...structs cimport TokenC
from .stateclass cimport StateClass from .stateclass cimport StateClass
from ._parser_utils cimport arg_max_if_valid
from ...errors import Errors from ...errors import Errors
from ... import util from ... import util
@ -320,11 +321,3 @@ cdef void c_transition_batch(TransitionSystem moves, StateC** states, const floa
states[i].history.push_back(guess) states[i].history.push_back(guess)
free(is_valid) free(is_valid)
cdef int arg_max_if_valid(const weight_t* scores, const int* is_valid, int n) nogil:
cdef int best = -1
for i in range(n):
if is_valid[i] >= 1:
if best == -1 or scores[i] > scores[best]:
best = i
return best