From aad38972cbb5c94aed96c6a8fccf4c262f7fb745 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Mon, 30 May 2022 12:16:28 +0200 Subject: [PATCH] Restore C CPU inference in the refactored parser (#10747) * Bring back the C parsing model The C parsing model is used for CPU inference and is still faster for CPU inference than the forward pass of the Thinc model. * Use C sgemm provided by the Ops implementation * Make tb_framework module Cython, merge in C forward implementation * TransitionModel: raise in backprop returned from forward_cpu * Re-enable greedy parse test * Return transition scores when forward_cpu is used * Apply suggestions from code review Import `Model` from `thinc.api` Co-authored-by: Sofie Van Landeghem * Use relative imports in tb_framework * Don't assume a default for beam_width * We don't have a direct dependency on BLIS anymore * Rename forwards to _forward_{fallback,greedy_cpu} * Require thinc >=8.1.0,<8.2.0 * tb_framework: clean up imports * Fix return type of _get_seen_mask * Move up _forward_greedy_cpu * Style fixes. * Lower thinc lowerbound to 8.1.0.dev0 * Formatting fix Co-authored-by: Adriane Boyd Co-authored-by: Sofie Van Landeghem Co-authored-by: Adriane Boyd --- setup.py | 1 + spacy/errors.py | 1 + spacy/ml/tb_framework.pxd | 28 ++ .../ml/{tb_framework.py => tb_framework.pyx} | 254 +++++++++++++++--- .../_parser_internals/transition_system.pxd | 3 + spacy/tests/parser/test_add_label.py | 1 - 6 files changed, 252 insertions(+), 36 deletions(-) create mode 100644 spacy/ml/tb_framework.pxd rename spacy/ml/{tb_framework.py => tb_framework.pyx} (66%) diff --git a/setup.py b/setup.py index f6d010c76..17d68a1f4 100755 --- a/setup.py +++ b/setup.py @@ -31,6 +31,7 @@ MOD_NAMES = [ "spacy.vocab", "spacy.attrs", "spacy.kb", + "spacy.ml.tb_framework", "spacy.morphology", "spacy.pipeline._edit_tree_internals.edit_trees", "spacy.pipeline.morphologizer", diff --git a/spacy/errors.py b/spacy/errors.py index c82ffe882..8bf1f03de 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -919,6 +919,7 @@ class Errors(metaclass=ErrorsWithCodes): E1035 = ("Token index {i} out of bounds ({length})") E1036 = ("Cannot index into NoneNode") E1037 = ("Invalid attribute value '{attr}'.") + E1038 = ("Backprop is not supported when is_train is not set.") # Deprecated model shortcuts, only used in errors and warnings diff --git a/spacy/ml/tb_framework.pxd b/spacy/ml/tb_framework.pxd new file mode 100644 index 000000000..965508519 --- /dev/null +++ b/spacy/ml/tb_framework.pxd @@ -0,0 +1,28 @@ +from libc.stdint cimport int8_t + + +cdef struct SizesC: + int states + int classes + int hiddens + int pieces + int feats + int embed_width + int tokens + + +cdef struct WeightsC: + const float* feat_weights + const float* feat_bias + const float* hidden_bias + const float* hidden_weights + const int8_t* seen_mask + + +cdef struct ActivationsC: + int* token_ids + float* unmaxed + float* hiddens + int* is_valid + int _curr_size + int _max_size diff --git a/spacy/ml/tb_framework.py b/spacy/ml/tb_framework.pyx similarity index 66% rename from spacy/ml/tb_framework.py rename to spacy/ml/tb_framework.pyx index 7aa1a9324..e98d59a8a 100644 --- a/spacy/ml/tb_framework.py +++ b/spacy/ml/tb_framework.pyx @@ -1,15 +1,26 @@ +# cython: infer_types=True, cdivision=True, boundscheck=False from typing import List, Tuple, Any, Optional, cast -from thinc.api import Ops, Model, normal_init, chain, list2array, Linear -from thinc.api import uniform_init, glorot_uniform_init, zero_init -from thinc.types import Floats1d, Floats2d, Floats3d, Ints2d, Floats4d +from libc.string cimport memset, memcpy +from libc.stdlib cimport calloc, free, realloc +from libcpp.vector cimport vector import numpy +cimport numpy as np +from thinc.api import Model, normal_init, chain, list2array, Linear +from thinc.api import uniform_init, glorot_uniform_init, zero_init +from thinc.api import NumpyOps +from thinc.backends.linalg cimport Vec, VecVec +from thinc.backends.cblas cimport CBlas +from thinc.types import Floats1d, Floats2d, Floats3d, Ints2d, Floats4d + +from ..errors import Errors from ..pipeline._parser_internals import _beam_utils from ..pipeline._parser_internals.batch import GreedyBatch +from ..pipeline._parser_internals.transition_system cimport c_transition_batch, TransitionSystem +from ..pipeline._parser_internals.stateclass cimport StateC, StateClass from ..tokens.doc import Doc from ..util import registry -TransitionSystem = Any # TODO State = Any # TODO @@ -131,29 +142,82 @@ def init( # model = _lsuv_init(model) return model - def forward(model, docs_moves: Tuple[List[Doc], TransitionSystem], is_train: bool): - nF = model.get_dim("nF") - tok2vec = model.get_ref("tok2vec") - lower_pad = model.get_param("lower_pad") - lower_W = model.get_param("lower_W") - lower_b = model.get_param("lower_b") - upper_W = model.get_param("upper_W") - upper_b = model.get_param("upper_b") - nH = model.get_dim("nH") - nP = model.get_dim("nP") - nO = model.get_dim("nO") - nI = model.get_dim("nI") - beam_width = model.attrs["beam_width"] - beam_density = model.attrs["beam_density"] + lower_pad = model.get_param("lower_pad") + tok2vec = model.get_ref("tok2vec") - ops = model.ops docs, moves = docs_moves states = moves.init_batch(docs) tokvecs, backprop_tok2vec = tok2vec(docs, is_train) tokvecs = model.ops.xp.vstack((tokvecs, lower_pad)) feats, backprop_feats = _forward_precomputable_affine(model, tokvecs, is_train) + seen_mask = _get_seen_mask(model) + + if beam_width == 1 and not is_train and isinstance(model.ops, NumpyOps): + return _forward_greedy_cpu(model, moves, states, feats, seen_mask) + else: + return _forward_fallback(model, moves, states, tokvecs, backprop_tok2vec, feats, backprop_feats, seen_mask, is_train) + +def _forward_greedy_cpu(model: Model, TransitionSystem moves, states: List[StateClass], np.ndarray feats, + np.ndarray[np.npy_bool, ndim=1] seen_mask): + cdef vector[StateC *] c_states + cdef StateClass state + for state in states: + if not state.is_final(): + c_states.push_back(state.c) + weights = get_c_weights(model, feats.data, seen_mask) + # Precomputed features have rows for each token, plus one for padding. + cdef int n_tokens = feats.shape[0] - 1 + sizes = get_c_sizes(model, c_states.size(), n_tokens) + cdef CBlas cblas = model.ops.cblas() + scores = _parseC(cblas, moves, &c_states[0], weights, sizes) + + def backprop(dY): + raise ValueError(Errors.E1038) + + return (states, scores), backprop + +cdef list _parseC(CBlas cblas, TransitionSystem moves, StateC** states, + WeightsC weights, SizesC sizes): + cdef int i, j + cdef vector[StateC *] unfinished + cdef ActivationsC activations = alloc_activations(sizes) + cdef np.ndarray step_scores + + scores = [] + while sizes.states >= 1: + step_scores = numpy.empty((sizes.states, sizes.classes), dtype="f") + with nogil: + predict_states(cblas, &activations, step_scores.data, states, &weights, sizes) + # Validate actions, argmax, take action. + c_transition_batch(moves, states, step_scores.data, sizes.classes, + sizes.states) + for i in range(sizes.states): + if not states[i].is_final(): + unfinished.push_back(states[i]) + for i in range(unfinished.size()): + states[i] = unfinished[i] + sizes.states = unfinished.size() + scores.append(step_scores) + unfinished.clear() + free_activations(&activations) + + return scores + +def _forward_fallback(model: Model, moves: TransitionSystem, states: List[StateClass], tokvecs, backprop_tok2vec, feats, backprop_feats, seen_mask, is_train: bool): + nF = model.get_dim("nF") + lower_b = model.get_param("lower_b") + upper_W = model.get_param("upper_W") + upper_b = model.get_param("upper_b") + nH = model.get_dim("nH") + nP = model.get_dim("nP") + + beam_width = model.attrs["beam_width"] + beam_density = model.attrs["beam_density"] + + ops = model.ops + all_ids = [] all_which = [] all_statevecs = [] @@ -164,8 +228,7 @@ def forward(model, docs_moves: Tuple[List[Doc], TransitionSystem], is_train: boo batch = _beam_utils.BeamBatch( moves, states, None, width=beam_width, density=beam_density ) - seen_mask = _get_seen_mask(model) - arange = model.ops.xp.arange(nF) + arange = ops.xp.arange(nF) while not batch.is_done: ids = numpy.zeros((len(batch.get_unfinished_states()), nF), dtype="i") for i, state in enumerate(batch.get_unfinished_states()): @@ -174,16 +237,16 @@ def forward(model, docs_moves: Tuple[List[Doc], TransitionSystem], is_train: boo # to create the state vectors. preacts2f = feats[ids, arange].sum(axis=1) # type: ignore preacts2f += lower_b - preacts = model.ops.reshape3f(preacts2f, preacts2f.shape[0], nH, nP) + preacts = ops.reshape3f(preacts2f, preacts2f.shape[0], nH, nP) assert preacts.shape[0] == len(batch.get_unfinished_states()), preacts.shape statevecs, which = ops.maxout(preacts) # Multiply the state-vector by the scores weights and add the bias, # to get the logits. - scores = model.ops.gemm(statevecs, upper_W, trans2=True) + scores = ops.gemm(statevecs, upper_W, trans2=True) scores += upper_b - scores[:, seen_mask] = model.ops.xp.nanmin(scores) + scores[:, seen_mask] = ops.xp.nanmin(scores) # Transition the states, filtering out any that are finished. - cpu_scores = model.ops.to_numpy(scores) + cpu_scores = ops.to_numpy(scores) batch.advance(cpu_scores) all_scores.append(scores) if is_train: @@ -193,10 +256,9 @@ def forward(model, docs_moves: Tuple[List[Doc], TransitionSystem], is_train: boo all_which.append(which) def backprop_parser(d_states_d_scores): - d_tokvecs = model.ops.alloc2f(tokvecs.shape[0], tokvecs.shape[1]) - ids = model.ops.xp.vstack(all_ids) + ids = ops.xp.vstack(all_ids) which = ops.xp.vstack(all_which) - statevecs = model.ops.xp.vstack(all_statevecs) + statevecs = ops.xp.vstack(all_statevecs) _, d_scores = d_states_d_scores if model.attrs.get("unseen_classes"): # If we have a negative gradient (i.e. the probability should @@ -209,18 +271,18 @@ def forward(model, docs_moves: Tuple[List[Doc], TransitionSystem], is_train: boo # Calculate the gradients for the parameters of the upper layer. # The weight gemm is (nS, nO) @ (nS, nH).T model.inc_grad("upper_b", d_scores.sum(axis=0)) - model.inc_grad("upper_W", model.ops.gemm(d_scores, statevecs, trans1=True)) + model.inc_grad("upper_W", ops.gemm(d_scores, statevecs, trans1=True)) # Now calculate d_statevecs, by backproping through the upper linear layer. # This gemm is (nS, nO) @ (nO, nH) - d_statevecs = model.ops.gemm(d_scores, upper_W) + d_statevecs = ops.gemm(d_scores, upper_W) # Backprop through the maxout activation - d_preacts = model.ops.backprop_maxout(d_statevecs, which, nP) - d_preacts2f = model.ops.reshape2f(d_preacts, d_preacts.shape[0], nH * nP) + d_preacts = ops.backprop_maxout(d_statevecs, which, nP) + d_preacts2f = ops.reshape2f(d_preacts, d_preacts.shape[0], nH * nP) model.inc_grad("lower_b", d_preacts2f.sum(axis=0)) # We don't need to backprop the summation, because we pass back the IDs instead d_state_features = backprop_feats((d_preacts2f, ids)) - d_tokvecs = model.ops.alloc2f(tokvecs.shape[0], tokvecs.shape[1]) - model.ops.scatter_add(d_tokvecs, ids, d_state_features) + d_tokvecs = ops.alloc2f(tokvecs.shape[0], tokvecs.shape[1]) + ops.scatter_add(d_tokvecs, ids, d_state_features) model.inc_grad("lower_pad", d_tokvecs[-1]) return (backprop_tok2vec(d_tokvecs[:-1]), None) @@ -328,7 +390,7 @@ def _forward_reference( return (states, all_scores), backprop_parser -def _get_seen_mask(model: Model) -> Floats1d: +def _get_seen_mask(model: Model) -> numpy.array[bool, 1]: mask = model.ops.xp.zeros(model.get_dim("nO"), dtype="bool") for class_ in model.attrs.get("unseen_classes", set()): mask[class_] = True @@ -449,3 +511,125 @@ def _lsuv_init(model: Model): else: break return model + + +cdef WeightsC get_c_weights(model, const float* feats, np.ndarray[np.npy_bool, ndim=1] seen_mask) except *: + cdef np.ndarray lower_b = model.get_param("lower_b") + cdef np.ndarray upper_W = model.get_param("upper_W") + cdef np.ndarray upper_b = model.get_param("upper_b") + + cdef WeightsC output + output.feat_weights = feats + output.feat_bias = lower_b.data + output.hidden_weights = upper_W.data + output.hidden_bias = upper_b.data + output.seen_mask = seen_mask.data + + return output + + +cdef SizesC get_c_sizes(model, int batch_size, int tokens) except *: + cdef SizesC output + output.states = batch_size + output.classes = model.get_dim("nO") + output.hiddens = model.get_dim("nH") + output.pieces = model.get_dim("nP") + output.feats = model.get_dim("nF") + output.embed_width = model.get_dim("nI") + output.tokens = tokens + return output + + +cdef ActivationsC alloc_activations(SizesC n) nogil: + cdef ActivationsC A + memset(&A, 0, sizeof(A)) + resize_activations(&A, n) + return A + + +cdef void free_activations(const ActivationsC* A) nogil: + free(A.token_ids) + free(A.unmaxed) + free(A.hiddens) + free(A.is_valid) + + +cdef void resize_activations(ActivationsC* A, SizesC n) nogil: + if n.states <= A._max_size: + A._curr_size = n.states + return + if A._max_size == 0: + A.token_ids = calloc(n.states * n.feats, sizeof(A.token_ids[0])) + A.unmaxed = calloc(n.states * n.hiddens * n.pieces, sizeof(A.unmaxed[0])) + A.hiddens = calloc(n.states * n.hiddens, sizeof(A.hiddens[0])) + A.is_valid = calloc(n.states * n.classes, sizeof(A.is_valid[0])) + A._max_size = n.states + else: + A.token_ids = realloc(A.token_ids, + n.states * n.feats * sizeof(A.token_ids[0])) + A.unmaxed = realloc(A.unmaxed, + n.states * n.hiddens * n.pieces * sizeof(A.unmaxed[0])) + A.hiddens = realloc(A.hiddens, + n.states * n.hiddens * sizeof(A.hiddens[0])) + A.is_valid = realloc(A.is_valid, + n.states * n.classes * sizeof(A.is_valid[0])) + A._max_size = n.states + A._curr_size = n.states + + +cdef void predict_states(CBlas cblas, ActivationsC* A, float* scores, StateC** states, const WeightsC* W, SizesC n) nogil: + resize_activations(A, n) + for i in range(n.states): + states[i].set_context_tokens(&A.token_ids[i*n.feats], n.feats) + memset(A.unmaxed, 0, n.states * n.hiddens * n.pieces * sizeof(float)) + sum_state_features(cblas, A.unmaxed, W.feat_weights, A.token_ids, n) + for i in range(n.states): + VecVec.add_i(&A.unmaxed[i*n.hiddens*n.pieces], + W.feat_bias, 1., n.hiddens * n.pieces) + for j in range(n.hiddens): + index = i * n.hiddens * n.pieces + j * n.pieces + which = Vec.arg_max(&A.unmaxed[index], n.pieces) + A.hiddens[i*n.hiddens + j] = A.unmaxed[index + which] + if W.hidden_weights == NULL: + memcpy(scores, A.hiddens, n.states * n.classes * sizeof(float)) + else: + # Compute hidden-to-output + cblas.sgemm()(False, True, n.states, n.classes, n.hiddens, + 1.0, A.hiddens, n.hiddens, + W.hidden_weights, n.hiddens, + 0.0, scores, n.classes) + # Add bias + for i in range(n.states): + VecVec.add_i(&scores[i*n.classes], W.hidden_bias, 1., n.classes) + # Set unseen classes to minimum value + i = 0 + min_ = scores[0] + for i in range(1, n.states * n.classes): + if scores[i] < min_: + min_ = scores[i] + for i in range(n.states): + for j in range(n.classes): + if W.seen_mask[j]: + scores[i*n.classes+j] = min_ + + +cdef void sum_state_features(CBlas cblas, float* output, + const float* cached, const int* token_ids, SizesC n) nogil: + cdef int idx, b, f, i + cdef const float* feature + cdef int B = n.states + cdef int O = n.hiddens * n.pieces + cdef int F = n.feats + cdef int T = n.tokens + padding = cached + (T * F * O) + cdef int id_stride = F*O + cdef float one = 1. + for b in range(B): + for f in range(F): + if token_ids[f] < 0: + feature = &padding[f*O] + else: + idx = token_ids[f] * id_stride + f*O + feature = &cached[idx] + cblas.saxpy()(O, one, feature, 1, &output[b*O], 1) + token_ids += F diff --git a/spacy/pipeline/_parser_internals/transition_system.pxd b/spacy/pipeline/_parser_internals/transition_system.pxd index 52ebd2b8e..d2bc0f781 100644 --- a/spacy/pipeline/_parser_internals/transition_system.pxd +++ b/spacy/pipeline/_parser_internals/transition_system.pxd @@ -53,3 +53,6 @@ cdef class TransitionSystem: cdef int set_costs(self, int* is_valid, weight_t* costs, const StateC* state, gold) except -1 + +cdef void c_transition_batch(TransitionSystem moves, StateC** states, const float* scores, + int nr_class, int batch_size) nogil diff --git a/spacy/tests/parser/test_add_label.py b/spacy/tests/parser/test_add_label.py index 4c775a913..540b00f89 100644 --- a/spacy/tests/parser/test_add_label.py +++ b/spacy/tests/parser/test_add_label.py @@ -135,7 +135,6 @@ def test_ner_labels_added_implicitly_on_beam_parse(): assert "D" in ner.labels -@pytest.mark.skip(reason="greedy_parse is deprecated") def test_ner_labels_added_implicitly_on_greedy_parse(): nlp = Language() ner = nlp.add_pipe("beam_ner")