mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-04 12:20:20 +03:00
Restore C CPU inference in the refactored parser (#10747)
* Bring back the C parsing model The C parsing model is used for CPU inference and is still faster for CPU inference than the forward pass of the Thinc model. * Use C sgemm provided by the Ops implementation * Make tb_framework module Cython, merge in C forward implementation * TransitionModel: raise in backprop returned from forward_cpu * Re-enable greedy parse test * Return transition scores when forward_cpu is used * Apply suggestions from code review Import `Model` from `thinc.api` Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> * Use relative imports in tb_framework * Don't assume a default for beam_width * We don't have a direct dependency on BLIS anymore * Rename forwards to _forward_{fallback,greedy_cpu} * Require thinc >=8.1.0,<8.2.0 * tb_framework: clean up imports * Fix return type of _get_seen_mask * Move up _forward_greedy_cpu * Style fixes. * Lower thinc lowerbound to 8.1.0.dev0 * Formatting fix Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>
This commit is contained in:
parent
65c770c368
commit
aad38972cb
1
setup.py
1
setup.py
|
@ -31,6 +31,7 @@ MOD_NAMES = [
|
|||
"spacy.vocab",
|
||||
"spacy.attrs",
|
||||
"spacy.kb",
|
||||
"spacy.ml.tb_framework",
|
||||
"spacy.morphology",
|
||||
"spacy.pipeline._edit_tree_internals.edit_trees",
|
||||
"spacy.pipeline.morphologizer",
|
||||
|
|
|
@ -919,6 +919,7 @@ class Errors(metaclass=ErrorsWithCodes):
|
|||
E1035 = ("Token index {i} out of bounds ({length})")
|
||||
E1036 = ("Cannot index into NoneNode")
|
||||
E1037 = ("Invalid attribute value '{attr}'.")
|
||||
E1038 = ("Backprop is not supported when is_train is not set.")
|
||||
|
||||
|
||||
# Deprecated model shortcuts, only used in errors and warnings
|
||||
|
|
28
spacy/ml/tb_framework.pxd
Normal file
28
spacy/ml/tb_framework.pxd
Normal file
|
@ -0,0 +1,28 @@
|
|||
from libc.stdint cimport int8_t
|
||||
|
||||
|
||||
cdef struct SizesC:
|
||||
int states
|
||||
int classes
|
||||
int hiddens
|
||||
int pieces
|
||||
int feats
|
||||
int embed_width
|
||||
int tokens
|
||||
|
||||
|
||||
cdef struct WeightsC:
|
||||
const float* feat_weights
|
||||
const float* feat_bias
|
||||
const float* hidden_bias
|
||||
const float* hidden_weights
|
||||
const int8_t* seen_mask
|
||||
|
||||
|
||||
cdef struct ActivationsC:
|
||||
int* token_ids
|
||||
float* unmaxed
|
||||
float* hiddens
|
||||
int* is_valid
|
||||
int _curr_size
|
||||
int _max_size
|
|
@ -1,15 +1,26 @@
|
|||
# cython: infer_types=True, cdivision=True, boundscheck=False
|
||||
from typing import List, Tuple, Any, Optional, cast
|
||||
from thinc.api import Ops, Model, normal_init, chain, list2array, Linear
|
||||
from thinc.api import uniform_init, glorot_uniform_init, zero_init
|
||||
from thinc.types import Floats1d, Floats2d, Floats3d, Ints2d, Floats4d
|
||||
from libc.string cimport memset, memcpy
|
||||
from libc.stdlib cimport calloc, free, realloc
|
||||
from libcpp.vector cimport vector
|
||||
import numpy
|
||||
cimport numpy as np
|
||||
from thinc.api import Model, normal_init, chain, list2array, Linear
|
||||
from thinc.api import uniform_init, glorot_uniform_init, zero_init
|
||||
from thinc.api import NumpyOps
|
||||
from thinc.backends.linalg cimport Vec, VecVec
|
||||
from thinc.backends.cblas cimport CBlas
|
||||
from thinc.types import Floats1d, Floats2d, Floats3d, Ints2d, Floats4d
|
||||
|
||||
from ..errors import Errors
|
||||
from ..pipeline._parser_internals import _beam_utils
|
||||
from ..pipeline._parser_internals.batch import GreedyBatch
|
||||
from ..pipeline._parser_internals.transition_system cimport c_transition_batch, TransitionSystem
|
||||
from ..pipeline._parser_internals.stateclass cimport StateC, StateClass
|
||||
from ..tokens.doc import Doc
|
||||
from ..util import registry
|
||||
|
||||
|
||||
TransitionSystem = Any # TODO
|
||||
State = Any # TODO
|
||||
|
||||
|
||||
|
@ -131,29 +142,82 @@ def init(
|
|||
# model = _lsuv_init(model)
|
||||
return model
|
||||
|
||||
|
||||
def forward(model, docs_moves: Tuple[List[Doc], TransitionSystem], is_train: bool):
|
||||
nF = model.get_dim("nF")
|
||||
tok2vec = model.get_ref("tok2vec")
|
||||
lower_pad = model.get_param("lower_pad")
|
||||
lower_W = model.get_param("lower_W")
|
||||
lower_b = model.get_param("lower_b")
|
||||
upper_W = model.get_param("upper_W")
|
||||
upper_b = model.get_param("upper_b")
|
||||
nH = model.get_dim("nH")
|
||||
nP = model.get_dim("nP")
|
||||
nO = model.get_dim("nO")
|
||||
nI = model.get_dim("nI")
|
||||
|
||||
beam_width = model.attrs["beam_width"]
|
||||
beam_density = model.attrs["beam_density"]
|
||||
lower_pad = model.get_param("lower_pad")
|
||||
tok2vec = model.get_ref("tok2vec")
|
||||
|
||||
ops = model.ops
|
||||
docs, moves = docs_moves
|
||||
states = moves.init_batch(docs)
|
||||
tokvecs, backprop_tok2vec = tok2vec(docs, is_train)
|
||||
tokvecs = model.ops.xp.vstack((tokvecs, lower_pad))
|
||||
feats, backprop_feats = _forward_precomputable_affine(model, tokvecs, is_train)
|
||||
seen_mask = _get_seen_mask(model)
|
||||
|
||||
if beam_width == 1 and not is_train and isinstance(model.ops, NumpyOps):
|
||||
return _forward_greedy_cpu(model, moves, states, feats, seen_mask)
|
||||
else:
|
||||
return _forward_fallback(model, moves, states, tokvecs, backprop_tok2vec, feats, backprop_feats, seen_mask, is_train)
|
||||
|
||||
def _forward_greedy_cpu(model: Model, TransitionSystem moves, states: List[StateClass], np.ndarray feats,
|
||||
np.ndarray[np.npy_bool, ndim=1] seen_mask):
|
||||
cdef vector[StateC *] c_states
|
||||
cdef StateClass state
|
||||
for state in states:
|
||||
if not state.is_final():
|
||||
c_states.push_back(state.c)
|
||||
weights = get_c_weights(model, <float *> feats.data, seen_mask)
|
||||
# Precomputed features have rows for each token, plus one for padding.
|
||||
cdef int n_tokens = feats.shape[0] - 1
|
||||
sizes = get_c_sizes(model, c_states.size(), n_tokens)
|
||||
cdef CBlas cblas = model.ops.cblas()
|
||||
scores = _parseC(cblas, moves, &c_states[0], weights, sizes)
|
||||
|
||||
def backprop(dY):
|
||||
raise ValueError(Errors.E1038)
|
||||
|
||||
return (states, scores), backprop
|
||||
|
||||
cdef list _parseC(CBlas cblas, TransitionSystem moves, StateC** states,
|
||||
WeightsC weights, SizesC sizes):
|
||||
cdef int i, j
|
||||
cdef vector[StateC *] unfinished
|
||||
cdef ActivationsC activations = alloc_activations(sizes)
|
||||
cdef np.ndarray step_scores
|
||||
|
||||
scores = []
|
||||
while sizes.states >= 1:
|
||||
step_scores = numpy.empty((sizes.states, sizes.classes), dtype="f")
|
||||
with nogil:
|
||||
predict_states(cblas, &activations, <float *> step_scores.data, states, &weights, sizes)
|
||||
# Validate actions, argmax, take action.
|
||||
c_transition_batch(moves, states, <const float *> step_scores.data, sizes.classes,
|
||||
sizes.states)
|
||||
for i in range(sizes.states):
|
||||
if not states[i].is_final():
|
||||
unfinished.push_back(states[i])
|
||||
for i in range(unfinished.size()):
|
||||
states[i] = unfinished[i]
|
||||
sizes.states = unfinished.size()
|
||||
scores.append(step_scores)
|
||||
unfinished.clear()
|
||||
free_activations(&activations)
|
||||
|
||||
return scores
|
||||
|
||||
def _forward_fallback(model: Model, moves: TransitionSystem, states: List[StateClass], tokvecs, backprop_tok2vec, feats, backprop_feats, seen_mask, is_train: bool):
|
||||
nF = model.get_dim("nF")
|
||||
lower_b = model.get_param("lower_b")
|
||||
upper_W = model.get_param("upper_W")
|
||||
upper_b = model.get_param("upper_b")
|
||||
nH = model.get_dim("nH")
|
||||
nP = model.get_dim("nP")
|
||||
|
||||
beam_width = model.attrs["beam_width"]
|
||||
beam_density = model.attrs["beam_density"]
|
||||
|
||||
ops = model.ops
|
||||
|
||||
all_ids = []
|
||||
all_which = []
|
||||
all_statevecs = []
|
||||
|
@ -164,8 +228,7 @@ def forward(model, docs_moves: Tuple[List[Doc], TransitionSystem], is_train: boo
|
|||
batch = _beam_utils.BeamBatch(
|
||||
moves, states, None, width=beam_width, density=beam_density
|
||||
)
|
||||
seen_mask = _get_seen_mask(model)
|
||||
arange = model.ops.xp.arange(nF)
|
||||
arange = ops.xp.arange(nF)
|
||||
while not batch.is_done:
|
||||
ids = numpy.zeros((len(batch.get_unfinished_states()), nF), dtype="i")
|
||||
for i, state in enumerate(batch.get_unfinished_states()):
|
||||
|
@ -174,16 +237,16 @@ def forward(model, docs_moves: Tuple[List[Doc], TransitionSystem], is_train: boo
|
|||
# to create the state vectors.
|
||||
preacts2f = feats[ids, arange].sum(axis=1) # type: ignore
|
||||
preacts2f += lower_b
|
||||
preacts = model.ops.reshape3f(preacts2f, preacts2f.shape[0], nH, nP)
|
||||
preacts = ops.reshape3f(preacts2f, preacts2f.shape[0], nH, nP)
|
||||
assert preacts.shape[0] == len(batch.get_unfinished_states()), preacts.shape
|
||||
statevecs, which = ops.maxout(preacts)
|
||||
# Multiply the state-vector by the scores weights and add the bias,
|
||||
# to get the logits.
|
||||
scores = model.ops.gemm(statevecs, upper_W, trans2=True)
|
||||
scores = ops.gemm(statevecs, upper_W, trans2=True)
|
||||
scores += upper_b
|
||||
scores[:, seen_mask] = model.ops.xp.nanmin(scores)
|
||||
scores[:, seen_mask] = ops.xp.nanmin(scores)
|
||||
# Transition the states, filtering out any that are finished.
|
||||
cpu_scores = model.ops.to_numpy(scores)
|
||||
cpu_scores = ops.to_numpy(scores)
|
||||
batch.advance(cpu_scores)
|
||||
all_scores.append(scores)
|
||||
if is_train:
|
||||
|
@ -193,10 +256,9 @@ def forward(model, docs_moves: Tuple[List[Doc], TransitionSystem], is_train: boo
|
|||
all_which.append(which)
|
||||
|
||||
def backprop_parser(d_states_d_scores):
|
||||
d_tokvecs = model.ops.alloc2f(tokvecs.shape[0], tokvecs.shape[1])
|
||||
ids = model.ops.xp.vstack(all_ids)
|
||||
ids = ops.xp.vstack(all_ids)
|
||||
which = ops.xp.vstack(all_which)
|
||||
statevecs = model.ops.xp.vstack(all_statevecs)
|
||||
statevecs = ops.xp.vstack(all_statevecs)
|
||||
_, d_scores = d_states_d_scores
|
||||
if model.attrs.get("unseen_classes"):
|
||||
# If we have a negative gradient (i.e. the probability should
|
||||
|
@ -209,18 +271,18 @@ def forward(model, docs_moves: Tuple[List[Doc], TransitionSystem], is_train: boo
|
|||
# Calculate the gradients for the parameters of the upper layer.
|
||||
# The weight gemm is (nS, nO) @ (nS, nH).T
|
||||
model.inc_grad("upper_b", d_scores.sum(axis=0))
|
||||
model.inc_grad("upper_W", model.ops.gemm(d_scores, statevecs, trans1=True))
|
||||
model.inc_grad("upper_W", ops.gemm(d_scores, statevecs, trans1=True))
|
||||
# Now calculate d_statevecs, by backproping through the upper linear layer.
|
||||
# This gemm is (nS, nO) @ (nO, nH)
|
||||
d_statevecs = model.ops.gemm(d_scores, upper_W)
|
||||
d_statevecs = ops.gemm(d_scores, upper_W)
|
||||
# Backprop through the maxout activation
|
||||
d_preacts = model.ops.backprop_maxout(d_statevecs, which, nP)
|
||||
d_preacts2f = model.ops.reshape2f(d_preacts, d_preacts.shape[0], nH * nP)
|
||||
d_preacts = ops.backprop_maxout(d_statevecs, which, nP)
|
||||
d_preacts2f = ops.reshape2f(d_preacts, d_preacts.shape[0], nH * nP)
|
||||
model.inc_grad("lower_b", d_preacts2f.sum(axis=0))
|
||||
# We don't need to backprop the summation, because we pass back the IDs instead
|
||||
d_state_features = backprop_feats((d_preacts2f, ids))
|
||||
d_tokvecs = model.ops.alloc2f(tokvecs.shape[0], tokvecs.shape[1])
|
||||
model.ops.scatter_add(d_tokvecs, ids, d_state_features)
|
||||
d_tokvecs = ops.alloc2f(tokvecs.shape[0], tokvecs.shape[1])
|
||||
ops.scatter_add(d_tokvecs, ids, d_state_features)
|
||||
model.inc_grad("lower_pad", d_tokvecs[-1])
|
||||
return (backprop_tok2vec(d_tokvecs[:-1]), None)
|
||||
|
||||
|
@ -328,7 +390,7 @@ def _forward_reference(
|
|||
return (states, all_scores), backprop_parser
|
||||
|
||||
|
||||
def _get_seen_mask(model: Model) -> Floats1d:
|
||||
def _get_seen_mask(model: Model) -> numpy.array[bool, 1]:
|
||||
mask = model.ops.xp.zeros(model.get_dim("nO"), dtype="bool")
|
||||
for class_ in model.attrs.get("unseen_classes", set()):
|
||||
mask[class_] = True
|
||||
|
@ -449,3 +511,125 @@ def _lsuv_init(model: Model):
|
|||
else:
|
||||
break
|
||||
return model
|
||||
|
||||
|
||||
cdef WeightsC get_c_weights(model, const float* feats, np.ndarray[np.npy_bool, ndim=1] seen_mask) except *:
|
||||
cdef np.ndarray lower_b = model.get_param("lower_b")
|
||||
cdef np.ndarray upper_W = model.get_param("upper_W")
|
||||
cdef np.ndarray upper_b = model.get_param("upper_b")
|
||||
|
||||
cdef WeightsC output
|
||||
output.feat_weights = feats
|
||||
output.feat_bias = <const float*>lower_b.data
|
||||
output.hidden_weights = <const float *> upper_W.data
|
||||
output.hidden_bias = <const float *> upper_b.data
|
||||
output.seen_mask = <const int8_t*> seen_mask.data
|
||||
|
||||
return output
|
||||
|
||||
|
||||
cdef SizesC get_c_sizes(model, int batch_size, int tokens) except *:
|
||||
cdef SizesC output
|
||||
output.states = batch_size
|
||||
output.classes = model.get_dim("nO")
|
||||
output.hiddens = model.get_dim("nH")
|
||||
output.pieces = model.get_dim("nP")
|
||||
output.feats = model.get_dim("nF")
|
||||
output.embed_width = model.get_dim("nI")
|
||||
output.tokens = tokens
|
||||
return output
|
||||
|
||||
|
||||
cdef ActivationsC alloc_activations(SizesC n) nogil:
|
||||
cdef ActivationsC A
|
||||
memset(&A, 0, sizeof(A))
|
||||
resize_activations(&A, n)
|
||||
return A
|
||||
|
||||
|
||||
cdef void free_activations(const ActivationsC* A) nogil:
|
||||
free(A.token_ids)
|
||||
free(A.unmaxed)
|
||||
free(A.hiddens)
|
||||
free(A.is_valid)
|
||||
|
||||
|
||||
cdef void resize_activations(ActivationsC* A, SizesC n) nogil:
|
||||
if n.states <= A._max_size:
|
||||
A._curr_size = n.states
|
||||
return
|
||||
if A._max_size == 0:
|
||||
A.token_ids = <int*>calloc(n.states * n.feats, sizeof(A.token_ids[0]))
|
||||
A.unmaxed = <float*>calloc(n.states * n.hiddens * n.pieces, sizeof(A.unmaxed[0]))
|
||||
A.hiddens = <float*>calloc(n.states * n.hiddens, sizeof(A.hiddens[0]))
|
||||
A.is_valid = <int*>calloc(n.states * n.classes, sizeof(A.is_valid[0]))
|
||||
A._max_size = n.states
|
||||
else:
|
||||
A.token_ids = <int*>realloc(A.token_ids,
|
||||
n.states * n.feats * sizeof(A.token_ids[0]))
|
||||
A.unmaxed = <float*>realloc(A.unmaxed,
|
||||
n.states * n.hiddens * n.pieces * sizeof(A.unmaxed[0]))
|
||||
A.hiddens = <float*>realloc(A.hiddens,
|
||||
n.states * n.hiddens * sizeof(A.hiddens[0]))
|
||||
A.is_valid = <int*>realloc(A.is_valid,
|
||||
n.states * n.classes * sizeof(A.is_valid[0]))
|
||||
A._max_size = n.states
|
||||
A._curr_size = n.states
|
||||
|
||||
|
||||
cdef void predict_states(CBlas cblas, ActivationsC* A, float* scores, StateC** states, const WeightsC* W, SizesC n) nogil:
|
||||
resize_activations(A, n)
|
||||
for i in range(n.states):
|
||||
states[i].set_context_tokens(&A.token_ids[i*n.feats], n.feats)
|
||||
memset(A.unmaxed, 0, n.states * n.hiddens * n.pieces * sizeof(float))
|
||||
sum_state_features(cblas, A.unmaxed, W.feat_weights, A.token_ids, n)
|
||||
for i in range(n.states):
|
||||
VecVec.add_i(&A.unmaxed[i*n.hiddens*n.pieces],
|
||||
W.feat_bias, 1., n.hiddens * n.pieces)
|
||||
for j in range(n.hiddens):
|
||||
index = i * n.hiddens * n.pieces + j * n.pieces
|
||||
which = Vec.arg_max(&A.unmaxed[index], n.pieces)
|
||||
A.hiddens[i*n.hiddens + j] = A.unmaxed[index + which]
|
||||
if W.hidden_weights == NULL:
|
||||
memcpy(scores, A.hiddens, n.states * n.classes * sizeof(float))
|
||||
else:
|
||||
# Compute hidden-to-output
|
||||
cblas.sgemm()(False, True, n.states, n.classes, n.hiddens,
|
||||
1.0, <const float *>A.hiddens, n.hiddens,
|
||||
<const float *>W.hidden_weights, n.hiddens,
|
||||
0.0, scores, n.classes)
|
||||
# Add bias
|
||||
for i in range(n.states):
|
||||
VecVec.add_i(&scores[i*n.classes], W.hidden_bias, 1., n.classes)
|
||||
# Set unseen classes to minimum value
|
||||
i = 0
|
||||
min_ = scores[0]
|
||||
for i in range(1, n.states * n.classes):
|
||||
if scores[i] < min_:
|
||||
min_ = scores[i]
|
||||
for i in range(n.states):
|
||||
for j in range(n.classes):
|
||||
if W.seen_mask[j]:
|
||||
scores[i*n.classes+j] = min_
|
||||
|
||||
|
||||
cdef void sum_state_features(CBlas cblas, float* output,
|
||||
const float* cached, const int* token_ids, SizesC n) nogil:
|
||||
cdef int idx, b, f, i
|
||||
cdef const float* feature
|
||||
cdef int B = n.states
|
||||
cdef int O = n.hiddens * n.pieces
|
||||
cdef int F = n.feats
|
||||
cdef int T = n.tokens
|
||||
padding = cached + (T * F * O)
|
||||
cdef int id_stride = F*O
|
||||
cdef float one = 1.
|
||||
for b in range(B):
|
||||
for f in range(F):
|
||||
if token_ids[f] < 0:
|
||||
feature = &padding[f*O]
|
||||
else:
|
||||
idx = token_ids[f] * id_stride + f*O
|
||||
feature = &cached[idx]
|
||||
cblas.saxpy()(O, one, <const float*>feature, 1, &output[b*O], 1)
|
||||
token_ids += F
|
|
@ -53,3 +53,6 @@ cdef class TransitionSystem:
|
|||
|
||||
cdef int set_costs(self, int* is_valid, weight_t* costs,
|
||||
const StateC* state, gold) except -1
|
||||
|
||||
cdef void c_transition_batch(TransitionSystem moves, StateC** states, const float* scores,
|
||||
int nr_class, int batch_size) nogil
|
||||
|
|
|
@ -135,7 +135,6 @@ def test_ner_labels_added_implicitly_on_beam_parse():
|
|||
assert "D" in ner.labels
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="greedy_parse is deprecated")
|
||||
def test_ner_labels_added_implicitly_on_greedy_parse():
|
||||
nlp = Language()
|
||||
ner = nlp.add_pipe("beam_ner")
|
||||
|
|
Loading…
Reference in New Issue
Block a user