spaCy/spacy/ml/parser_model.pyx

501 lines
18 KiB
Cython
Raw Normal View History

# cython: infer_types=True, cdivision=True, boundscheck=False
# cython: profile=False
cimport numpy as np
from libc.math cimport exp
from libc.stdlib cimport calloc, free, realloc
2023-12-08 22:24:09 +03:00
from libc.string cimport memcpy, memset
from thinc.backends.cblas cimport saxpy, sgemm
import numpy
import numpy.random
2023-12-08 22:24:09 +03:00
from thinc.api import CupyOps, Model, NumpyOps, get_ops
from .. import util
from ..errors import Errors
2023-12-08 22:24:09 +03:00
from ..pipeline._parser_internals.stateclass cimport StateClass
2023-12-18 22:02:15 +03:00
from ..typedefs cimport weight_t
cdef WeightsC get_c_weights(model) except *:
cdef WeightsC output
cdef precompute_hiddens state2vec = model.state2vec
output.feat_weights = state2vec.get_feat_weights()
output.feat_bias = <const float*>state2vec.bias.data
cdef np.ndarray vec2scores_W
cdef np.ndarray vec2scores_b
if model.vec2scores is None:
output.hidden_weights = NULL
output.hidden_bias = NULL
else:
vec2scores_W = model.vec2scores.get_param("W")
vec2scores_b = model.vec2scores.get_param("b")
output.hidden_weights = <const float*>vec2scores_W.data
output.hidden_bias = <const float*>vec2scores_b.data
cdef np.ndarray class_mask = model._class_mask
output.seen_classes = <const float*>class_mask.data
return output
cdef SizesC get_c_sizes(model, int batch_size) except *:
cdef SizesC output
output.states = batch_size
if model.vec2scores is None:
output.classes = model.state2vec.get_dim("nO")
else:
output.classes = model.vec2scores.get_dim("nO")
output.hiddens = model.state2vec.get_dim("nO")
output.pieces = model.state2vec.get_dim("nP")
output.feats = model.state2vec.get_dim("nF")
output.embed_width = model.tokvecs.shape[1]
return output
cdef ActivationsC alloc_activations(SizesC n) nogil:
cdef ActivationsC A
memset(&A, 0, sizeof(A))
resize_activations(&A, n)
return A
cdef void free_activations(const ActivationsC* A) nogil:
free(A.token_ids)
free(A.scores)
free(A.unmaxed)
free(A.hiddens)
free(A.is_valid)
cdef void resize_activations(ActivationsC* A, SizesC n) nogil:
if n.states <= A._max_size:
A._curr_size = n.states
return
if A._max_size == 0:
A.token_ids = <int*>calloc(n.states * n.feats, sizeof(A.token_ids[0]))
A.scores = <float*>calloc(n.states * n.classes, sizeof(A.scores[0]))
A.unmaxed = <float*>calloc(n.states * n.hiddens * n.pieces, sizeof(A.unmaxed[0]))
A.hiddens = <float*>calloc(n.states * n.hiddens, sizeof(A.hiddens[0]))
A.is_valid = <int*>calloc(n.states * n.classes, sizeof(A.is_valid[0]))
A._max_size = n.states
else:
A.token_ids = <int*>realloc(A.token_ids,
2023-12-18 22:02:15 +03:00
n.states * n.feats * sizeof(A.token_ids[0]))
A.scores = <float*>realloc(A.scores,
2023-12-18 22:02:15 +03:00
n.states * n.classes * sizeof(A.scores[0]))
A.unmaxed = <float*>realloc(A.unmaxed,
2023-12-18 22:02:15 +03:00
n.states * n.hiddens * n.pieces * sizeof(A.unmaxed[0]))
A.hiddens = <float*>realloc(A.hiddens,
2023-12-18 22:02:15 +03:00
n.states * n.hiddens * sizeof(A.hiddens[0]))
A.is_valid = <int*>realloc(A.is_valid,
2023-12-18 22:02:15 +03:00
n.states * n.classes * sizeof(A.is_valid[0]))
A._max_size = n.states
A._curr_size = n.states
cdef void predict_states(CBlas cblas, ActivationsC* A, StateC** states,
2023-12-18 22:02:15 +03:00
const WeightsC* W, SizesC n) nogil:
resize_activations(A, n)
for i in range(n.states):
states[i].set_context_tokens(&A.token_ids[i*n.feats], n.feats)
memset(A.unmaxed, 0, n.states * n.hiddens * n.pieces * sizeof(float))
memset(A.hiddens, 0, n.states * n.hiddens * sizeof(float))
2023-12-18 22:02:15 +03:00
sum_state_features(cblas, A.unmaxed, W.feat_weights, A.token_ids, n.states,
n.feats, n.hiddens * n.pieces)
for i in range(n.states):
2023-12-18 22:02:15 +03:00
saxpy(cblas)(n.hiddens * n.pieces, 1., W.feat_bias, 1,
&A.unmaxed[i*n.hiddens*n.pieces], 1)
for j in range(n.hiddens):
index = i * n.hiddens * n.pieces + j * n.pieces
which = _arg_max(&A.unmaxed[index], n.pieces)
A.hiddens[i*n.hiddens + j] = A.unmaxed[index + which]
memset(A.scores, 0, n.states * n.classes * sizeof(float))
if W.hidden_weights == NULL:
memcpy(A.scores, A.hiddens, n.states * n.classes * sizeof(float))
else:
# Compute hidden-to-output
2023-12-18 22:02:15 +03:00
sgemm(cblas)(False, True, n.states, n.classes, n.hiddens, 1.0,
<const float *>A.hiddens, n.hiddens,
<const float *>W.hidden_weights, n.hiddens, 0.0,
A.scores, n.classes)
# Add bias
for i in range(n.states):
saxpy(cblas)(n.classes, 1., W.hidden_bias, 1, &A.scores[i*n.classes], 1)
# Set unseen classes to minimum value
i = 0
min_ = A.scores[0]
for i in range(1, n.states * n.classes):
if A.scores[i] < min_:
min_ = A.scores[i]
for i in range(n.states):
for j in range(n.classes):
if not W.seen_classes[j]:
A.scores[i*n.classes+j] = min_
2023-12-18 22:02:15 +03:00
cdef void sum_state_features(CBlas cblas, float* output, const float* cached,
const int* token_ids, int B, int F, int O) nogil:
cdef int idx, b, f
cdef const float* feature
padding = cached
cached += F * O
cdef int id_stride = F*O
cdef float one = 1.
for b in range(B):
for f in range(F):
if token_ids[f] < 0:
feature = &padding[f*O]
else:
idx = token_ids[f] * id_stride + f*O
feature = &cached[idx]
saxpy(cblas)(O, one, <const float*>feature, 1, &output[b*O], 1)
token_ids += F
2023-12-18 22:02:15 +03:00
cdef void cpu_log_loss(float* d_scores, const float* costs, const int* is_valid,
const float* scores, int O) nogil:
"""Do multi-label log loss"""
cdef double max_, gmax, Z, gZ
best = arg_max_if_gold(scores, costs, is_valid, O)
guess = _arg_max(scores, O)
if best == -1 or guess == -1:
# These shouldn't happen, but if they do, we want to make sure we don't
# cause an OOB access.
return
Z = 1e-10
gZ = 1e-10
max_ = scores[guess]
gmax = scores[best]
for i in range(O):
Z += exp(scores[i] - max_)
if costs[i] <= costs[best]:
gZ += exp(scores[i] - gmax)
for i in range(O):
if costs[i] <= costs[best]:
d_scores[i] = (exp(scores[i]-max_) / Z) - (exp(scores[i]-gmax)/gZ)
else:
d_scores[i] = exp(scores[i]-max_) / Z
cdef int arg_max_if_gold(const weight_t* scores, const weight_t* costs,
2023-12-18 22:02:15 +03:00
const int* is_valid, int n) nogil:
# Find minimum cost
cdef float cost = 1
for i in range(n):
if is_valid[i] and costs[i] < cost:
cost = costs[i]
# Now find best-scoring with that cost
cdef int best = -1
for i in range(n):
if costs[i] <= cost and is_valid[i]:
if best == -1 or scores[i] > scores[best]:
best = i
return best
cdef int arg_max_if_valid(const weight_t* scores, const int* is_valid, int n) nogil:
cdef int best = -1
for i in range(n):
if is_valid[i] >= 1:
if best == -1 or scores[i] > scores[best]:
best = i
return best
class ParserStepModel(Model):
def __init__(self, docs, layers, *, has_upper, unseen_classes=None, train=True,
2023-12-18 22:02:15 +03:00
dropout=0.1):
Model.__init__(self, name="parser_step_model", forward=step_forward)
self.attrs["has_upper"] = has_upper
self.attrs["dropout_rate"] = dropout
self.tokvecs, self.bp_tokvecs = layers[0](docs, is_train=train)
if layers[1].get_dim("nP") >= 2:
activation = "maxout"
elif has_upper:
activation = None
else:
activation = "relu"
self.state2vec = precompute_hiddens(len(docs), self.tokvecs, layers[1],
activation=activation, train=train)
if has_upper:
self.vec2scores = layers[-1]
else:
self.vec2scores = None
self.cuda_stream = util.get_cuda_stream(non_blocking=True)
self.backprops = []
self._class_mask = numpy.zeros((self.nO,), dtype='f')
self._class_mask.fill(1)
if unseen_classes is not None:
for class_ in unseen_classes:
self._class_mask[class_] = 0.
def clear_memory(self):
del self.tokvecs
del self.bp_tokvecs
del self.state2vec
del self.backprops
del self._class_mask
@property
def nO(self):
if self.attrs["has_upper"]:
return self.vec2scores.get_dim("nO")
else:
return self.state2vec.get_dim("nO")
def class_is_unseen(self, class_):
return self._class_mask[class_]
def mark_class_unseen(self, class_):
self._class_mask[class_] = 0
def mark_class_seen(self, class_):
self._class_mask[class_] = 1
def get_token_ids(self, states):
cdef StateClass state
states = [state for state in states if not state.is_final()]
cdef np.ndarray ids = numpy.zeros((len(states), self.state2vec.nF),
dtype='i', order='C')
ids.fill(-1)
c_ids = <int*>ids.data
for state in states:
state.c.set_context_tokens(c_ids, ids.shape[1])
c_ids += ids.shape[1]
return ids
def backprop_step(self, token_ids, d_vector, get_d_tokvecs):
if isinstance(self.state2vec.ops, CupyOps) \
2023-12-18 22:02:15 +03:00
and not isinstance(token_ids, self.state2vec.ops.xp.ndarray):
# Move token_ids and d_vector to GPU, asynchronously
self.backprops.append((
util.get_async(self.cuda_stream, token_ids),
util.get_async(self.cuda_stream, d_vector),
get_d_tokvecs
))
else:
self.backprops.append((token_ids, d_vector, get_d_tokvecs))
def finish_steps(self, golds):
# Add a padding vector to the d_tokvecs gradient, so that missing
# values don't affect the real gradient.
d_tokvecs = self.ops.alloc((self.tokvecs.shape[0]+1, self.tokvecs.shape[1]))
# Tells CUDA to block, so our async copies complete.
if self.cuda_stream is not None:
self.cuda_stream.synchronize()
for ids, d_vector, bp_vector in self.backprops:
d_state_features = bp_vector((d_vector, ids))
ids = ids.flatten()
d_state_features = d_state_features.reshape(
(ids.size, d_state_features.shape[2]))
2023-12-18 22:02:15 +03:00
self.ops.scatter_add(d_tokvecs, ids, d_state_features)
# Padded -- see update()
self.bp_tokvecs(d_tokvecs[:-1])
return d_tokvecs
2023-12-18 22:02:15 +03:00
NUMPY_OPS = NumpyOps()
2023-12-18 22:02:15 +03:00
def step_forward(model: ParserStepModel, states, is_train):
token_ids = model.get_token_ids(states)
vector, get_d_tokvecs = model.state2vec(token_ids, is_train)
mask = None
if model.attrs["has_upper"]:
dropout_rate = model.attrs["dropout_rate"]
if is_train and dropout_rate > 0:
mask = NUMPY_OPS.get_dropout_mask(vector.shape, 0.1)
vector *= mask
scores, get_d_vector = model.vec2scores(vector, is_train)
else:
scores = NumpyOps().asarray(vector)
2023-12-18 22:02:15 +03:00
def get_d_vector(d_scores): return d_scores
# If the class is unseen, make sure its score is minimum
scores[:, model._class_mask == 0] = numpy.nanmin(scores)
def backprop_parser_step(d_scores):
# Zero vectors for unseen classes
d_scores *= model._class_mask
d_vector = get_d_vector(d_scores)
if mask is not None:
d_vector *= mask
model.backprop_step(token_ids, d_vector, get_d_tokvecs)
return None
return scores, backprop_parser_step
cdef class precompute_hiddens:
"""Allow a model to be "primed" by pre-computing input features in bulk.
This is used for the parser, where we want to take a batch of documents,
and compute vectors for each (token, position) pair. These vectors can then
be reused, especially for beam-search.
Let's say we're using 12 features for each state, e.g. word at start of
buffer, three words on stack, their children, etc. In the normal arc-eager
system, a document of length N is processed in 2*N states. This means we'll
create 2*N*12 feature vectors --- but if we pre-compute, we only need
N*12 vector computations. The saving for beam-search is much better:
if we have a beam of k, we'll normally make 2*N*12*K computations --
so we can save the factor k. This also gives a nice CPU/GPU division:
we can do all our hard maths up front, packed into large multiplications,
and do the hard-to-program parsing on the CPU.
"""
cdef readonly int nF, nO, nP
cdef bint _is_synchronized
cdef public object ops
cdef public object numpy_ops
cdef public object _cpu_ops
cdef np.ndarray _features
cdef np.ndarray _cached
cdef np.ndarray bias
cdef object _cuda_stream
cdef object _bp_hiddens
cdef object activation
def __init__(self, batch_size, tokvecs, lower_model, cuda_stream=None,
activation="maxout", train=False):
gpu_cached, bp_features = lower_model(tokvecs, train)
cdef np.ndarray cached
if not isinstance(gpu_cached, numpy.ndarray):
# Note the passing of cuda_stream here: it lets
# cupy make the copy asynchronously.
# We then have to block before first use.
cached = gpu_cached.get(stream=cuda_stream)
else:
cached = gpu_cached
if not isinstance(lower_model.get_param("b"), numpy.ndarray):
self.bias = lower_model.get_param("b").get(stream=cuda_stream)
else:
self.bias = lower_model.get_param("b")
self.nF = cached.shape[1]
if lower_model.has_dim("nP"):
self.nP = lower_model.get_dim("nP")
else:
self.nP = 1
self.nO = cached.shape[2]
self.ops = lower_model.ops
self.numpy_ops = NumpyOps()
self._cpu_ops = get_ops("cpu") if isinstance(self.ops, CupyOps) else self.ops
assert activation in (None, "relu", "maxout")
self.activation = activation
self._is_synchronized = False
self._cuda_stream = cuda_stream
self._cached = cached
self._bp_hiddens = bp_features
cdef const float* get_feat_weights(self) except NULL:
if not self._is_synchronized and self._cuda_stream is not None:
self._cuda_stream.synchronize()
self._is_synchronized = True
return <float*>self._cached.data
def has_dim(self, name):
if name == "nF":
return self.nF if self.nF is not None else True
elif name == "nP":
return self.nP if self.nP is not None else True
elif name == "nO":
return self.nO if self.nO is not None else True
else:
return False
def get_dim(self, name):
if name == "nF":
return self.nF
elif name == "nP":
return self.nP
elif name == "nO":
return self.nO
else:
raise ValueError(Errors.E1033.format(name=name))
def set_dim(self, name, value):
if name == "nF":
self.nF = value
elif name == "nP":
self.nP = value
elif name == "nO":
self.nO = value
else:
raise ValueError(Errors.E1033.format(name=name))
def __call__(self, X, bint is_train):
if is_train:
return self.begin_update(X)
else:
return self.predict(X), lambda X: X
def predict(self, X):
return self.begin_update(X)[0]
def begin_update(self, token_ids):
cdef np.ndarray state_vector = numpy.zeros(
(token_ids.shape[0], self.nO, self.nP), dtype='f')
# This is tricky, but (assuming GPU available);
# - Input to forward on CPU
# - Output from forward on CPU
# - Input to backward on GPU!
# - Output from backward on GPU
bp_hiddens = self._bp_hiddens
cdef CBlas cblas = self._cpu_ops.cblas()
feat_weights = self.get_feat_weights()
cdef int[:, ::1] ids = token_ids
sum_state_features(cblas, <float*>state_vector.data,
2023-12-18 22:02:15 +03:00
feat_weights, &ids[0, 0], token_ids.shape[0],
self.nF, self.nO*self.nP)
state_vector += self.bias
state_vector, bp_nonlinearity = self._nonlinearity(state_vector)
def backward(d_state_vector_ids):
d_state_vector, token_ids = d_state_vector_ids
d_state_vector = bp_nonlinearity(d_state_vector)
d_tokens = bp_hiddens((d_state_vector, token_ids))
return d_tokens
return state_vector, backward
def _nonlinearity(self, state_vector):
if self.activation == "maxout":
return self._maxout_nonlinearity(state_vector)
else:
return self._relu_nonlinearity(state_vector)
def _maxout_nonlinearity(self, state_vector):
state_vector, mask = self.numpy_ops.maxout(state_vector)
# We're outputting to CPU, but we need this variable on GPU for the
# backward pass.
mask = self.ops.asarray(mask)
def backprop_maxout(d_best):
return self.ops.backprop_maxout(d_best, mask, self.nP)
2023-12-18 22:02:15 +03:00
return state_vector, backprop_maxout
def _relu_nonlinearity(self, state_vector):
state_vector = state_vector.reshape((state_vector.shape[0], -1))
mask = state_vector >= 0.
state_vector *= mask
# We're outputting to CPU, but we need this variable on GPU for the
# backward pass.
mask = self.ops.asarray(mask)
def backprop_relu(d_best):
d_best *= mask
return d_best.reshape((d_best.shape + (1,)))
2023-12-18 22:02:15 +03:00
return state_vector, backprop_relu
cdef inline int _arg_max(const float* scores, const int n_classes) nogil:
if n_classes == 2:
return 0 if scores[0] > scores[1] else 1
cdef int i
cdef int best = 0
cdef float mode = scores[0]
for i in range(1, n_classes):
if scores[i] > mode:
mode = scores[i]
best = i
return best