From 7eaea5de04af183a4e0f3d93f535c28527ed0ab0 Mon Sep 17 00:00:00 2001 From: Madeesh Kannan Date: Tue, 6 Sep 2022 13:56:10 +0200 Subject: [PATCH] Move `argmax` impls to new `_parser_utils` Cython module (#11410) --- setup.py | 1 + spacy/ml/tb_framework.pyx | 14 ++---------- .../_parser_internals/_parser_utils.pxd | 2 ++ .../_parser_internals/_parser_utils.pyx | 22 +++++++++++++++++++ .../_parser_internals/transition_system.pyx | 9 +------- 5 files changed, 28 insertions(+), 20 deletions(-) create mode 100644 spacy/pipeline/_parser_internals/_parser_utils.pxd create mode 100644 spacy/pipeline/_parser_internals/_parser_utils.pyx diff --git a/setup.py b/setup.py index 6466f1c32..316a58f47 100755 --- a/setup.py +++ b/setup.py @@ -51,6 +51,7 @@ MOD_NAMES = [ "spacy.pipeline._parser_internals.stateclass", "spacy.pipeline._parser_internals.transition_system", "spacy.pipeline._parser_internals._beam_utils", + "spacy.pipeline._parser_internals._parser_utils", "spacy.tokenizer", "spacy.training.align", "spacy.training.gold_io", diff --git a/spacy/ml/tb_framework.pyx b/spacy/ml/tb_framework.pyx index 421f825c7..d6cd5be2b 100644 --- a/spacy/ml/tb_framework.pyx +++ b/spacy/ml/tb_framework.pyx @@ -15,6 +15,7 @@ from thinc.types import Ints1d, Ints2d from ..errors import Errors from ..pipeline._parser_internals import _beam_utils from ..pipeline._parser_internals.batch import GreedyBatch +from ..pipeline._parser_internals._parser_utils cimport arg_max from ..pipeline._parser_internals.transition_system cimport c_transition_batch, c_apply_actions from ..pipeline._parser_internals.transition_system cimport TransitionSystem from ..pipeline._parser_internals.stateclass cimport StateC, StateClass @@ -514,7 +515,7 @@ cdef void _predict_states(CBlas cblas, ActivationsC* A, float* scores, StateC** saxpy(cblas)(n.hiddens * n.pieces, 1., W.feat_bias, 1, &A.unmaxed[i*n.hiddens*n.pieces], 1) for j in range(n.hiddens): index = i * n.hiddens * n.pieces + j * n.pieces - which = _arg_max(&A.unmaxed[index], n.pieces) + which = arg_max(&A.unmaxed[index], n.pieces) A.hiddens[i*n.hiddens + j] = A.unmaxed[index + which] if W.hidden_weights == NULL: memcpy(scores, A.hiddens, n.states * n.classes * sizeof(float)) @@ -560,14 +561,3 @@ cdef void _sum_state_features(CBlas cblas, float* output, saxpy(cblas)(O, one, feature, 1, &output[b*O], 1) token_ids += F -cdef inline int _arg_max(const float* scores, const int n_classes) nogil: - if n_classes == 2: - return 0 if scores[0] > scores[1] else 1 - cdef int i - cdef int best = 0 - cdef float mode = scores[0] - for i in range(1, n_classes): - if scores[i] > mode: - mode = scores[i] - best = i - return best \ No newline at end of file diff --git a/spacy/pipeline/_parser_internals/_parser_utils.pxd b/spacy/pipeline/_parser_internals/_parser_utils.pxd new file mode 100644 index 000000000..7fee05bad --- /dev/null +++ b/spacy/pipeline/_parser_internals/_parser_utils.pxd @@ -0,0 +1,2 @@ +cdef int arg_max(const float* scores, const int n_classes) nogil +cdef int arg_max_if_valid(const float* scores, const int* is_valid, int n) nogil diff --git a/spacy/pipeline/_parser_internals/_parser_utils.pyx b/spacy/pipeline/_parser_internals/_parser_utils.pyx new file mode 100644 index 000000000..582756bf5 --- /dev/null +++ b/spacy/pipeline/_parser_internals/_parser_utils.pyx @@ -0,0 +1,22 @@ +# cython: infer_types=True + +cdef inline int arg_max(const float* scores, const int n_classes) nogil: + if n_classes == 2: + return 0 if scores[0] > scores[1] else 1 + cdef int i + cdef int best = 0 + cdef float mode = scores[0] + for i in range(1, n_classes): + if scores[i] > mode: + mode = scores[i] + best = i + return best + + +cdef inline int arg_max_if_valid(const float* scores, const int* is_valid, int n) nogil: + cdef int best = -1 + for i in range(n): + if is_valid[i] >= 1: + if best == -1 or scores[i] > scores[best]: + best = i + return best diff --git a/spacy/pipeline/_parser_internals/transition_system.pyx b/spacy/pipeline/_parser_internals/transition_system.pyx index dd18606c1..89f9e8ae8 100644 --- a/spacy/pipeline/_parser_internals/transition_system.pyx +++ b/spacy/pipeline/_parser_internals/transition_system.pyx @@ -12,6 +12,7 @@ from ...typedefs cimport weight_t, attr_t from ...tokens.doc cimport Doc from ...structs cimport TokenC from .stateclass cimport StateClass +from ._parser_utils cimport arg_max_if_valid from ...errors import Errors from ... import util @@ -320,11 +321,3 @@ cdef void c_transition_batch(TransitionSystem moves, StateC** states, const floa states[i].history.push_back(guess) free(is_valid) - -cdef int arg_max_if_valid(const weight_t* scores, const int* is_valid, int n) nogil: - cdef int best = -1 - for i in range(n): - if is_valid[i] >= 1: - if best == -1 or scores[i] > scores[best]: - best = i - return best