2020-03-02 13:48:10 +03:00
|
|
|
# cython: infer_types=True, cdivision=True, boundscheck=False
|
2018-05-15 23:17:29 +03:00
|
|
|
cimport numpy as np
|
|
|
|
from libc.math cimport exp
|
|
|
|
from libc.stdlib cimport calloc, free, realloc
|
2023-06-14 18:48:41 +03:00
|
|
|
from libc.string cimport memcpy, memset
|
2022-06-16 12:42:34 +03:00
|
|
|
from thinc.backends.cblas cimport saxpy, sgemm
|
2023-06-14 18:48:41 +03:00
|
|
|
from thinc.backends.linalg cimport Vec, VecVec
|
2018-05-15 23:17:29 +03:00
|
|
|
|
2020-03-02 13:48:10 +03:00
|
|
|
import numpy
|
|
|
|
import numpy.random
|
2023-07-19 13:03:31 +03:00
|
|
|
from thinc.api import CupyOps, Model, NumpyOps
|
2020-03-02 13:48:10 +03:00
|
|
|
|
|
|
|
from .. import util
|
2022-05-02 14:38:46 +03:00
|
|
|
from ..errors import Errors
|
2023-06-14 18:48:41 +03:00
|
|
|
|
2020-07-31 00:30:54 +03:00
|
|
|
from ..pipeline._parser_internals.stateclass cimport StateClass
|
2023-07-19 13:03:31 +03:00
|
|
|
from ..typedefs cimport weight_t
|
2018-05-15 23:17:29 +03:00
|
|
|
|
|
|
|
|
|
|
|
cdef WeightsC get_c_weights(model) except *:
|
|
|
|
cdef WeightsC output
|
|
|
|
cdef precompute_hiddens state2vec = model.state2vec
|
|
|
|
output.feat_weights = state2vec.get_feat_weights()
|
|
|
|
output.feat_bias = <const float*>state2vec.bias.data
|
2019-11-19 17:54:34 +03:00
|
|
|
cdef np.ndarray vec2scores_W
|
|
|
|
cdef np.ndarray vec2scores_b
|
|
|
|
if model.vec2scores is None:
|
|
|
|
output.hidden_weights = NULL
|
|
|
|
output.hidden_bias = NULL
|
|
|
|
else:
|
2020-01-29 19:06:46 +03:00
|
|
|
vec2scores_W = model.vec2scores.get_param("W")
|
|
|
|
vec2scores_b = model.vec2scores.get_param("b")
|
2019-11-19 17:54:34 +03:00
|
|
|
output.hidden_weights = <const float*>vec2scores_W.data
|
|
|
|
output.hidden_bias = <const float*>vec2scores_b.data
|
2019-02-24 18:41:41 +03:00
|
|
|
cdef np.ndarray class_mask = model._class_mask
|
|
|
|
output.seen_classes = <const float*>class_mask.data
|
2018-05-15 23:17:29 +03:00
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
|
|
cdef SizesC get_c_sizes(model, int batch_size) except *:
|
|
|
|
cdef SizesC output
|
|
|
|
output.states = batch_size
|
2019-11-19 17:54:34 +03:00
|
|
|
if model.vec2scores is None:
|
2020-01-29 19:06:46 +03:00
|
|
|
output.classes = model.state2vec.get_dim("nO")
|
2019-11-19 17:54:34 +03:00
|
|
|
else:
|
2020-01-29 19:06:46 +03:00
|
|
|
output.classes = model.vec2scores.get_dim("nO")
|
|
|
|
output.hiddens = model.state2vec.get_dim("nO")
|
|
|
|
output.pieces = model.state2vec.get_dim("nP")
|
|
|
|
output.feats = model.state2vec.get_dim("nF")
|
2018-05-15 23:17:29 +03:00
|
|
|
output.embed_width = model.tokvecs.shape[1]
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
2019-10-22 16:06:44 +03:00
|
|
|
cdef ActivationsC alloc_activations(SizesC n) nogil:
|
|
|
|
cdef ActivationsC A
|
|
|
|
memset(&A, 0, sizeof(A))
|
|
|
|
resize_activations(&A, n)
|
|
|
|
return A
|
|
|
|
|
|
|
|
|
|
|
|
cdef void free_activations(const ActivationsC* A) nogil:
|
|
|
|
free(A.token_ids)
|
|
|
|
free(A.scores)
|
|
|
|
free(A.unmaxed)
|
|
|
|
free(A.hiddens)
|
|
|
|
free(A.is_valid)
|
|
|
|
|
|
|
|
|
2018-05-15 23:17:29 +03:00
|
|
|
cdef void resize_activations(ActivationsC* A, SizesC n) nogil:
|
|
|
|
if n.states <= A._max_size:
|
|
|
|
A._curr_size = n.states
|
|
|
|
return
|
|
|
|
if A._max_size == 0:
|
|
|
|
A.token_ids = <int*>calloc(n.states * n.feats, sizeof(A.token_ids[0]))
|
|
|
|
A.scores = <float*>calloc(n.states * n.classes, sizeof(A.scores[0]))
|
|
|
|
A.unmaxed = <float*>calloc(n.states * n.hiddens * n.pieces, sizeof(A.unmaxed[0]))
|
|
|
|
A.hiddens = <float*>calloc(n.states * n.hiddens, sizeof(A.hiddens[0]))
|
|
|
|
A.is_valid = <int*>calloc(n.states * n.classes, sizeof(A.is_valid[0]))
|
|
|
|
A._max_size = n.states
|
|
|
|
else:
|
2023-07-19 13:03:31 +03:00
|
|
|
A.token_ids = <int*>realloc(
|
|
|
|
A.token_ids, n.states * n.feats * sizeof(A.token_ids[0])
|
|
|
|
)
|
|
|
|
A.scores = <float*>realloc(
|
|
|
|
A.scores, n.states * n.classes * sizeof(A.scores[0])
|
|
|
|
)
|
|
|
|
A.unmaxed = <float*>realloc(
|
|
|
|
A.unmaxed, n.states * n.hiddens * n.pieces * sizeof(A.unmaxed[0])
|
|
|
|
)
|
|
|
|
A.hiddens = <float*>realloc(
|
|
|
|
A.hiddens, n.states * n.hiddens * sizeof(A.hiddens[0])
|
|
|
|
)
|
|
|
|
A.is_valid = <int*>realloc(
|
|
|
|
A.is_valid, n.states * n.classes * sizeof(A.is_valid[0])
|
|
|
|
)
|
2018-05-15 23:17:29 +03:00
|
|
|
A._max_size = n.states
|
|
|
|
A._curr_size = n.states
|
|
|
|
|
|
|
|
|
2023-07-19 13:03:31 +03:00
|
|
|
cdef void predict_states(
|
|
|
|
CBlas cblas, ActivationsC* A, StateC** states, const WeightsC* W, SizesC n
|
|
|
|
) nogil:
|
2018-05-15 23:17:29 +03:00
|
|
|
resize_activations(A, n)
|
|
|
|
for i in range(n.states):
|
|
|
|
states[i].set_context_tokens(&A.token_ids[i*n.feats], n.feats)
|
2019-11-19 17:54:34 +03:00
|
|
|
memset(A.unmaxed, 0, n.states * n.hiddens * n.pieces * sizeof(float))
|
|
|
|
memset(A.hiddens, 0, n.states * n.hiddens * sizeof(float))
|
2023-07-19 13:03:31 +03:00
|
|
|
sum_state_features(
|
|
|
|
cblas,
|
|
|
|
A.unmaxed,
|
|
|
|
W.feat_weights,
|
|
|
|
A.token_ids,
|
|
|
|
n.states,
|
|
|
|
n.feats,
|
|
|
|
n.hiddens * n.pieces
|
|
|
|
)
|
2018-05-15 23:17:29 +03:00
|
|
|
for i in range(n.states):
|
2023-07-19 13:03:31 +03:00
|
|
|
VecVec.add_i(
|
|
|
|
&A.unmaxed[i*n.hiddens*n.pieces],
|
|
|
|
W.feat_bias, 1.,
|
|
|
|
n.hiddens * n.pieces
|
|
|
|
)
|
2018-05-15 23:17:29 +03:00
|
|
|
for j in range(n.hiddens):
|
|
|
|
index = i * n.hiddens * n.pieces + j * n.pieces
|
|
|
|
which = Vec.arg_max(&A.unmaxed[index], n.pieces)
|
|
|
|
A.hiddens[i*n.hiddens + j] = A.unmaxed[index + which]
|
|
|
|
memset(A.scores, 0, n.states * n.classes * sizeof(float))
|
2019-11-19 17:54:34 +03:00
|
|
|
if W.hidden_weights == NULL:
|
|
|
|
memcpy(A.scores, A.hiddens, n.states * n.classes * sizeof(float))
|
|
|
|
else:
|
|
|
|
# Compute hidden-to-output
|
2023-07-19 13:03:31 +03:00
|
|
|
sgemm(cblas)(
|
|
|
|
False, True, n.states, n.classes, n.hiddens,
|
Parser: use C saxpy/sgemm provided by the Ops implementation (#10773)
* Parser: use C saxpy/sgemm provided by the Ops implementation
This is a backport of https://github.com/explosion/spaCy/pull/10747
from the parser refactor branch. It eliminates the explicit calls
to BLIS, instead using the saxpy/sgemm provided by the Ops
implementation.
This allows us to use Accelerate in the parser on M1 Macs (with
an updated thinc-apple-ops).
Performance of the de_core_news_lg pipe:
BLIS 0.7.0, no thinc-apple-ops: 6385 WPS
BLIS 0.7.0, thinc-apple-ops: 36455 WPS
BLIS 0.9.0, no thinc-apple-ops: 19188 WPS
BLIS 0.9.0, thinc-apple-ops: 36682 WPS
This PR, thinc-apple-ops: 38726 WPS
Performance of the de_core_news_lg pipe (only tok2vec -> parser):
BLIS 0.7.0, no thinc-apple-ops: 13907 WPS
BLIS 0.7.0, thinc-apple-ops: 73172 WPS
BLIS 0.9.0, no thinc-apple-ops: 41576 WPS
BLIS 0.9.0, thinc-apple-ops: 72569 WPS
This PR, thinc-apple-ops: 87061 WPS
* Require thinc >=8.1.0,<8.2.0
* Lower thinc lowerbound to 8.1.0.dev0
* Use best CPU ops for CBLAS when the parser model is on the GPU
* Fix another unguarded cblas() call
* Fix: use ops as a shorthand for self.model.ops
Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>
Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>
2022-05-27 12:20:52 +03:00
|
|
|
1.0, <const float *>A.hiddens, n.hiddens,
|
|
|
|
<const float *>W.hidden_weights, n.hiddens,
|
2023-07-19 13:03:31 +03:00
|
|
|
0.0, A.scores, n.classes
|
|
|
|
)
|
2019-11-19 17:54:34 +03:00
|
|
|
# Add bias
|
|
|
|
for i in range(n.states):
|
2023-07-19 13:03:31 +03:00
|
|
|
VecVec.add_i(&A.scores[i*n.classes], W.hidden_bias, 1., n.classes)
|
2019-02-24 18:41:41 +03:00
|
|
|
# Set unseen classes to minimum value
|
|
|
|
i = 0
|
|
|
|
min_ = A.scores[0]
|
|
|
|
for i in range(1, n.states * n.classes):
|
|
|
|
if A.scores[i] < min_:
|
|
|
|
min_ = A.scores[i]
|
|
|
|
for i in range(n.states):
|
|
|
|
for j in range(n.classes):
|
|
|
|
if not W.seen_classes[j]:
|
|
|
|
A.scores[i*n.classes+j] = min_
|
2018-05-15 23:17:29 +03:00
|
|
|
|
💫 Replace ujson, msgpack and dill/pickle/cloudpickle with srsly (#3003)
Remove hacks and wrappers, keep code in sync across our libraries and move spaCy a few steps closer to only depending on packages with binary wheels 🎉
See here: https://github.com/explosion/srsly
Serialization is hard, especially across Python versions and multiple platforms. After dealing with many subtle bugs over the years (encodings, locales, large files) our libraries like spaCy and Prodigy have steadily grown a number of utility functions to wrap the multiple serialization formats we need to support (especially json, msgpack and pickle). These wrapping functions ended up duplicated across our codebases, so we wanted to put them in one place.
At the same time, we noticed that having a lot of small dependencies was making maintainence harder, and making installation slower. To solve this, we've made srsly standalone, by including the component packages directly within it. This way we can provide all the serialization utilities we need in a single binary wheel.
srsly currently includes forks of the following packages:
ujson
msgpack
msgpack-numpy
cloudpickle
* WIP: replace json/ujson with srsly
* Replace ujson in examples
Use regular json instead of srsly to make code easier to read and follow
* Update requirements
* Fix imports
* Fix typos
* Replace msgpack with srsly
* Fix warning
2018-12-03 03:28:22 +03:00
|
|
|
|
2023-07-19 13:03:31 +03:00
|
|
|
cdef void sum_state_features(
|
|
|
|
CBlas cblas,
|
|
|
|
float* output,
|
|
|
|
const float* cached,
|
|
|
|
const int* token_ids,
|
|
|
|
int B,
|
|
|
|
int F,
|
|
|
|
int O
|
|
|
|
) nogil:
|
|
|
|
cdef int idx, b, f
|
2018-05-15 23:17:29 +03:00
|
|
|
cdef const float* feature
|
|
|
|
padding = cached
|
|
|
|
cached += F * O
|
|
|
|
cdef int id_stride = F*O
|
|
|
|
cdef float one = 1.
|
|
|
|
for b in range(B):
|
|
|
|
for f in range(F):
|
|
|
|
if token_ids[f] < 0:
|
|
|
|
feature = &padding[f*O]
|
|
|
|
else:
|
|
|
|
idx = token_ids[f] * id_stride + f*O
|
|
|
|
feature = &cached[idx]
|
2022-06-16 12:42:34 +03:00
|
|
|
saxpy(cblas)(O, one, <const float*>feature, 1, &output[b*O], 1)
|
2018-05-15 23:17:29 +03:00
|
|
|
token_ids += F
|
|
|
|
|
|
|
|
|
2023-07-19 13:03:31 +03:00
|
|
|
cdef void cpu_log_loss(
|
|
|
|
float* d_scores,
|
|
|
|
const float* costs,
|
|
|
|
const int* is_valid,
|
|
|
|
const float* scores,
|
|
|
|
int O
|
|
|
|
) nogil:
|
2018-05-15 23:17:29 +03:00
|
|
|
"""Do multi-label log loss"""
|
|
|
|
cdef double max_, gmax, Z, gZ
|
|
|
|
best = arg_max_if_gold(scores, costs, is_valid, O)
|
2019-03-15 17:22:16 +03:00
|
|
|
guess = Vec.arg_max(scores, O)
|
2019-03-10 16:53:03 +03:00
|
|
|
if best == -1 or guess == -1:
|
|
|
|
# These shouldn't happen, but if they do, we want to make sure we don't
|
|
|
|
# cause an OOB access.
|
|
|
|
return
|
2018-05-15 23:17:29 +03:00
|
|
|
Z = 1e-10
|
|
|
|
gZ = 1e-10
|
|
|
|
max_ = scores[guess]
|
|
|
|
gmax = scores[best]
|
|
|
|
for i in range(O):
|
2019-03-15 17:22:16 +03:00
|
|
|
Z += exp(scores[i] - max_)
|
|
|
|
if costs[i] <= costs[best]:
|
|
|
|
gZ += exp(scores[i] - gmax)
|
2018-05-15 23:17:29 +03:00
|
|
|
for i in range(O):
|
2019-03-15 17:22:16 +03:00
|
|
|
if costs[i] <= costs[best]:
|
2018-05-15 23:17:29 +03:00
|
|
|
d_scores[i] = (exp(scores[i]-max_) / Z) - (exp(scores[i]-gmax)/gZ)
|
|
|
|
else:
|
|
|
|
d_scores[i] = exp(scores[i]-max_) / Z
|
|
|
|
|
💫 Replace ujson, msgpack and dill/pickle/cloudpickle with srsly (#3003)
Remove hacks and wrappers, keep code in sync across our libraries and move spaCy a few steps closer to only depending on packages with binary wheels 🎉
See here: https://github.com/explosion/srsly
Serialization is hard, especially across Python versions and multiple platforms. After dealing with many subtle bugs over the years (encodings, locales, large files) our libraries like spaCy and Prodigy have steadily grown a number of utility functions to wrap the multiple serialization formats we need to support (especially json, msgpack and pickle). These wrapping functions ended up duplicated across our codebases, so we wanted to put them in one place.
At the same time, we noticed that having a lot of small dependencies was making maintainence harder, and making installation slower. To solve this, we've made srsly standalone, by including the component packages directly within it. This way we can provide all the serialization utilities we need in a single binary wheel.
srsly currently includes forks of the following packages:
ujson
msgpack
msgpack-numpy
cloudpickle
* WIP: replace json/ujson with srsly
* Replace ujson in examples
Use regular json instead of srsly to make code easier to read and follow
* Update requirements
* Fix imports
* Fix typos
* Replace msgpack with srsly
* Fix warning
2018-12-03 03:28:22 +03:00
|
|
|
|
2023-07-19 13:03:31 +03:00
|
|
|
cdef int arg_max_if_gold(
|
|
|
|
const weight_t* scores, const weight_t* costs, const int* is_valid, int n
|
|
|
|
) nogil:
|
2018-05-15 23:17:29 +03:00
|
|
|
# Find minimum cost
|
|
|
|
cdef float cost = 1
|
|
|
|
for i in range(n):
|
|
|
|
if is_valid[i] and costs[i] < cost:
|
|
|
|
cost = costs[i]
|
|
|
|
# Now find best-scoring with that cost
|
|
|
|
cdef int best = -1
|
|
|
|
for i in range(n):
|
|
|
|
if costs[i] <= cost and is_valid[i]:
|
|
|
|
if best == -1 or scores[i] > scores[best]:
|
|
|
|
best = i
|
|
|
|
return best
|
|
|
|
|
|
|
|
|
|
|
|
cdef int arg_max_if_valid(const weight_t* scores, const int* is_valid, int n) nogil:
|
|
|
|
cdef int best = -1
|
|
|
|
for i in range(n):
|
|
|
|
if is_valid[i] >= 1:
|
|
|
|
if best == -1 or scores[i] > scores[best]:
|
|
|
|
best = i
|
|
|
|
return best
|
|
|
|
|
|
|
|
|
|
|
|
class ParserStepModel(Model):
|
2023-07-19 13:03:31 +03:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
docs,
|
|
|
|
layers,
|
|
|
|
*,
|
|
|
|
has_upper,
|
|
|
|
unseen_classes=None,
|
|
|
|
train=True,
|
|
|
|
dropout=0.1
|
|
|
|
):
|
2020-01-29 19:06:46 +03:00
|
|
|
Model.__init__(self, name="parser_step_model", forward=step_forward)
|
2020-05-18 23:23:33 +03:00
|
|
|
self.attrs["has_upper"] = has_upper
|
2020-07-07 02:38:15 +03:00
|
|
|
self.attrs["dropout_rate"] = dropout
|
2020-01-29 19:06:46 +03:00
|
|
|
self.tokvecs, self.bp_tokvecs = layers[0](docs, is_train=train)
|
|
|
|
if layers[1].get_dim("nP") >= 2:
|
2019-11-19 17:54:34 +03:00
|
|
|
activation = "maxout"
|
2020-05-18 23:23:33 +03:00
|
|
|
elif has_upper:
|
2019-11-19 17:54:34 +03:00
|
|
|
activation = None
|
|
|
|
else:
|
|
|
|
activation = "relu"
|
2018-05-15 23:17:29 +03:00
|
|
|
self.state2vec = precompute_hiddens(len(docs), self.tokvecs, layers[1],
|
2020-01-29 19:06:46 +03:00
|
|
|
activation=activation, train=train)
|
2020-05-18 23:23:33 +03:00
|
|
|
if has_upper:
|
2019-11-19 17:54:34 +03:00
|
|
|
self.vec2scores = layers[-1]
|
|
|
|
else:
|
|
|
|
self.vec2scores = None
|
|
|
|
self.cuda_stream = util.get_cuda_stream(non_blocking=True)
|
2018-05-15 23:17:29 +03:00
|
|
|
self.backprops = []
|
2020-05-18 23:23:33 +03:00
|
|
|
self._class_mask = numpy.zeros((self.nO,), dtype='f')
|
2019-02-24 18:41:41 +03:00
|
|
|
self._class_mask.fill(1)
|
|
|
|
if unseen_classes is not None:
|
|
|
|
for class_ in unseen_classes:
|
|
|
|
self._class_mask[class_] = 0.
|
2018-05-15 23:17:29 +03:00
|
|
|
|
2020-07-10 23:35:20 +03:00
|
|
|
def clear_memory(self):
|
|
|
|
del self.tokvecs
|
|
|
|
del self.bp_tokvecs
|
|
|
|
del self.state2vec
|
|
|
|
del self.backprops
|
|
|
|
del self._class_mask
|
|
|
|
|
2018-05-15 23:17:29 +03:00
|
|
|
@property
|
|
|
|
def nO(self):
|
2020-05-18 23:23:33 +03:00
|
|
|
if self.attrs["has_upper"]:
|
|
|
|
return self.vec2scores.get_dim("nO")
|
|
|
|
else:
|
|
|
|
return self.state2vec.get_dim("nO")
|
2018-05-15 23:17:29 +03:00
|
|
|
|
2019-02-24 18:41:41 +03:00
|
|
|
def class_is_unseen(self, class_):
|
|
|
|
return self._class_mask[class_]
|
|
|
|
|
|
|
|
def mark_class_unseen(self, class_):
|
|
|
|
self._class_mask[class_] = 0
|
|
|
|
|
|
|
|
def mark_class_seen(self, class_):
|
|
|
|
self._class_mask[class_] = 1
|
|
|
|
|
2020-06-26 20:34:12 +03:00
|
|
|
def get_token_ids(self, states):
|
2018-05-15 23:17:29 +03:00
|
|
|
cdef StateClass state
|
|
|
|
states = [state for state in states if not state.is_final()]
|
|
|
|
cdef np.ndarray ids = numpy.zeros((len(states), self.state2vec.nF),
|
|
|
|
dtype='i', order='C')
|
|
|
|
ids.fill(-1)
|
|
|
|
c_ids = <int*>ids.data
|
|
|
|
for state in states:
|
|
|
|
state.c.set_context_tokens(c_ids, ids.shape[1])
|
|
|
|
c_ids += ids.shape[1]
|
|
|
|
return ids
|
|
|
|
|
2020-07-10 23:35:20 +03:00
|
|
|
def backprop_step(self, token_ids, d_vector, get_d_tokvecs):
|
2023-07-19 13:03:31 +03:00
|
|
|
if (
|
|
|
|
isinstance(self.state2vec.ops, CupyOps)
|
|
|
|
and not isinstance(token_ids, self.state2vec.ops.xp.ndarray)
|
|
|
|
):
|
2020-07-10 23:35:20 +03:00
|
|
|
# Move token_ids and d_vector to GPU, asynchronously
|
|
|
|
self.backprops.append((
|
|
|
|
util.get_async(self.cuda_stream, token_ids),
|
|
|
|
util.get_async(self.cuda_stream, d_vector),
|
|
|
|
get_d_tokvecs
|
|
|
|
))
|
|
|
|
else:
|
|
|
|
self.backprops.append((token_ids, d_vector, get_d_tokvecs))
|
|
|
|
|
2020-01-29 19:06:46 +03:00
|
|
|
def finish_steps(self, golds):
|
2018-05-15 23:17:29 +03:00
|
|
|
# Add a padding vector to the d_tokvecs gradient, so that missing
|
|
|
|
# values don't affect the real gradient.
|
2020-01-29 19:06:46 +03:00
|
|
|
d_tokvecs = self.ops.alloc((self.tokvecs.shape[0]+1, self.tokvecs.shape[1]))
|
2019-11-19 17:54:34 +03:00
|
|
|
# Tells CUDA to block, so our async copies complete.
|
|
|
|
if self.cuda_stream is not None:
|
|
|
|
self.cuda_stream.synchronize()
|
2018-05-15 23:17:29 +03:00
|
|
|
for ids, d_vector, bp_vector in self.backprops:
|
2020-01-29 19:06:46 +03:00
|
|
|
d_state_features = bp_vector((d_vector, ids))
|
2018-05-15 23:17:29 +03:00
|
|
|
ids = ids.flatten()
|
|
|
|
d_state_features = d_state_features.reshape(
|
|
|
|
(ids.size, d_state_features.shape[2]))
|
2023-07-19 13:03:31 +03:00
|
|
|
self.ops.scatter_add(d_tokvecs, ids, d_state_features)
|
2018-05-15 23:17:29 +03:00
|
|
|
# Padded -- see update()
|
2020-01-29 19:06:46 +03:00
|
|
|
self.bp_tokvecs(d_tokvecs[:-1])
|
2018-05-15 23:17:29 +03:00
|
|
|
return d_tokvecs
|
|
|
|
|
2023-07-19 13:03:31 +03:00
|
|
|
|
2020-07-07 02:38:15 +03:00
|
|
|
NUMPY_OPS = NumpyOps()
|
2018-05-15 23:17:29 +03:00
|
|
|
|
2023-07-19 13:03:31 +03:00
|
|
|
|
2020-01-29 19:06:46 +03:00
|
|
|
def step_forward(model: ParserStepModel, states, is_train):
|
|
|
|
token_ids = model.get_token_ids(states)
|
|
|
|
vector, get_d_tokvecs = model.state2vec(token_ids, is_train)
|
2020-07-07 02:38:15 +03:00
|
|
|
mask = None
|
2020-05-18 23:23:33 +03:00
|
|
|
if model.attrs["has_upper"]:
|
2020-07-07 02:38:15 +03:00
|
|
|
dropout_rate = model.attrs["dropout_rate"]
|
|
|
|
if is_train and dropout_rate > 0:
|
|
|
|
mask = NUMPY_OPS.get_dropout_mask(vector.shape, 0.1)
|
|
|
|
vector *= mask
|
2020-01-29 19:06:46 +03:00
|
|
|
scores, get_d_vector = model.vec2scores(vector, is_train)
|
|
|
|
else:
|
|
|
|
scores = NumpyOps().asarray(vector)
|
2023-07-19 13:03:31 +03:00
|
|
|
get_d_vector = lambda d_scores: d_scores # no-cython-lint: E731
|
2020-01-29 19:06:46 +03:00
|
|
|
# If the class is unseen, make sure its score is minimum
|
|
|
|
scores[:, model._class_mask == 0] = numpy.nanmin(scores)
|
|
|
|
|
|
|
|
def backprop_parser_step(d_scores):
|
|
|
|
# Zero vectors for unseen classes
|
|
|
|
d_scores *= model._class_mask
|
|
|
|
d_vector = get_d_vector(d_scores)
|
2020-07-07 02:38:15 +03:00
|
|
|
if mask is not None:
|
|
|
|
d_vector *= mask
|
2020-07-10 23:35:20 +03:00
|
|
|
model.backprop_step(token_ids, d_vector, get_d_tokvecs)
|
2020-01-29 19:06:46 +03:00
|
|
|
return None
|
|
|
|
return scores, backprop_parser_step
|
|
|
|
|
|
|
|
|
2018-05-15 23:17:29 +03:00
|
|
|
cdef class precompute_hiddens:
|
|
|
|
"""Allow a model to be "primed" by pre-computing input features in bulk.
|
|
|
|
|
|
|
|
This is used for the parser, where we want to take a batch of documents,
|
|
|
|
and compute vectors for each (token, position) pair. These vectors can then
|
|
|
|
be reused, especially for beam-search.
|
|
|
|
|
|
|
|
Let's say we're using 12 features for each state, e.g. word at start of
|
|
|
|
buffer, three words on stack, their children, etc. In the normal arc-eager
|
|
|
|
system, a document of length N is processed in 2*N states. This means we'll
|
|
|
|
create 2*N*12 feature vectors --- but if we pre-compute, we only need
|
|
|
|
N*12 vector computations. The saving for beam-search is much better:
|
|
|
|
if we have a beam of k, we'll normally make 2*N*12*K computations --
|
|
|
|
so we can save the factor k. This also gives a nice CPU/GPU division:
|
|
|
|
we can do all our hard maths up front, packed into large multiplications,
|
|
|
|
and do the hard-to-program parsing on the CPU.
|
|
|
|
"""
|
2020-02-27 20:42:27 +03:00
|
|
|
cdef readonly int nF, nO, nP
|
2018-05-15 23:17:29 +03:00
|
|
|
cdef bint _is_synchronized
|
|
|
|
cdef public object ops
|
2020-06-30 22:22:54 +03:00
|
|
|
cdef public object numpy_ops
|
2018-05-15 23:17:29 +03:00
|
|
|
cdef np.ndarray _features
|
|
|
|
cdef np.ndarray _cached
|
|
|
|
cdef np.ndarray bias
|
|
|
|
cdef object _cuda_stream
|
|
|
|
cdef object _bp_hiddens
|
2019-11-19 17:54:34 +03:00
|
|
|
cdef object activation
|
2018-05-15 23:17:29 +03:00
|
|
|
|
|
|
|
def __init__(self, batch_size, tokvecs, lower_model, cuda_stream=None,
|
2020-01-29 19:06:46 +03:00
|
|
|
activation="maxout", train=False):
|
|
|
|
gpu_cached, bp_features = lower_model(tokvecs, train)
|
2018-05-15 23:17:29 +03:00
|
|
|
cdef np.ndarray cached
|
|
|
|
if not isinstance(gpu_cached, numpy.ndarray):
|
|
|
|
# Note the passing of cuda_stream here: it lets
|
|
|
|
# cupy make the copy asynchronously.
|
|
|
|
# We then have to block before first use.
|
|
|
|
cached = gpu_cached.get(stream=cuda_stream)
|
|
|
|
else:
|
|
|
|
cached = gpu_cached
|
2020-01-29 19:06:46 +03:00
|
|
|
if not isinstance(lower_model.get_param("b"), numpy.ndarray):
|
2020-03-29 01:09:35 +03:00
|
|
|
self.bias = lower_model.get_param("b").get(stream=cuda_stream)
|
2018-05-15 23:17:29 +03:00
|
|
|
else:
|
2020-01-29 19:06:46 +03:00
|
|
|
self.bias = lower_model.get_param("b")
|
2018-05-15 23:17:29 +03:00
|
|
|
self.nF = cached.shape[1]
|
2020-01-29 19:06:46 +03:00
|
|
|
if lower_model.has_dim("nP"):
|
|
|
|
self.nP = lower_model.get_dim("nP")
|
|
|
|
else:
|
|
|
|
self.nP = 1
|
2018-05-15 23:17:29 +03:00
|
|
|
self.nO = cached.shape[2]
|
|
|
|
self.ops = lower_model.ops
|
2020-06-30 22:22:54 +03:00
|
|
|
self.numpy_ops = NumpyOps()
|
2019-11-19 17:54:34 +03:00
|
|
|
assert activation in (None, "relu", "maxout")
|
|
|
|
self.activation = activation
|
2018-05-15 23:17:29 +03:00
|
|
|
self._is_synchronized = False
|
|
|
|
self._cuda_stream = cuda_stream
|
|
|
|
self._cached = cached
|
|
|
|
self._bp_hiddens = bp_features
|
|
|
|
|
|
|
|
cdef const float* get_feat_weights(self) except NULL:
|
|
|
|
if not self._is_synchronized and self._cuda_stream is not None:
|
|
|
|
self._cuda_stream.synchronize()
|
|
|
|
self._is_synchronized = True
|
|
|
|
return <float*>self._cached.data
|
|
|
|
|
2020-02-27 20:42:27 +03:00
|
|
|
def has_dim(self, name):
|
|
|
|
if name == "nF":
|
|
|
|
return self.nF if self.nF is not None else True
|
|
|
|
elif name == "nP":
|
|
|
|
return self.nP if self.nP is not None else True
|
|
|
|
elif name == "nO":
|
|
|
|
return self.nO if self.nO is not None else True
|
|
|
|
else:
|
|
|
|
return False
|
|
|
|
|
2020-01-29 19:06:46 +03:00
|
|
|
def get_dim(self, name):
|
|
|
|
if name == "nF":
|
|
|
|
return self.nF
|
|
|
|
elif name == "nP":
|
|
|
|
return self.nP
|
|
|
|
elif name == "nO":
|
|
|
|
return self.nO
|
|
|
|
else:
|
2022-05-02 14:38:46 +03:00
|
|
|
raise ValueError(Errors.E1033.format(name=name))
|
2020-01-29 19:06:46 +03:00
|
|
|
|
2020-02-27 20:42:27 +03:00
|
|
|
def set_dim(self, name, value):
|
|
|
|
if name == "nF":
|
|
|
|
self.nF = value
|
|
|
|
elif name == "nP":
|
|
|
|
self.nP = value
|
|
|
|
elif name == "nO":
|
|
|
|
self.nO = value
|
|
|
|
else:
|
2022-05-02 14:38:46 +03:00
|
|
|
raise ValueError(Errors.E1033.format(name=name))
|
2020-02-27 20:42:27 +03:00
|
|
|
|
2020-01-29 19:06:46 +03:00
|
|
|
def __call__(self, X, bint is_train):
|
|
|
|
if is_train:
|
|
|
|
return self.begin_update(X)
|
|
|
|
else:
|
|
|
|
return self.predict(X), lambda X: X
|
|
|
|
|
|
|
|
def predict(self, X):
|
|
|
|
return self.begin_update(X)[0]
|
2018-05-15 23:17:29 +03:00
|
|
|
|
2020-01-29 19:06:46 +03:00
|
|
|
def begin_update(self, token_ids):
|
2018-05-15 23:17:29 +03:00
|
|
|
cdef np.ndarray state_vector = numpy.zeros(
|
|
|
|
(token_ids.shape[0], self.nO, self.nP), dtype='f')
|
|
|
|
# This is tricky, but (assuming GPU available);
|
|
|
|
# - Input to forward on CPU
|
|
|
|
# - Output from forward on CPU
|
|
|
|
# - Input to backward on GPU!
|
|
|
|
# - Output from backward on GPU
|
|
|
|
bp_hiddens = self._bp_hiddens
|
|
|
|
|
Parser: use C saxpy/sgemm provided by the Ops implementation (#10773)
* Parser: use C saxpy/sgemm provided by the Ops implementation
This is a backport of https://github.com/explosion/spaCy/pull/10747
from the parser refactor branch. It eliminates the explicit calls
to BLIS, instead using the saxpy/sgemm provided by the Ops
implementation.
This allows us to use Accelerate in the parser on M1 Macs (with
an updated thinc-apple-ops).
Performance of the de_core_news_lg pipe:
BLIS 0.7.0, no thinc-apple-ops: 6385 WPS
BLIS 0.7.0, thinc-apple-ops: 36455 WPS
BLIS 0.9.0, no thinc-apple-ops: 19188 WPS
BLIS 0.9.0, thinc-apple-ops: 36682 WPS
This PR, thinc-apple-ops: 38726 WPS
Performance of the de_core_news_lg pipe (only tok2vec -> parser):
BLIS 0.7.0, no thinc-apple-ops: 13907 WPS
BLIS 0.7.0, thinc-apple-ops: 73172 WPS
BLIS 0.9.0, no thinc-apple-ops: 41576 WPS
BLIS 0.9.0, thinc-apple-ops: 72569 WPS
This PR, thinc-apple-ops: 87061 WPS
* Require thinc >=8.1.0,<8.2.0
* Lower thinc lowerbound to 8.1.0.dev0
* Use best CPU ops for CBLAS when the parser model is on the GPU
* Fix another unguarded cblas() call
* Fix: use ops as a shorthand for self.model.ops
Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>
Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>
2022-05-27 12:20:52 +03:00
|
|
|
cdef CBlas cblas
|
|
|
|
if isinstance(self.ops, CupyOps):
|
2022-07-05 11:53:42 +03:00
|
|
|
cblas = NUMPY_OPS.cblas()
|
Parser: use C saxpy/sgemm provided by the Ops implementation (#10773)
* Parser: use C saxpy/sgemm provided by the Ops implementation
This is a backport of https://github.com/explosion/spaCy/pull/10747
from the parser refactor branch. It eliminates the explicit calls
to BLIS, instead using the saxpy/sgemm provided by the Ops
implementation.
This allows us to use Accelerate in the parser on M1 Macs (with
an updated thinc-apple-ops).
Performance of the de_core_news_lg pipe:
BLIS 0.7.0, no thinc-apple-ops: 6385 WPS
BLIS 0.7.0, thinc-apple-ops: 36455 WPS
BLIS 0.9.0, no thinc-apple-ops: 19188 WPS
BLIS 0.9.0, thinc-apple-ops: 36682 WPS
This PR, thinc-apple-ops: 38726 WPS
Performance of the de_core_news_lg pipe (only tok2vec -> parser):
BLIS 0.7.0, no thinc-apple-ops: 13907 WPS
BLIS 0.7.0, thinc-apple-ops: 73172 WPS
BLIS 0.9.0, no thinc-apple-ops: 41576 WPS
BLIS 0.9.0, thinc-apple-ops: 72569 WPS
This PR, thinc-apple-ops: 87061 WPS
* Require thinc >=8.1.0,<8.2.0
* Lower thinc lowerbound to 8.1.0.dev0
* Use best CPU ops for CBLAS when the parser model is on the GPU
* Fix another unguarded cblas() call
* Fix: use ops as a shorthand for self.model.ops
Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>
Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>
2022-05-27 12:20:52 +03:00
|
|
|
else:
|
|
|
|
cblas = self.ops.cblas()
|
|
|
|
|
2018-05-15 23:17:29 +03:00
|
|
|
feat_weights = self.get_feat_weights()
|
|
|
|
cdef int[:, ::1] ids = token_ids
|
2023-07-19 13:03:31 +03:00
|
|
|
sum_state_features(
|
|
|
|
cblas, <float*>state_vector.data,
|
|
|
|
feat_weights, &ids[0, 0],
|
|
|
|
token_ids.shape[0], self.nF, self.nO*self.nP
|
|
|
|
)
|
2020-07-07 02:38:15 +03:00
|
|
|
state_vector += self.bias
|
2018-05-15 23:17:29 +03:00
|
|
|
state_vector, bp_nonlinearity = self._nonlinearity(state_vector)
|
|
|
|
|
2020-01-29 19:06:46 +03:00
|
|
|
def backward(d_state_vector_ids):
|
2018-05-15 23:17:29 +03:00
|
|
|
d_state_vector, token_ids = d_state_vector_ids
|
2020-01-29 19:06:46 +03:00
|
|
|
d_state_vector = bp_nonlinearity(d_state_vector)
|
|
|
|
d_tokens = bp_hiddens((d_state_vector, token_ids))
|
2018-05-15 23:17:29 +03:00
|
|
|
return d_tokens
|
|
|
|
return state_vector, backward
|
|
|
|
|
|
|
|
def _nonlinearity(self, state_vector):
|
2019-11-19 17:54:34 +03:00
|
|
|
if self.activation == "maxout":
|
2020-06-30 22:22:54 +03:00
|
|
|
return self._maxout_nonlinearity(state_vector)
|
2019-11-19 17:54:34 +03:00
|
|
|
else:
|
2020-06-30 22:22:54 +03:00
|
|
|
return self._relu_nonlinearity(state_vector)
|
|
|
|
|
|
|
|
def _maxout_nonlinearity(self, state_vector):
|
|
|
|
state_vector, mask = self.numpy_ops.maxout(state_vector)
|
|
|
|
# We're outputting to CPU, but we need this variable on GPU for the
|
|
|
|
# backward pass.
|
|
|
|
mask = self.ops.asarray(mask)
|
|
|
|
|
|
|
|
def backprop_maxout(d_best):
|
|
|
|
return self.ops.backprop_maxout(d_best, mask, self.nP)
|
2023-07-19 13:03:31 +03:00
|
|
|
|
2020-06-30 22:22:54 +03:00
|
|
|
return state_vector, backprop_maxout
|
|
|
|
|
|
|
|
def _relu_nonlinearity(self, state_vector):
|
2020-07-01 03:48:58 +03:00
|
|
|
state_vector = state_vector.reshape((state_vector.shape[0], -1))
|
2020-06-30 22:22:54 +03:00
|
|
|
mask = state_vector >= 0.
|
|
|
|
state_vector *= mask
|
|
|
|
# We're outputting to CPU, but we need this variable on GPU for the
|
|
|
|
# backward pass.
|
|
|
|
mask = self.ops.asarray(mask)
|
|
|
|
|
|
|
|
def backprop_relu(d_best):
|
|
|
|
d_best *= mask
|
2020-07-01 03:48:58 +03:00
|
|
|
return d_best.reshape((d_best.shape + (1,)))
|
2023-07-19 13:03:31 +03:00
|
|
|
|
2020-06-30 22:22:54 +03:00
|
|
|
return state_vector, backprop_relu
|