Restore C CPU inference in the refactored parser (#10747)

* Bring back the C parsing model

The C parsing model is used for CPU inference and is still faster for
CPU inference than the forward pass of the Thinc model.

* Use C sgemm provided by the Ops implementation

* Make tb_framework module Cython, merge in C forward implementation

* TransitionModel: raise in backprop returned from forward_cpu

* Re-enable greedy parse test

* Return transition scores when forward_cpu is used

* Apply suggestions from code review

Import `Model` from `thinc.api`

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Use relative imports in tb_framework

* Don't assume a default for beam_width

* We don't have a direct dependency on BLIS anymore

* Rename forwards to _forward_{fallback,greedy_cpu}

* Require thinc >=8.1.0,<8.2.0

* tb_framework: clean up imports

* Fix return type of _get_seen_mask

* Move up _forward_greedy_cpu

* Style fixes.

* Lower thinc lowerbound to 8.1.0.dev0

* Formatting fix

Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>
This commit is contained in:
Daniël de Kok 2022-05-30 12:16:28 +02:00 committed by GitHub
parent 65c770c368
commit aad38972cb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 252 additions and 36 deletions

View File

@ -31,6 +31,7 @@ MOD_NAMES = [
"spacy.vocab",
"spacy.attrs",
"spacy.kb",
"spacy.ml.tb_framework",
"spacy.morphology",
"spacy.pipeline._edit_tree_internals.edit_trees",
"spacy.pipeline.morphologizer",

View File

@ -919,6 +919,7 @@ class Errors(metaclass=ErrorsWithCodes):
E1035 = ("Token index {i} out of bounds ({length})")
E1036 = ("Cannot index into NoneNode")
E1037 = ("Invalid attribute value '{attr}'.")
E1038 = ("Backprop is not supported when is_train is not set.")
# Deprecated model shortcuts, only used in errors and warnings

28
spacy/ml/tb_framework.pxd Normal file
View File

@ -0,0 +1,28 @@
from libc.stdint cimport int8_t
cdef struct SizesC:
int states
int classes
int hiddens
int pieces
int feats
int embed_width
int tokens
cdef struct WeightsC:
const float* feat_weights
const float* feat_bias
const float* hidden_bias
const float* hidden_weights
const int8_t* seen_mask
cdef struct ActivationsC:
int* token_ids
float* unmaxed
float* hiddens
int* is_valid
int _curr_size
int _max_size

View File

@ -1,15 +1,26 @@
# cython: infer_types=True, cdivision=True, boundscheck=False
from typing import List, Tuple, Any, Optional, cast
from thinc.api import Ops, Model, normal_init, chain, list2array, Linear
from thinc.api import uniform_init, glorot_uniform_init, zero_init
from thinc.types import Floats1d, Floats2d, Floats3d, Ints2d, Floats4d
from libc.string cimport memset, memcpy
from libc.stdlib cimport calloc, free, realloc
from libcpp.vector cimport vector
import numpy
cimport numpy as np
from thinc.api import Model, normal_init, chain, list2array, Linear
from thinc.api import uniform_init, glorot_uniform_init, zero_init
from thinc.api import NumpyOps
from thinc.backends.linalg cimport Vec, VecVec
from thinc.backends.cblas cimport CBlas
from thinc.types import Floats1d, Floats2d, Floats3d, Ints2d, Floats4d
from ..errors import Errors
from ..pipeline._parser_internals import _beam_utils
from ..pipeline._parser_internals.batch import GreedyBatch
from ..pipeline._parser_internals.transition_system cimport c_transition_batch, TransitionSystem
from ..pipeline._parser_internals.stateclass cimport StateC, StateClass
from ..tokens.doc import Doc
from ..util import registry
TransitionSystem = Any # TODO
State = Any # TODO
@ -131,29 +142,82 @@ def init(
# model = _lsuv_init(model)
return model
def forward(model, docs_moves: Tuple[List[Doc], TransitionSystem], is_train: bool):
nF = model.get_dim("nF")
tok2vec = model.get_ref("tok2vec")
lower_pad = model.get_param("lower_pad")
lower_W = model.get_param("lower_W")
lower_b = model.get_param("lower_b")
upper_W = model.get_param("upper_W")
upper_b = model.get_param("upper_b")
nH = model.get_dim("nH")
nP = model.get_dim("nP")
nO = model.get_dim("nO")
nI = model.get_dim("nI")
beam_width = model.attrs["beam_width"]
beam_density = model.attrs["beam_density"]
lower_pad = model.get_param("lower_pad")
tok2vec = model.get_ref("tok2vec")
ops = model.ops
docs, moves = docs_moves
states = moves.init_batch(docs)
tokvecs, backprop_tok2vec = tok2vec(docs, is_train)
tokvecs = model.ops.xp.vstack((tokvecs, lower_pad))
feats, backprop_feats = _forward_precomputable_affine(model, tokvecs, is_train)
seen_mask = _get_seen_mask(model)
if beam_width == 1 and not is_train and isinstance(model.ops, NumpyOps):
return _forward_greedy_cpu(model, moves, states, feats, seen_mask)
else:
return _forward_fallback(model, moves, states, tokvecs, backprop_tok2vec, feats, backprop_feats, seen_mask, is_train)
def _forward_greedy_cpu(model: Model, TransitionSystem moves, states: List[StateClass], np.ndarray feats,
np.ndarray[np.npy_bool, ndim=1] seen_mask):
cdef vector[StateC *] c_states
cdef StateClass state
for state in states:
if not state.is_final():
c_states.push_back(state.c)
weights = get_c_weights(model, <float *> feats.data, seen_mask)
# Precomputed features have rows for each token, plus one for padding.
cdef int n_tokens = feats.shape[0] - 1
sizes = get_c_sizes(model, c_states.size(), n_tokens)
cdef CBlas cblas = model.ops.cblas()
scores = _parseC(cblas, moves, &c_states[0], weights, sizes)
def backprop(dY):
raise ValueError(Errors.E1038)
return (states, scores), backprop
cdef list _parseC(CBlas cblas, TransitionSystem moves, StateC** states,
WeightsC weights, SizesC sizes):
cdef int i, j
cdef vector[StateC *] unfinished
cdef ActivationsC activations = alloc_activations(sizes)
cdef np.ndarray step_scores
scores = []
while sizes.states >= 1:
step_scores = numpy.empty((sizes.states, sizes.classes), dtype="f")
with nogil:
predict_states(cblas, &activations, <float *> step_scores.data, states, &weights, sizes)
# Validate actions, argmax, take action.
c_transition_batch(moves, states, <const float *> step_scores.data, sizes.classes,
sizes.states)
for i in range(sizes.states):
if not states[i].is_final():
unfinished.push_back(states[i])
for i in range(unfinished.size()):
states[i] = unfinished[i]
sizes.states = unfinished.size()
scores.append(step_scores)
unfinished.clear()
free_activations(&activations)
return scores
def _forward_fallback(model: Model, moves: TransitionSystem, states: List[StateClass], tokvecs, backprop_tok2vec, feats, backprop_feats, seen_mask, is_train: bool):
nF = model.get_dim("nF")
lower_b = model.get_param("lower_b")
upper_W = model.get_param("upper_W")
upper_b = model.get_param("upper_b")
nH = model.get_dim("nH")
nP = model.get_dim("nP")
beam_width = model.attrs["beam_width"]
beam_density = model.attrs["beam_density"]
ops = model.ops
all_ids = []
all_which = []
all_statevecs = []
@ -164,8 +228,7 @@ def forward(model, docs_moves: Tuple[List[Doc], TransitionSystem], is_train: boo
batch = _beam_utils.BeamBatch(
moves, states, None, width=beam_width, density=beam_density
)
seen_mask = _get_seen_mask(model)
arange = model.ops.xp.arange(nF)
arange = ops.xp.arange(nF)
while not batch.is_done:
ids = numpy.zeros((len(batch.get_unfinished_states()), nF), dtype="i")
for i, state in enumerate(batch.get_unfinished_states()):
@ -174,16 +237,16 @@ def forward(model, docs_moves: Tuple[List[Doc], TransitionSystem], is_train: boo
# to create the state vectors.
preacts2f = feats[ids, arange].sum(axis=1) # type: ignore
preacts2f += lower_b
preacts = model.ops.reshape3f(preacts2f, preacts2f.shape[0], nH, nP)
preacts = ops.reshape3f(preacts2f, preacts2f.shape[0], nH, nP)
assert preacts.shape[0] == len(batch.get_unfinished_states()), preacts.shape
statevecs, which = ops.maxout(preacts)
# Multiply the state-vector by the scores weights and add the bias,
# to get the logits.
scores = model.ops.gemm(statevecs, upper_W, trans2=True)
scores = ops.gemm(statevecs, upper_W, trans2=True)
scores += upper_b
scores[:, seen_mask] = model.ops.xp.nanmin(scores)
scores[:, seen_mask] = ops.xp.nanmin(scores)
# Transition the states, filtering out any that are finished.
cpu_scores = model.ops.to_numpy(scores)
cpu_scores = ops.to_numpy(scores)
batch.advance(cpu_scores)
all_scores.append(scores)
if is_train:
@ -193,10 +256,9 @@ def forward(model, docs_moves: Tuple[List[Doc], TransitionSystem], is_train: boo
all_which.append(which)
def backprop_parser(d_states_d_scores):
d_tokvecs = model.ops.alloc2f(tokvecs.shape[0], tokvecs.shape[1])
ids = model.ops.xp.vstack(all_ids)
ids = ops.xp.vstack(all_ids)
which = ops.xp.vstack(all_which)
statevecs = model.ops.xp.vstack(all_statevecs)
statevecs = ops.xp.vstack(all_statevecs)
_, d_scores = d_states_d_scores
if model.attrs.get("unseen_classes"):
# If we have a negative gradient (i.e. the probability should
@ -209,18 +271,18 @@ def forward(model, docs_moves: Tuple[List[Doc], TransitionSystem], is_train: boo
# Calculate the gradients for the parameters of the upper layer.
# The weight gemm is (nS, nO) @ (nS, nH).T
model.inc_grad("upper_b", d_scores.sum(axis=0))
model.inc_grad("upper_W", model.ops.gemm(d_scores, statevecs, trans1=True))
model.inc_grad("upper_W", ops.gemm(d_scores, statevecs, trans1=True))
# Now calculate d_statevecs, by backproping through the upper linear layer.
# This gemm is (nS, nO) @ (nO, nH)
d_statevecs = model.ops.gemm(d_scores, upper_W)
d_statevecs = ops.gemm(d_scores, upper_W)
# Backprop through the maxout activation
d_preacts = model.ops.backprop_maxout(d_statevecs, which, nP)
d_preacts2f = model.ops.reshape2f(d_preacts, d_preacts.shape[0], nH * nP)
d_preacts = ops.backprop_maxout(d_statevecs, which, nP)
d_preacts2f = ops.reshape2f(d_preacts, d_preacts.shape[0], nH * nP)
model.inc_grad("lower_b", d_preacts2f.sum(axis=0))
# We don't need to backprop the summation, because we pass back the IDs instead
d_state_features = backprop_feats((d_preacts2f, ids))
d_tokvecs = model.ops.alloc2f(tokvecs.shape[0], tokvecs.shape[1])
model.ops.scatter_add(d_tokvecs, ids, d_state_features)
d_tokvecs = ops.alloc2f(tokvecs.shape[0], tokvecs.shape[1])
ops.scatter_add(d_tokvecs, ids, d_state_features)
model.inc_grad("lower_pad", d_tokvecs[-1])
return (backprop_tok2vec(d_tokvecs[:-1]), None)
@ -328,7 +390,7 @@ def _forward_reference(
return (states, all_scores), backprop_parser
def _get_seen_mask(model: Model) -> Floats1d:
def _get_seen_mask(model: Model) -> numpy.array[bool, 1]:
mask = model.ops.xp.zeros(model.get_dim("nO"), dtype="bool")
for class_ in model.attrs.get("unseen_classes", set()):
mask[class_] = True
@ -449,3 +511,125 @@ def _lsuv_init(model: Model):
else:
break
return model
cdef WeightsC get_c_weights(model, const float* feats, np.ndarray[np.npy_bool, ndim=1] seen_mask) except *:
cdef np.ndarray lower_b = model.get_param("lower_b")
cdef np.ndarray upper_W = model.get_param("upper_W")
cdef np.ndarray upper_b = model.get_param("upper_b")
cdef WeightsC output
output.feat_weights = feats
output.feat_bias = <const float*>lower_b.data
output.hidden_weights = <const float *> upper_W.data
output.hidden_bias = <const float *> upper_b.data
output.seen_mask = <const int8_t*> seen_mask.data
return output
cdef SizesC get_c_sizes(model, int batch_size, int tokens) except *:
cdef SizesC output
output.states = batch_size
output.classes = model.get_dim("nO")
output.hiddens = model.get_dim("nH")
output.pieces = model.get_dim("nP")
output.feats = model.get_dim("nF")
output.embed_width = model.get_dim("nI")
output.tokens = tokens
return output
cdef ActivationsC alloc_activations(SizesC n) nogil:
cdef ActivationsC A
memset(&A, 0, sizeof(A))
resize_activations(&A, n)
return A
cdef void free_activations(const ActivationsC* A) nogil:
free(A.token_ids)
free(A.unmaxed)
free(A.hiddens)
free(A.is_valid)
cdef void resize_activations(ActivationsC* A, SizesC n) nogil:
if n.states <= A._max_size:
A._curr_size = n.states
return
if A._max_size == 0:
A.token_ids = <int*>calloc(n.states * n.feats, sizeof(A.token_ids[0]))
A.unmaxed = <float*>calloc(n.states * n.hiddens * n.pieces, sizeof(A.unmaxed[0]))
A.hiddens = <float*>calloc(n.states * n.hiddens, sizeof(A.hiddens[0]))
A.is_valid = <int*>calloc(n.states * n.classes, sizeof(A.is_valid[0]))
A._max_size = n.states
else:
A.token_ids = <int*>realloc(A.token_ids,
n.states * n.feats * sizeof(A.token_ids[0]))
A.unmaxed = <float*>realloc(A.unmaxed,
n.states * n.hiddens * n.pieces * sizeof(A.unmaxed[0]))
A.hiddens = <float*>realloc(A.hiddens,
n.states * n.hiddens * sizeof(A.hiddens[0]))
A.is_valid = <int*>realloc(A.is_valid,
n.states * n.classes * sizeof(A.is_valid[0]))
A._max_size = n.states
A._curr_size = n.states
cdef void predict_states(CBlas cblas, ActivationsC* A, float* scores, StateC** states, const WeightsC* W, SizesC n) nogil:
resize_activations(A, n)
for i in range(n.states):
states[i].set_context_tokens(&A.token_ids[i*n.feats], n.feats)
memset(A.unmaxed, 0, n.states * n.hiddens * n.pieces * sizeof(float))
sum_state_features(cblas, A.unmaxed, W.feat_weights, A.token_ids, n)
for i in range(n.states):
VecVec.add_i(&A.unmaxed[i*n.hiddens*n.pieces],
W.feat_bias, 1., n.hiddens * n.pieces)
for j in range(n.hiddens):
index = i * n.hiddens * n.pieces + j * n.pieces
which = Vec.arg_max(&A.unmaxed[index], n.pieces)
A.hiddens[i*n.hiddens + j] = A.unmaxed[index + which]
if W.hidden_weights == NULL:
memcpy(scores, A.hiddens, n.states * n.classes * sizeof(float))
else:
# Compute hidden-to-output
cblas.sgemm()(False, True, n.states, n.classes, n.hiddens,
1.0, <const float *>A.hiddens, n.hiddens,
<const float *>W.hidden_weights, n.hiddens,
0.0, scores, n.classes)
# Add bias
for i in range(n.states):
VecVec.add_i(&scores[i*n.classes], W.hidden_bias, 1., n.classes)
# Set unseen classes to minimum value
i = 0
min_ = scores[0]
for i in range(1, n.states * n.classes):
if scores[i] < min_:
min_ = scores[i]
for i in range(n.states):
for j in range(n.classes):
if W.seen_mask[j]:
scores[i*n.classes+j] = min_
cdef void sum_state_features(CBlas cblas, float* output,
const float* cached, const int* token_ids, SizesC n) nogil:
cdef int idx, b, f, i
cdef const float* feature
cdef int B = n.states
cdef int O = n.hiddens * n.pieces
cdef int F = n.feats
cdef int T = n.tokens
padding = cached + (T * F * O)
cdef int id_stride = F*O
cdef float one = 1.
for b in range(B):
for f in range(F):
if token_ids[f] < 0:
feature = &padding[f*O]
else:
idx = token_ids[f] * id_stride + f*O
feature = &cached[idx]
cblas.saxpy()(O, one, <const float*>feature, 1, &output[b*O], 1)
token_ids += F

View File

@ -53,3 +53,6 @@ cdef class TransitionSystem:
cdef int set_costs(self, int* is_valid, weight_t* costs,
const StateC* state, gold) except -1
cdef void c_transition_batch(TransitionSystem moves, StateC** states, const float* scores,
int nr_class, int batch_size) nogil

View File

@ -135,7 +135,6 @@ def test_ner_labels_added_implicitly_on_beam_parse():
assert "D" in ner.labels
@pytest.mark.skip(reason="greedy_parse is deprecated")
def test_ner_labels_added_implicitly_on_greedy_parse():
nlp = Language()
ner = nlp.add_pipe("beam_ner")