diff --git a/setup.py b/setup.py index d5b82ec68..77a4cf283 100755 --- a/setup.py +++ b/setup.py @@ -33,10 +33,12 @@ MOD_NAMES = [ "spacy.kb.candidate", "spacy.kb.kb", "spacy.kb.kb_in_memory", - "spacy.ml.tb_framework", + "spacy.ml.parser_model", "spacy.morphology", + "spacy.pipeline.dep_parser", "spacy.pipeline._edit_tree_internals.edit_trees", "spacy.pipeline.morphologizer", + "spacy.pipeline.ner", "spacy.pipeline.pipe", "spacy.pipeline.trainable_pipe", "spacy.pipeline.sentencizer", @@ -44,7 +46,6 @@ MOD_NAMES = [ "spacy.pipeline.tagger", "spacy.pipeline.transition_parser", "spacy.pipeline._parser_internals.arc_eager", - "spacy.pipeline._parser_internals.batch", "spacy.pipeline._parser_internals.ner", "spacy.pipeline._parser_internals.nonproj", "spacy.pipeline._parser_internals.search", @@ -52,7 +53,6 @@ MOD_NAMES = [ "spacy.pipeline._parser_internals.stateclass", "spacy.pipeline._parser_internals.transition_system", "spacy.pipeline._parser_internals._beam_utils", - "spacy.pipeline._parser_internals._parser_utils", "spacy.tokenizer", "spacy.training.align", "spacy.training.gold_io", diff --git a/spacy/cli/templates/quickstart_training.jinja b/spacy/cli/templates/quickstart_training.jinja index 1c1650cd1..1937ea935 100644 --- a/spacy/cli/templates/quickstart_training.jinja +++ b/spacy/cli/templates/quickstart_training.jinja @@ -90,11 +90,12 @@ grad_factor = 1.0 factory = "parser" [components.parser.model] -@architectures = "spacy.TransitionBasedParser.v3" +@architectures = "spacy.TransitionBasedParser.v2" state_type = "parser" extra_state_tokens = false hidden_width = 128 maxout_pieces = 3 +use_upper = false nO = null [components.parser.model.tok2vec] @@ -110,11 +111,12 @@ grad_factor = 1.0 factory = "ner" [components.ner.model] -@architectures = "spacy.TransitionBasedParser.v3" +@architectures = "spacy.TransitionBasedParser.v2" state_type = "ner" extra_state_tokens = false hidden_width = 64 maxout_pieces = 2 +use_upper = false nO = null [components.ner.model.tok2vec] @@ -383,11 +385,12 @@ width = ${components.tok2vec.model.encode.width} factory = "parser" [components.parser.model] -@architectures = "spacy.TransitionBasedParser.v3" +@architectures = "spacy.TransitionBasedParser.v2" state_type = "parser" extra_state_tokens = false hidden_width = 128 maxout_pieces = 3 +use_upper = true nO = null [components.parser.model.tok2vec] @@ -400,11 +403,12 @@ width = ${components.tok2vec.model.encode.width} factory = "ner" [components.ner.model] -@architectures = "spacy.TransitionBasedParser.v3" +@architectures = "spacy.TransitionBasedParser.v2" state_type = "ner" extra_state_tokens = false hidden_width = 64 maxout_pieces = 2 +use_upper = true nO = null [components.ner.model.tok2vec] diff --git a/spacy/compat.py b/spacy/compat.py index 1e63807a0..30459e2e4 100644 --- a/spacy/compat.py +++ b/spacy/compat.py @@ -23,6 +23,11 @@ try: except ImportError: cupy = None +if sys.version_info[:2] >= (3, 8): # Python 3.8+ + from typing import Literal, Protocol, runtime_checkable +else: + from typing_extensions import Literal, Protocol, runtime_checkable # noqa: F401 + from thinc.api import Optimizer # noqa: F401 pickle = pickle diff --git a/spacy/errors.py b/spacy/errors.py index 0f946b14a..d152cb7b2 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -215,12 +215,6 @@ class Warnings(metaclass=ErrorsWithCodes): "key attribute for vectors, configure it through Vectors(attr=) or " "'spacy init vectors --attr'") - # v4 warning strings - W400 = ("`use_upper=False` is ignored, the upper layer is always enabled") - W401 = ("`incl_prior is True`, but the selected knowledge base type {kb_type} doesn't support prior probability " - "lookups so this setting will be ignored. If your KB does support prior probability lookups, make sure " - "to return `True` in `.supports_prior_probs`.") - class Errors(metaclass=ErrorsWithCodes): E001 = ("No component '{name}' found in pipeline. Available names: {opts}") @@ -1000,6 +994,7 @@ class Errors(metaclass=ErrorsWithCodes): E4011 = ("Server error ({status_code}), couldn't fetch {url}") + RENAMED_LANGUAGE_CODES = {"xx": "mul", "is": "isl"} # fmt: on diff --git a/spacy/ml/_precomputable_affine.py b/spacy/ml/_precomputable_affine.py new file mode 100644 index 000000000..1c20c622b --- /dev/null +++ b/spacy/ml/_precomputable_affine.py @@ -0,0 +1,164 @@ +from thinc.api import Model, normal_init + +from ..util import registry + + +@registry.layers("spacy.PrecomputableAffine.v1") +def PrecomputableAffine(nO, nI, nF, nP, dropout=0.1): + model = Model( + "precomputable_affine", + forward, + init=init, + dims={"nO": nO, "nI": nI, "nF": nF, "nP": nP}, + params={"W": None, "b": None, "pad": None}, + attrs={"dropout_rate": dropout}, + ) + return model + + +def forward(model, X, is_train): + nF = model.get_dim("nF") + nO = model.get_dim("nO") + nP = model.get_dim("nP") + nI = model.get_dim("nI") + W = model.get_param("W") + # Preallocate array for layer output, including padding. + Yf = model.ops.alloc2f(X.shape[0] + 1, nF * nO * nP, zeros=False) + model.ops.gemm(X, W.reshape((nF * nO * nP, nI)), trans2=True, out=Yf[1:]) + Yf = Yf.reshape((Yf.shape[0], nF, nO, nP)) + + # Set padding. Padding has shape (1, nF, nO, nP). Unfortunately, we cannot + # change its shape to (nF, nO, nP) without breaking existing models. So + # we'll squeeze the first dimension here. + Yf[0] = model.ops.xp.squeeze(model.get_param("pad"), 0) + + def backward(dY_ids): + # This backprop is particularly tricky, because we get back a different + # thing from what we put out. We put out an array of shape: + # (nB, nF, nO, nP), and get back: + # (nB, nO, nP) and ids (nB, nF) + # The ids tell us the values of nF, so we would have: + # + # dYf = zeros((nB, nF, nO, nP)) + # for b in range(nB): + # for f in range(nF): + # dYf[b, ids[b, f]] += dY[b] + # + # However, we avoid building that array for efficiency -- and just pass + # in the indices. + dY, ids = dY_ids + assert dY.ndim == 3 + assert dY.shape[1] == nO, dY.shape + assert dY.shape[2] == nP, dY.shape + # nB = dY.shape[0] + model.inc_grad("pad", _backprop_precomputable_affine_padding(model, dY, ids)) + Xf = X[ids] + Xf = Xf.reshape((Xf.shape[0], nF * nI)) + + model.inc_grad("b", dY.sum(axis=0)) + dY = dY.reshape((dY.shape[0], nO * nP)) + + Wopfi = W.transpose((1, 2, 0, 3)) + Wopfi = Wopfi.reshape((nO * nP, nF * nI)) + dXf = model.ops.gemm(dY.reshape((dY.shape[0], nO * nP)), Wopfi) + + dWopfi = model.ops.gemm(dY, Xf, trans1=True) + dWopfi = dWopfi.reshape((nO, nP, nF, nI)) + # (o, p, f, i) --> (f, o, p, i) + dWopfi = dWopfi.transpose((2, 0, 1, 3)) + model.inc_grad("W", dWopfi) + return dXf.reshape((dXf.shape[0], nF, nI)) + + return Yf, backward + + +def _backprop_precomputable_affine_padding(model, dY, ids): + nB = dY.shape[0] + nF = model.get_dim("nF") + nP = model.get_dim("nP") + nO = model.get_dim("nO") + # Backprop the "padding", used as a filler for missing values. + # Values that are missing are set to -1, and each state vector could + # have multiple missing values. The padding has different values for + # different missing features. The gradient of the padding vector is: + # + # for b in range(nB): + # for f in range(nF): + # if ids[b, f] < 0: + # d_pad[f] += dY[b] + # + # Which can be rewritten as: + # + # (ids < 0).T @ dY + mask = model.ops.asarray(ids < 0, dtype="f") + d_pad = model.ops.gemm(mask, dY.reshape(nB, nO * nP), trans1=True) + return d_pad.reshape((1, nF, nO, nP)) + + +def init(model, X=None, Y=None): + """This is like the 'layer sequential unit variance', but instead + of taking the actual inputs, we randomly generate whitened data. + + Why's this all so complicated? We have a huge number of inputs, + and the maxout unit makes guessing the dynamics tricky. Instead + we set the maxout weights to values that empirically result in + whitened outputs given whitened inputs. + """ + if model.has_param("W") and model.get_param("W").any(): + return + + nF = model.get_dim("nF") + nO = model.get_dim("nO") + nP = model.get_dim("nP") + nI = model.get_dim("nI") + W = model.ops.alloc4f(nF, nO, nP, nI) + b = model.ops.alloc2f(nO, nP) + pad = model.ops.alloc4f(1, nF, nO, nP) + + ops = model.ops + W = normal_init(ops, W.shape, mean=float(ops.xp.sqrt(1.0 / nF * nI))) + pad = normal_init(ops, pad.shape, mean=1.0) + model.set_param("W", W) + model.set_param("b", b) + model.set_param("pad", pad) + + ids = ops.alloc((5000, nF), dtype="f") + ids += ops.xp.random.uniform(0, 1000, ids.shape) + ids = ops.asarray(ids, dtype="i") + tokvecs = ops.alloc((5000, nI), dtype="f") + tokvecs += ops.xp.random.normal(loc=0.0, scale=1.0, size=tokvecs.size).reshape( + tokvecs.shape + ) + + def predict(ids, tokvecs): + # nS ids. nW tokvecs. Exclude the padding array. + hiddens = model.predict(tokvecs[:-1]) # (nW, f, o, p) + vectors = model.ops.alloc((ids.shape[0], nO * nP), dtype="f") + # need nS vectors + hiddens = hiddens.reshape((hiddens.shape[0] * nF, nO * nP)) + model.ops.scatter_add(vectors, ids.flatten(), hiddens) + vectors = vectors.reshape((vectors.shape[0], nO, nP)) + vectors += b + vectors = model.ops.asarray(vectors) + if nP >= 2: + return model.ops.maxout(vectors)[0] + else: + return vectors * (vectors >= 0) + + tol_var = 0.01 + tol_mean = 0.01 + t_max = 10 + W = model.get_param("W").copy() + b = model.get_param("b").copy() + for t_i in range(t_max): + acts1 = predict(ids, tokvecs) + var = model.ops.xp.var(acts1) + mean = model.ops.xp.mean(acts1) + if abs(var - 1.0) >= tol_var: + W /= model.ops.xp.sqrt(var) + model.set_param("W", W) + elif abs(mean) >= tol_mean: + b -= mean + model.set_param("b", b) + else: + break diff --git a/spacy/ml/models/parser.py b/spacy/ml/models/parser.py index 422abf4e2..a70d84dea 100644 --- a/spacy/ml/models/parser.py +++ b/spacy/ml/models/parser.py @@ -1,66 +1,23 @@ -import warnings -from typing import Any, List, Literal, Optional, Tuple - -from thinc.api import Model +from typing import Optional, List, cast +from thinc.api import Model, chain, list2array, Linear, zero_init, use_ops from thinc.types import Floats2d -from ...errors import Errors, Warnings -from ...tokens.doc import Doc +from ...errors import Errors +from ...compat import Literal from ...util import registry +from .._precomputable_affine import PrecomputableAffine from ..tb_framework import TransitionModel - -TransitionSystem = Any # TODO -State = Any # TODO - - -@registry.architectures.register("spacy.TransitionBasedParser.v2") -def transition_parser_v2( - tok2vec: Model[List[Doc], List[Floats2d]], - state_type: Literal["parser", "ner"], - extra_state_tokens: bool, - hidden_width: int, - maxout_pieces: int, - use_upper: bool, - nO: Optional[int] = None, -) -> Model: - if not use_upper: - warnings.warn(Warnings.W400) - - return build_tb_parser_model( - tok2vec, - state_type, - extra_state_tokens, - hidden_width, - maxout_pieces, - nO=nO, - ) - - -@registry.architectures.register("spacy.TransitionBasedParser.v3") -def transition_parser_v3( - tok2vec: Model[List[Doc], List[Floats2d]], - state_type: Literal["parser", "ner"], - extra_state_tokens: bool, - hidden_width: int, - maxout_pieces: int, - nO: Optional[int] = None, -) -> Model: - return build_tb_parser_model( - tok2vec, - state_type, - extra_state_tokens, - hidden_width, - maxout_pieces, - nO=nO, - ) +from ...tokens import Doc +@registry.architectures("spacy.TransitionBasedParser.v2") def build_tb_parser_model( tok2vec: Model[List[Doc], List[Floats2d]], state_type: Literal["parser", "ner"], extra_state_tokens: bool, hidden_width: int, maxout_pieces: int, + use_upper: bool, nO: Optional[int] = None, ) -> Model: """ @@ -94,7 +51,14 @@ def build_tb_parser_model( feature sets (for the NER) or 13 (for the parser). hidden_width (int): The width of the hidden layer. maxout_pieces (int): How many pieces to use in the state prediction layer. - Recommended values are 1, 2 or 3. + Recommended values are 1, 2 or 3. If 1, the maxout non-linearity + is replaced with a ReLu non-linearity if use_upper=True, and no + non-linearity if use_upper=False. + use_upper (bool): Whether to use an additional hidden layer after the state + vector in order to predict the action scores. It is recommended to set + this to False for large pretrained models such as transformers, and True + for smaller networks. The upper layer is computed on CPU, which becomes + a bottleneck on larger GPU-based models, where it's also less necessary. nO (int or None): The number of actions the model will predict between. Usually inferred from data at the beginning of training, or loaded from disk. @@ -105,11 +69,106 @@ def build_tb_parser_model( nr_feature_tokens = 6 if extra_state_tokens else 3 else: raise ValueError(Errors.E917.format(value=state_type)) - return TransitionModel( - tok2vec=tok2vec, - state_tokens=nr_feature_tokens, - hidden_width=hidden_width, - maxout_pieces=maxout_pieces, - nO=nO, - unseen_classes=set(), + t2v_width = tok2vec.get_dim("nO") if tok2vec.has_dim("nO") else None + tok2vec = chain( + tok2vec, + list2array(), + Linear(hidden_width, t2v_width), ) + tok2vec.set_dim("nO", hidden_width) + lower = _define_lower( + nO=hidden_width if use_upper else nO, + nF=nr_feature_tokens, + nI=tok2vec.get_dim("nO"), + nP=maxout_pieces, + ) + upper = None + if use_upper: + with use_ops("cpu"): + # Initialize weights at zero, as it's a classification layer. + upper = _define_upper(nO=nO, nI=None) + return TransitionModel(tok2vec, lower, upper, resize_output) + + +def _define_upper(nO, nI): + return Linear(nO=nO, nI=nI, init_W=zero_init) + + +def _define_lower(nO, nF, nI, nP): + return PrecomputableAffine(nO=nO, nF=nF, nI=nI, nP=nP) + + +def resize_output(model, new_nO): + if model.attrs["has_upper"]: + return _resize_upper(model, new_nO) + return _resize_lower(model, new_nO) + + +def _resize_upper(model, new_nO): + upper = model.get_ref("upper") + if upper.has_dim("nO") is None: + upper.set_dim("nO", new_nO) + return model + elif new_nO == upper.get_dim("nO"): + return model + + smaller = upper + nI = smaller.maybe_get_dim("nI") + with use_ops("cpu"): + larger = _define_upper(nO=new_nO, nI=nI) + # it could be that the model is not initialized yet, then skip this bit + if smaller.has_param("W"): + larger_W = larger.ops.alloc2f(new_nO, nI) + larger_b = larger.ops.alloc1f(new_nO) + smaller_W = smaller.get_param("W") + smaller_b = smaller.get_param("b") + # Weights are stored in (nr_out, nr_in) format, so we're basically + # just adding rows here. + if smaller.has_dim("nO"): + old_nO = smaller.get_dim("nO") + larger_W[:old_nO] = smaller_W + larger_b[:old_nO] = smaller_b + for i in range(old_nO, new_nO): + model.attrs["unseen_classes"].add(i) + + larger.set_param("W", larger_W) + larger.set_param("b", larger_b) + model._layers[-1] = larger + model.set_ref("upper", larger) + return model + + +def _resize_lower(model, new_nO): + lower = model.get_ref("lower") + if lower.has_dim("nO") is None: + lower.set_dim("nO", new_nO) + return model + + smaller = lower + nI = smaller.maybe_get_dim("nI") + nF = smaller.maybe_get_dim("nF") + nP = smaller.maybe_get_dim("nP") + larger = _define_lower(nO=new_nO, nI=nI, nF=nF, nP=nP) + # it could be that the model is not initialized yet, then skip this bit + if smaller.has_param("W"): + larger_W = larger.ops.alloc4f(nF, new_nO, nP, nI) + larger_b = larger.ops.alloc2f(new_nO, nP) + larger_pad = larger.ops.alloc4f(1, nF, new_nO, nP) + smaller_W = smaller.get_param("W") + smaller_b = smaller.get_param("b") + smaller_pad = smaller.get_param("pad") + # Copy the old weights and padding into the new layer + if smaller.has_dim("nO"): + old_nO = smaller.get_dim("nO") + larger_W[:, 0:old_nO, :, :] = smaller_W + larger_pad[:, :, 0:old_nO, :] = smaller_pad + larger_b[0:old_nO, :] = smaller_b + for i in range(old_nO, new_nO): + model.attrs["unseen_classes"].add(i) + + larger.set_param("W", larger_W) + larger.set_param("b", larger_b) + larger.set_param("pad", larger_pad) + model._layers[1] = larger + model.set_ref("lower", larger) + return model diff --git a/spacy/ml/parser_model.pxd b/spacy/ml/parser_model.pxd new file mode 100644 index 000000000..8def6cea5 --- /dev/null +++ b/spacy/ml/parser_model.pxd @@ -0,0 +1,49 @@ +from libc.string cimport memset, memcpy +from thinc.backends.cblas cimport CBlas +from ..typedefs cimport weight_t, hash_t +from ..pipeline._parser_internals._state cimport StateC + + +cdef struct SizesC: + int states + int classes + int hiddens + int pieces + int feats + int embed_width + + +cdef struct WeightsC: + const float* feat_weights + const float* feat_bias + const float* hidden_bias + const float* hidden_weights + const float* seen_classes + + +cdef struct ActivationsC: + int* token_ids + float* unmaxed + float* scores + float* hiddens + int* is_valid + int _curr_size + int _max_size + + +cdef WeightsC get_c_weights(model) except * + +cdef SizesC get_c_sizes(model, int batch_size) except * + +cdef ActivationsC alloc_activations(SizesC n) nogil + +cdef void free_activations(const ActivationsC* A) nogil + +cdef void predict_states(CBlas cblas, ActivationsC* A, StateC** states, + const WeightsC* W, SizesC n) nogil + +cdef int arg_max_if_valid(const weight_t* scores, const int* is_valid, int n) nogil + +cdef void cpu_log_loss(float* d_scores, + const float* costs, const int* is_valid, const float* scores, int O) nogil + diff --git a/spacy/ml/parser_model.pyx b/spacy/ml/parser_model.pyx new file mode 100644 index 000000000..91558683b --- /dev/null +++ b/spacy/ml/parser_model.pyx @@ -0,0 +1,500 @@ +# cython: infer_types=True, cdivision=True, boundscheck=False +cimport numpy as np +from libc.math cimport exp +from libc.string cimport memset, memcpy +from libc.stdlib cimport calloc, free, realloc +from thinc.backends.cblas cimport saxpy, sgemm + +import numpy +import numpy.random +from thinc.api import Model, CupyOps, NumpyOps, get_ops + +from .. import util +from ..errors import Errors +from ..typedefs cimport weight_t, class_t, hash_t +from ..pipeline._parser_internals.stateclass cimport StateClass + + +cdef WeightsC get_c_weights(model) except *: + cdef WeightsC output + cdef precompute_hiddens state2vec = model.state2vec + output.feat_weights = state2vec.get_feat_weights() + output.feat_bias = state2vec.bias.data + cdef np.ndarray vec2scores_W + cdef np.ndarray vec2scores_b + if model.vec2scores is None: + output.hidden_weights = NULL + output.hidden_bias = NULL + else: + vec2scores_W = model.vec2scores.get_param("W") + vec2scores_b = model.vec2scores.get_param("b") + output.hidden_weights = vec2scores_W.data + output.hidden_bias = vec2scores_b.data + cdef np.ndarray class_mask = model._class_mask + output.seen_classes = class_mask.data + return output + + +cdef SizesC get_c_sizes(model, int batch_size) except *: + cdef SizesC output + output.states = batch_size + if model.vec2scores is None: + output.classes = model.state2vec.get_dim("nO") + else: + output.classes = model.vec2scores.get_dim("nO") + output.hiddens = model.state2vec.get_dim("nO") + output.pieces = model.state2vec.get_dim("nP") + output.feats = model.state2vec.get_dim("nF") + output.embed_width = model.tokvecs.shape[1] + return output + + +cdef ActivationsC alloc_activations(SizesC n) nogil: + cdef ActivationsC A + memset(&A, 0, sizeof(A)) + resize_activations(&A, n) + return A + + +cdef void free_activations(const ActivationsC* A) nogil: + free(A.token_ids) + free(A.scores) + free(A.unmaxed) + free(A.hiddens) + free(A.is_valid) + + +cdef void resize_activations(ActivationsC* A, SizesC n) nogil: + if n.states <= A._max_size: + A._curr_size = n.states + return + if A._max_size == 0: + A.token_ids = calloc(n.states * n.feats, sizeof(A.token_ids[0])) + A.scores = calloc(n.states * n.classes, sizeof(A.scores[0])) + A.unmaxed = calloc(n.states * n.hiddens * n.pieces, sizeof(A.unmaxed[0])) + A.hiddens = calloc(n.states * n.hiddens, sizeof(A.hiddens[0])) + A.is_valid = calloc(n.states * n.classes, sizeof(A.is_valid[0])) + A._max_size = n.states + else: + A.token_ids = realloc(A.token_ids, + n.states * n.feats * sizeof(A.token_ids[0])) + A.scores = realloc(A.scores, + n.states * n.classes * sizeof(A.scores[0])) + A.unmaxed = realloc(A.unmaxed, + n.states * n.hiddens * n.pieces * sizeof(A.unmaxed[0])) + A.hiddens = realloc(A.hiddens, + n.states * n.hiddens * sizeof(A.hiddens[0])) + A.is_valid = realloc(A.is_valid, + n.states * n.classes * sizeof(A.is_valid[0])) + A._max_size = n.states + A._curr_size = n.states + + +cdef void predict_states(CBlas cblas, ActivationsC* A, StateC** states, + const WeightsC* W, SizesC n) nogil: + cdef double one = 1.0 + resize_activations(A, n) + for i in range(n.states): + states[i].set_context_tokens(&A.token_ids[i*n.feats], n.feats) + memset(A.unmaxed, 0, n.states * n.hiddens * n.pieces * sizeof(float)) + memset(A.hiddens, 0, n.states * n.hiddens * sizeof(float)) + sum_state_features(cblas, A.unmaxed, + W.feat_weights, A.token_ids, n.states, n.feats, n.hiddens * n.pieces) + for i in range(n.states): + saxpy(cblas)(n.hiddens * n.pieces, 1., W.feat_bias, 1, &A.unmaxed[i*n.hiddens*n.pieces], 1) + for j in range(n.hiddens): + index = i * n.hiddens * n.pieces + j * n.pieces + which = _arg_max(&A.unmaxed[index], n.pieces) + A.hiddens[i*n.hiddens + j] = A.unmaxed[index + which] + memset(A.scores, 0, n.states * n.classes * sizeof(float)) + if W.hidden_weights == NULL: + memcpy(A.scores, A.hiddens, n.states * n.classes * sizeof(float)) + else: + # Compute hidden-to-output + sgemm(cblas)(False, True, n.states, n.classes, n.hiddens, + 1.0, A.hiddens, n.hiddens, + W.hidden_weights, n.hiddens, + 0.0, A.scores, n.classes) + # Add bias + for i in range(n.states): + saxpy(cblas)(n.classes, 1., W.hidden_bias, 1, &A.scores[i*n.classes], 1) + # Set unseen classes to minimum value + i = 0 + min_ = A.scores[0] + for i in range(1, n.states * n.classes): + if A.scores[i] < min_: + min_ = A.scores[i] + for i in range(n.states): + for j in range(n.classes): + if not W.seen_classes[j]: + A.scores[i*n.classes+j] = min_ + + +cdef void sum_state_features(CBlas cblas, float* output, + const float* cached, const int* token_ids, int B, int F, int O) nogil: + cdef int idx, b, f, i + cdef const float* feature + padding = cached + cached += F * O + cdef int id_stride = F*O + cdef float one = 1. + for b in range(B): + for f in range(F): + if token_ids[f] < 0: + feature = &padding[f*O] + else: + idx = token_ids[f] * id_stride + f*O + feature = &cached[idx] + saxpy(cblas)(O, one, feature, 1, &output[b*O], 1) + token_ids += F + + +cdef void cpu_log_loss(float* d_scores, + const float* costs, const int* is_valid, const float* scores, + int O) nogil: + """Do multi-label log loss""" + cdef double max_, gmax, Z, gZ + best = arg_max_if_gold(scores, costs, is_valid, O) + guess = _arg_max(scores, O) + + if best == -1 or guess == -1: + # These shouldn't happen, but if they do, we want to make sure we don't + # cause an OOB access. + return + Z = 1e-10 + gZ = 1e-10 + max_ = scores[guess] + gmax = scores[best] + for i in range(O): + Z += exp(scores[i] - max_) + if costs[i] <= costs[best]: + gZ += exp(scores[i] - gmax) + for i in range(O): + if costs[i] <= costs[best]: + d_scores[i] = (exp(scores[i]-max_) / Z) - (exp(scores[i]-gmax)/gZ) + else: + d_scores[i] = exp(scores[i]-max_) / Z + + +cdef int arg_max_if_gold(const weight_t* scores, const weight_t* costs, + const int* is_valid, int n) nogil: + # Find minimum cost + cdef float cost = 1 + for i in range(n): + if is_valid[i] and costs[i] < cost: + cost = costs[i] + # Now find best-scoring with that cost + cdef int best = -1 + for i in range(n): + if costs[i] <= cost and is_valid[i]: + if best == -1 or scores[i] > scores[best]: + best = i + return best + + +cdef int arg_max_if_valid(const weight_t* scores, const int* is_valid, int n) nogil: + cdef int best = -1 + for i in range(n): + if is_valid[i] >= 1: + if best == -1 or scores[i] > scores[best]: + best = i + return best + + + +class ParserStepModel(Model): + def __init__(self, docs, layers, *, has_upper, unseen_classes=None, train=True, + dropout=0.1): + Model.__init__(self, name="parser_step_model", forward=step_forward) + self.attrs["has_upper"] = has_upper + self.attrs["dropout_rate"] = dropout + self.tokvecs, self.bp_tokvecs = layers[0](docs, is_train=train) + if layers[1].get_dim("nP") >= 2: + activation = "maxout" + elif has_upper: + activation = None + else: + activation = "relu" + self.state2vec = precompute_hiddens(len(docs), self.tokvecs, layers[1], + activation=activation, train=train) + if has_upper: + self.vec2scores = layers[-1] + else: + self.vec2scores = None + self.cuda_stream = util.get_cuda_stream(non_blocking=True) + self.backprops = [] + self._class_mask = numpy.zeros((self.nO,), dtype='f') + self._class_mask.fill(1) + if unseen_classes is not None: + for class_ in unseen_classes: + self._class_mask[class_] = 0. + + def clear_memory(self): + del self.tokvecs + del self.bp_tokvecs + del self.state2vec + del self.backprops + del self._class_mask + + @property + def nO(self): + if self.attrs["has_upper"]: + return self.vec2scores.get_dim("nO") + else: + return self.state2vec.get_dim("nO") + + def class_is_unseen(self, class_): + return self._class_mask[class_] + + def mark_class_unseen(self, class_): + self._class_mask[class_] = 0 + + def mark_class_seen(self, class_): + self._class_mask[class_] = 1 + + def get_token_ids(self, states): + cdef StateClass state + states = [state for state in states if not state.is_final()] + cdef np.ndarray ids = numpy.zeros((len(states), self.state2vec.nF), + dtype='i', order='C') + ids.fill(-1) + c_ids = ids.data + for state in states: + state.c.set_context_tokens(c_ids, ids.shape[1]) + c_ids += ids.shape[1] + return ids + + def backprop_step(self, token_ids, d_vector, get_d_tokvecs): + if isinstance(self.state2vec.ops, CupyOps) \ + and not isinstance(token_ids, self.state2vec.ops.xp.ndarray): + # Move token_ids and d_vector to GPU, asynchronously + self.backprops.append(( + util.get_async(self.cuda_stream, token_ids), + util.get_async(self.cuda_stream, d_vector), + get_d_tokvecs + )) + else: + self.backprops.append((token_ids, d_vector, get_d_tokvecs)) + + + def finish_steps(self, golds): + # Add a padding vector to the d_tokvecs gradient, so that missing + # values don't affect the real gradient. + d_tokvecs = self.ops.alloc((self.tokvecs.shape[0]+1, self.tokvecs.shape[1])) + # Tells CUDA to block, so our async copies complete. + if self.cuda_stream is not None: + self.cuda_stream.synchronize() + for ids, d_vector, bp_vector in self.backprops: + d_state_features = bp_vector((d_vector, ids)) + ids = ids.flatten() + d_state_features = d_state_features.reshape( + (ids.size, d_state_features.shape[2])) + self.ops.scatter_add(d_tokvecs, ids, + d_state_features) + # Padded -- see update() + self.bp_tokvecs(d_tokvecs[:-1]) + return d_tokvecs + +NUMPY_OPS = NumpyOps() + +def step_forward(model: ParserStepModel, states, is_train): + token_ids = model.get_token_ids(states) + vector, get_d_tokvecs = model.state2vec(token_ids, is_train) + mask = None + if model.attrs["has_upper"]: + dropout_rate = model.attrs["dropout_rate"] + if is_train and dropout_rate > 0: + mask = NUMPY_OPS.get_dropout_mask(vector.shape, 0.1) + vector *= mask + scores, get_d_vector = model.vec2scores(vector, is_train) + else: + scores = NumpyOps().asarray(vector) + get_d_vector = lambda d_scores: d_scores + # If the class is unseen, make sure its score is minimum + scores[:, model._class_mask == 0] = numpy.nanmin(scores) + + def backprop_parser_step(d_scores): + # Zero vectors for unseen classes + d_scores *= model._class_mask + d_vector = get_d_vector(d_scores) + if mask is not None: + d_vector *= mask + model.backprop_step(token_ids, d_vector, get_d_tokvecs) + return None + return scores, backprop_parser_step + + +cdef class precompute_hiddens: + """Allow a model to be "primed" by pre-computing input features in bulk. + + This is used for the parser, where we want to take a batch of documents, + and compute vectors for each (token, position) pair. These vectors can then + be reused, especially for beam-search. + + Let's say we're using 12 features for each state, e.g. word at start of + buffer, three words on stack, their children, etc. In the normal arc-eager + system, a document of length N is processed in 2*N states. This means we'll + create 2*N*12 feature vectors --- but if we pre-compute, we only need + N*12 vector computations. The saving for beam-search is much better: + if we have a beam of k, we'll normally make 2*N*12*K computations -- + so we can save the factor k. This also gives a nice CPU/GPU division: + we can do all our hard maths up front, packed into large multiplications, + and do the hard-to-program parsing on the CPU. + """ + cdef readonly int nF, nO, nP + cdef bint _is_synchronized + cdef public object ops + cdef public object numpy_ops + cdef public object _cpu_ops + cdef np.ndarray _features + cdef np.ndarray _cached + cdef np.ndarray bias + cdef object _cuda_stream + cdef object _bp_hiddens + cdef object activation + + def __init__(self, batch_size, tokvecs, lower_model, cuda_stream=None, + activation="maxout", train=False): + gpu_cached, bp_features = lower_model(tokvecs, train) + cdef np.ndarray cached + if not isinstance(gpu_cached, numpy.ndarray): + # Note the passing of cuda_stream here: it lets + # cupy make the copy asynchronously. + # We then have to block before first use. + cached = gpu_cached.get(stream=cuda_stream) + else: + cached = gpu_cached + if not isinstance(lower_model.get_param("b"), numpy.ndarray): + self.bias = lower_model.get_param("b").get(stream=cuda_stream) + else: + self.bias = lower_model.get_param("b") + self.nF = cached.shape[1] + if lower_model.has_dim("nP"): + self.nP = lower_model.get_dim("nP") + else: + self.nP = 1 + self.nO = cached.shape[2] + self.ops = lower_model.ops + self.numpy_ops = NumpyOps() + self._cpu_ops = get_ops("cpu") if isinstance(self.ops, CupyOps) else self.ops + assert activation in (None, "relu", "maxout") + self.activation = activation + self._is_synchronized = False + self._cuda_stream = cuda_stream + self._cached = cached + self._bp_hiddens = bp_features + + cdef const float* get_feat_weights(self) except NULL: + if not self._is_synchronized and self._cuda_stream is not None: + self._cuda_stream.synchronize() + self._is_synchronized = True + return self._cached.data + + def has_dim(self, name): + if name == "nF": + return self.nF if self.nF is not None else True + elif name == "nP": + return self.nP if self.nP is not None else True + elif name == "nO": + return self.nO if self.nO is not None else True + else: + return False + + def get_dim(self, name): + if name == "nF": + return self.nF + elif name == "nP": + return self.nP + elif name == "nO": + return self.nO + else: + raise ValueError(Errors.E1033.format(name=name)) + + def set_dim(self, name, value): + if name == "nF": + self.nF = value + elif name == "nP": + self.nP = value + elif name == "nO": + self.nO = value + else: + raise ValueError(Errors.E1033.format(name=name)) + + def __call__(self, X, bint is_train): + if is_train: + return self.begin_update(X) + else: + return self.predict(X), lambda X: X + + def predict(self, X): + return self.begin_update(X)[0] + + def begin_update(self, token_ids): + cdef np.ndarray state_vector = numpy.zeros( + (token_ids.shape[0], self.nO, self.nP), dtype='f') + # This is tricky, but (assuming GPU available); + # - Input to forward on CPU + # - Output from forward on CPU + # - Input to backward on GPU! + # - Output from backward on GPU + bp_hiddens = self._bp_hiddens + + cdef CBlas cblas = self._cpu_ops.cblas() + + feat_weights = self.get_feat_weights() + cdef int[:, ::1] ids = token_ids + sum_state_features(cblas, state_vector.data, + feat_weights, &ids[0,0], + token_ids.shape[0], self.nF, self.nO*self.nP) + state_vector += self.bias + state_vector, bp_nonlinearity = self._nonlinearity(state_vector) + + def backward(d_state_vector_ids): + d_state_vector, token_ids = d_state_vector_ids + d_state_vector = bp_nonlinearity(d_state_vector) + d_tokens = bp_hiddens((d_state_vector, token_ids)) + return d_tokens + return state_vector, backward + + def _nonlinearity(self, state_vector): + if self.activation == "maxout": + return self._maxout_nonlinearity(state_vector) + else: + return self._relu_nonlinearity(state_vector) + + def _maxout_nonlinearity(self, state_vector): + state_vector, mask = self.numpy_ops.maxout(state_vector) + # We're outputting to CPU, but we need this variable on GPU for the + # backward pass. + mask = self.ops.asarray(mask) + + def backprop_maxout(d_best): + return self.ops.backprop_maxout(d_best, mask, self.nP) + + return state_vector, backprop_maxout + + def _relu_nonlinearity(self, state_vector): + state_vector = state_vector.reshape((state_vector.shape[0], -1)) + mask = state_vector >= 0. + state_vector *= mask + # We're outputting to CPU, but we need this variable on GPU for the + # backward pass. + mask = self.ops.asarray(mask) + + def backprop_relu(d_best): + d_best *= mask + return d_best.reshape((d_best.shape + (1,))) + + return state_vector, backprop_relu + +cdef inline int _arg_max(const float* scores, const int n_classes) nogil: + if n_classes == 2: + return 0 if scores[0] > scores[1] else 1 + cdef int i + cdef int best = 0 + cdef float mode = scores[0] + for i in range(1, n_classes): + if scores[i] > mode: + mode = scores[i] + best = i + return best diff --git a/spacy/ml/tb_framework.pxd b/spacy/ml/tb_framework.pxd deleted file mode 100644 index 965508519..000000000 --- a/spacy/ml/tb_framework.pxd +++ /dev/null @@ -1,28 +0,0 @@ -from libc.stdint cimport int8_t - - -cdef struct SizesC: - int states - int classes - int hiddens - int pieces - int feats - int embed_width - int tokens - - -cdef struct WeightsC: - const float* feat_weights - const float* feat_bias - const float* hidden_bias - const float* hidden_weights - const int8_t* seen_mask - - -cdef struct ActivationsC: - int* token_ids - float* unmaxed - float* hiddens - int* is_valid - int _curr_size - int _max_size diff --git a/spacy/ml/tb_framework.py b/spacy/ml/tb_framework.py new file mode 100644 index 000000000..ab4a969e2 --- /dev/null +++ b/spacy/ml/tb_framework.py @@ -0,0 +1,50 @@ +from thinc.api import Model, noop +from .parser_model import ParserStepModel +from ..util import registry + + +@registry.layers("spacy.TransitionModel.v1") +def TransitionModel( + tok2vec, lower, upper, resize_output, dropout=0.2, unseen_classes=set() +): + """Set up a stepwise transition-based model""" + if upper is None: + has_upper = False + upper = noop() + else: + has_upper = True + # don't define nO for this object, because we can't dynamically change it + return Model( + name="parser_model", + forward=forward, + dims={"nI": tok2vec.maybe_get_dim("nI")}, + layers=[tok2vec, lower, upper], + refs={"tok2vec": tok2vec, "lower": lower, "upper": upper}, + init=init, + attrs={ + "has_upper": has_upper, + "unseen_classes": set(unseen_classes), + "resize_output": resize_output, + }, + ) + + +def forward(model, X, is_train): + step_model = ParserStepModel( + X, + model.layers, + unseen_classes=model.attrs["unseen_classes"], + train=is_train, + has_upper=model.attrs["has_upper"], + ) + + return step_model, step_model.finish_steps + + +def init(model, X=None, Y=None): + model.get_ref("tok2vec").initialize(X=X) + lower = model.get_ref("lower") + lower.initialize() + if model.attrs["has_upper"]: + statevecs = model.ops.alloc2f(2, lower.get_dim("nO")) + model.get_ref("upper").initialize(X=statevecs) diff --git a/spacy/ml/tb_framework.pyx b/spacy/ml/tb_framework.pyx deleted file mode 100644 index e497643f0..000000000 --- a/spacy/ml/tb_framework.pyx +++ /dev/null @@ -1,639 +0,0 @@ -# cython: infer_types=True, cdivision=True, boundscheck=False -from typing import Any, List, Optional, Tuple, cast - -from libc.stdlib cimport calloc, free, realloc -from libc.string cimport memcpy, memset -from libcpp.vector cimport vector - -import numpy - -cimport numpy as np - -from thinc.api import ( - Linear, - Model, - NumpyOps, - chain, - glorot_uniform_init, - list2array, - normal_init, - uniform_init, - zero_init, -) - -from thinc.backends.cblas cimport CBlas, saxpy, sgemm - -from thinc.types import Floats2d, Floats3d, Floats4d, Ints1d, Ints2d - -from ..errors import Errors -from ..pipeline._parser_internals import _beam_utils -from ..pipeline._parser_internals.batch import GreedyBatch - -from ..pipeline._parser_internals._parser_utils cimport arg_max -from ..pipeline._parser_internals.stateclass cimport StateC, StateClass -from ..pipeline._parser_internals.transition_system cimport ( - TransitionSystem, - c_apply_actions, - c_transition_batch, -) - -from ..tokens.doc import Doc -from ..util import registry - -State = Any # TODO - - -@registry.layers("spacy.TransitionModel.v2") -def TransitionModel( - *, - tok2vec: Model[List[Doc], List[Floats2d]], - beam_width: int = 1, - beam_density: float = 0.0, - state_tokens: int, - hidden_width: int, - maxout_pieces: int, - nO: Optional[int] = None, - unseen_classes=set(), -) -> Model[Tuple[List[Doc], TransitionSystem], List[Tuple[State, List[Floats2d]]]]: - """Set up a transition-based parsing model, using a maxout hidden - layer and a linear output layer. - """ - t2v_width = tok2vec.get_dim("nO") if tok2vec.has_dim("nO") else None - tok2vec_projected = chain(tok2vec, list2array(), Linear(hidden_width, t2v_width)) # type: ignore - tok2vec_projected.set_dim("nO", hidden_width) - - # FIXME: we use `output` as a container for the output layer's - # weights and biases. Thinc optimizers cannot handle resizing - # of parameters. So, when the parser model is resized, we - # construct a new `output` layer, which has a different key in - # the optimizer. Once the optimizer supports parameter resizing, - # we can replace the `output` layer by `output_W` and `output_b` - # parameters in this model. - output = Linear(nO=None, nI=hidden_width, init_W=zero_init) - - return Model( - name="parser_model", - forward=forward, - init=init, - layers=[tok2vec_projected, output], - refs={ - "tok2vec": tok2vec_projected, - "output": output, - }, - params={ - "hidden_W": None, # Floats2d W for the hidden layer - "hidden_b": None, # Floats1d bias for the hidden layer - "hidden_pad": None, # Floats1d padding for the hidden layer - }, - dims={ - "nO": None, # Output size - "nP": maxout_pieces, - "nH": hidden_width, - "nI": tok2vec_projected.maybe_get_dim("nO"), - "nF": state_tokens, - }, - attrs={ - "beam_width": beam_width, - "beam_density": beam_density, - "unseen_classes": set(unseen_classes), - "resize_output": resize_output, - }, - ) - - -def resize_output(model: Model, new_nO: int) -> Model: - old_nO = model.maybe_get_dim("nO") - output = model.get_ref("output") - if old_nO is None: - model.set_dim("nO", new_nO) - output.set_dim("nO", new_nO) - output.initialize() - return model - elif new_nO <= old_nO: - return model - elif output.has_param("W"): - nH = model.get_dim("nH") - new_output = Linear(nO=new_nO, nI=nH, init_W=zero_init) - new_output.initialize() - new_W = new_output.get_param("W") - new_b = new_output.get_param("b") - old_W = output.get_param("W") - old_b = output.get_param("b") - new_W[:old_nO] = old_W # type: ignore - new_b[:old_nO] = old_b # type: ignore - for i in range(old_nO, new_nO): - model.attrs["unseen_classes"].add(i) - model.layers[-1] = new_output - model.set_ref("output", new_output) - # TODO: Avoid this private intrusion - model._dims["nO"] = new_nO - return model - - -def init( - model, - X: Optional[Tuple[List[Doc], TransitionSystem]] = None, - Y: Optional[Tuple[List[State], List[Floats2d]]] = None, -): - if X is not None: - docs, _ = X - model.get_ref("tok2vec").initialize(X=docs) - else: - model.get_ref("tok2vec").initialize() - inferred_nO = _infer_nO(Y) - if inferred_nO is not None: - current_nO = model.maybe_get_dim("nO") - if current_nO is None or current_nO != inferred_nO: - model.attrs["resize_output"](model, inferred_nO) - nP = model.get_dim("nP") - nH = model.get_dim("nH") - nI = model.get_dim("nI") - nF = model.get_dim("nF") - ops = model.ops - - Wl = ops.alloc2f(nH * nP, nF * nI) - bl = ops.alloc1f(nH * nP) - padl = ops.alloc1f(nI) - # Wl = zero_init(ops, Wl.shape) - Wl = glorot_uniform_init(ops, Wl.shape) - padl = uniform_init(ops, padl.shape) # type: ignore - # TODO: Experiment with whether better to initialize output_W - model.set_param("hidden_W", Wl) - model.set_param("hidden_b", bl) - model.set_param("hidden_pad", padl) - # model = _lsuv_init(model) - return model - - -class TransitionModelInputs: - """ - Input to transition model. - """ - - # dataclass annotation is not yet supported in Cython 0.29.x, - # so, we'll do something close to it. - - actions: Optional[List[Ints1d]] - docs: List[Doc] - max_moves: int - moves: TransitionSystem - states: Optional[List[State]] - - __slots__ = [ - "actions", - "docs", - "max_moves", - "moves", - "states", - ] - - def __init__( - self, - docs: List[Doc], - moves: TransitionSystem, - actions: Optional[List[Ints1d]] = None, - max_moves: int = 0, - states: Optional[List[State]] = None, - ): - """ - actions (Optional[List[Ints1d]]): actions to apply for each Doc. - docs (List[Doc]): Docs to predict transition sequences for. - max_moves: (int): the maximum number of moves to apply, values less - than 1 will apply moves to states until they are final states. - moves (TransitionSystem): the transition system to use when predicting - the transition sequences. - states (Optional[List[States]]): the initial states to predict the - transition sequences for. When absent, the initial states are - initialized from the provided Docs. - """ - self.actions = actions - self.docs = docs - self.moves = moves - self.max_moves = max_moves - self.states = states - - -def forward(model, inputs: TransitionModelInputs, is_train: bool): - docs = inputs.docs - moves = inputs.moves - actions = inputs.actions - - beam_width = model.attrs["beam_width"] - hidden_pad = model.get_param("hidden_pad") - tok2vec = model.get_ref("tok2vec") - - states = moves.init_batch(docs) if inputs.states is None else inputs.states - tokvecs, backprop_tok2vec = tok2vec(docs, is_train) - tokvecs = model.ops.xp.vstack((tokvecs, hidden_pad)) - feats, backprop_feats = _forward_precomputable_affine(model, tokvecs, is_train) - seen_mask = _get_seen_mask(model) - - if not is_train and beam_width == 1 and isinstance(model.ops, NumpyOps): - # Note: max_moves is only used during training, so we don't need to - # pass it to the greedy inference path. - return _forward_greedy_cpu(model, moves, states, feats, seen_mask, actions=actions) - else: - return _forward_fallback(model, moves, states, tokvecs, backprop_tok2vec, - feats, backprop_feats, seen_mask, is_train, actions=actions, - max_moves=inputs.max_moves) - - -def _forward_greedy_cpu(model: Model, TransitionSystem moves, states: List[StateClass], np.ndarray feats, - np.ndarray[np.npy_bool, ndim = 1] seen_mask, actions: Optional[List[Ints1d]] = None): - cdef vector[StateC*] c_states - cdef StateClass state - for state in states: - if not state.is_final(): - c_states.push_back(state.c) - weights = _get_c_weights(model, feats.data, seen_mask) - # Precomputed features have rows for each token, plus one for padding. - cdef int n_tokens = feats.shape[0] - 1 - sizes = _get_c_sizes(model, c_states.size(), n_tokens) - cdef CBlas cblas = model.ops.cblas() - scores = _parse_batch(cblas, moves, &c_states[0], weights, sizes, actions=actions) - - def backprop(dY): - raise ValueError(Errors.E4004) - - return (states, scores), backprop - - -cdef list _parse_batch(CBlas cblas, TransitionSystem moves, StateC** states, - WeightsC weights, SizesC sizes, actions: Optional[List[Ints1d]]=None): - cdef int i - cdef vector[StateC *] unfinished - cdef ActivationsC activations = _alloc_activations(sizes) - cdef np.ndarray step_scores - cdef np.ndarray step_actions - - scores = [] - while sizes.states >= 1: - step_scores = numpy.empty((sizes.states, sizes.classes), dtype="f") - step_actions = actions[0] if actions is not None else None - with nogil: - _predict_states(cblas, &activations, step_scores.data, states, &weights, sizes) - if actions is None: - # Validate actions, argmax, take action. - c_transition_batch(moves, states, step_scores.data, sizes.classes, - sizes.states) - else: - c_apply_actions(moves, states, step_actions.data, sizes.states) - for i in range(sizes.states): - if not states[i].is_final(): - unfinished.push_back(states[i]) - for i in range(unfinished.size()): - states[i] = unfinished[i] - sizes.states = unfinished.size() - scores.append(step_scores) - unfinished.clear() - actions = actions[1:] if actions is not None else None - _free_activations(&activations) - - return scores - - -def _forward_fallback( - model: Model, - moves: TransitionSystem, - states: List[StateClass], - tokvecs, backprop_tok2vec, - feats, - backprop_feats, - seen_mask, - is_train: bool, - actions: Optional[List[Ints1d]] = None, - max_moves: int = 0, -): - nF = model.get_dim("nF") - output = model.get_ref("output") - hidden_b = model.get_param("hidden_b") - nH = model.get_dim("nH") - nP = model.get_dim("nP") - - beam_width = model.attrs["beam_width"] - beam_density = model.attrs["beam_density"] - - ops = model.ops - - all_ids = [] - all_which = [] - all_statevecs = [] - all_scores = [] - if beam_width == 1: - batch = GreedyBatch(moves, states, None) - else: - batch = _beam_utils.BeamBatch( - moves, states, None, width=beam_width, density=beam_density - ) - arange = ops.xp.arange(nF) - n_moves = 0 - while not batch.is_done: - ids = numpy.zeros((len(batch.get_unfinished_states()), nF), dtype="i") - for i, state in enumerate(batch.get_unfinished_states()): - state.set_context_tokens(ids, i, nF) - # Sum the state features, add the bias and apply the activation (maxout) - # to create the state vectors. - preacts2f = feats[ids, arange].sum(axis=1) # type: ignore - preacts2f += hidden_b - preacts = ops.reshape3f(preacts2f, preacts2f.shape[0], nH, nP) - assert preacts.shape[0] == len(batch.get_unfinished_states()), preacts.shape - statevecs, which = ops.maxout(preacts) - # We don't use output's backprop, since we want to backprop for - # all states at once, rather than a single state. - scores = output.predict(statevecs) - scores[:, seen_mask] = ops.xp.nanmin(scores) - # Transition the states, filtering out any that are finished. - cpu_scores = ops.to_numpy(scores) - if actions is None: - batch.advance(cpu_scores) - else: - batch.advance_with_actions(actions[0]) - actions = actions[1:] - all_scores.append(scores) - if is_train: - # Remember intermediate results for the backprop. - all_ids.append(ids) - all_statevecs.append(statevecs) - all_which.append(which) - if n_moves >= max_moves >= 1: - break - n_moves += 1 - - def backprop_parser(d_states_d_scores): - ids = ops.xp.vstack(all_ids) - which = ops.xp.vstack(all_which) - statevecs = ops.xp.vstack(all_statevecs) - _, d_scores = d_states_d_scores - if model.attrs.get("unseen_classes"): - # If we have a negative gradient (i.e. the probability should - # increase) on any classes we filtered out as unseen, mark - # them as seen. - for clas in set(model.attrs["unseen_classes"]): - if (d_scores[:, clas] < 0).any(): - model.attrs["unseen_classes"].remove(clas) - d_scores *= seen_mask == False # no-cython-lint - # Calculate the gradients for the parameters of the output layer. - # The weight gemm is (nS, nO) @ (nS, nH).T - output.inc_grad("b", d_scores.sum(axis=0)) - output.inc_grad("W", ops.gemm(d_scores, statevecs, trans1=True)) - # Now calculate d_statevecs, by backproping through the output linear layer. - # This gemm is (nS, nO) @ (nO, nH) - output_W = output.get_param("W") - d_statevecs = ops.gemm(d_scores, output_W) - # Backprop through the maxout activation - d_preacts = ops.backprop_maxout(d_statevecs, which, nP) - d_preacts2f = ops.reshape2f(d_preacts, d_preacts.shape[0], nH * nP) - model.inc_grad("hidden_b", d_preacts2f.sum(axis=0)) - # We don't need to backprop the summation, because we pass back the IDs instead - d_state_features = backprop_feats((d_preacts2f, ids)) - d_tokvecs = ops.alloc2f(tokvecs.shape[0], tokvecs.shape[1]) - ops.scatter_add(d_tokvecs, ids, d_state_features) - model.inc_grad("hidden_pad", d_tokvecs[-1]) - return (backprop_tok2vec(d_tokvecs[:-1]), None) - - return (list(batch), all_scores), backprop_parser - - -def _get_seen_mask(model: Model) -> numpy.array[bool, 1]: - mask = model.ops.xp.zeros(model.get_dim("nO"), dtype="bool") - for class_ in model.attrs.get("unseen_classes", set()): - mask[class_] = True - return mask - - -def _forward_precomputable_affine(model, X: Floats2d, is_train: bool): - W: Floats2d = model.get_param("hidden_W") - nF = model.get_dim("nF") - nH = model.get_dim("nH") - nP = model.get_dim("nP") - nI = model.get_dim("nI") - # The weights start out (nH * nP, nF * nI). Transpose and reshape to (nF * nH *nP, nI) - W3f = model.ops.reshape3f(W, nH * nP, nF, nI) - W3f = W3f.transpose((1, 0, 2)) - W2f = model.ops.reshape2f(W3f, nF * nH * nP, nI) - assert X.shape == (X.shape[0], nI), X.shape - Yf_ = model.ops.gemm(X, W2f, trans2=True) - Yf = model.ops.reshape3f(Yf_, Yf_.shape[0], nF, nH * nP) - - def backward(dY_ids: Tuple[Floats3d, Ints2d]): - # This backprop is particularly tricky, because we get back a different - # thing from what we put out. We put out an array of shape: - # (nB, nF, nH, nP), and get back: - # (nB, nH, nP) and ids (nB, nF) - # The ids tell us the values of nF, so we would have: - # - # dYf = zeros((nB, nF, nH, nP)) - # for b in range(nB): - # for f in range(nF): - # dYf[b, ids[b, f]] += dY[b] - # - # However, we avoid building that array for efficiency -- and just pass - # in the indices. - dY, ids = dY_ids - dXf = model.ops.gemm(dY, W) - Xf = X[ids].reshape((ids.shape[0], -1)) - dW = model.ops.gemm(dY, Xf, trans1=True) - model.inc_grad("hidden_W", dW) - return model.ops.reshape3f(dXf, dXf.shape[0], nF, nI) - - return Yf, backward - - -def _infer_nO(Y: Optional[Tuple[List[State], List[Floats2d]]]) -> Optional[int]: - if Y is None: - return None - _, scores = Y - if len(scores) == 0: - return None - assert scores[0].shape[0] >= 1 - assert len(scores[0].shape) == 2 - return scores[0].shape[1] - - -def _lsuv_init(model: Model): - """This is like the 'layer sequential unit variance', but instead - of taking the actual inputs, we randomly generate whitened data. - - Why's this all so complicated? We have a huge number of inputs, - and the maxout unit makes guessing the dynamics tricky. Instead - we set the maxout weights to values that empirically result in - whitened outputs given whitened inputs. - """ - W = model.maybe_get_param("hidden_W") - if W is not None and W.any(): - return - - nF = model.get_dim("nF") - nH = model.get_dim("nH") - nP = model.get_dim("nP") - nI = model.get_dim("nI") - W = model.ops.alloc4f(nF, nH, nP, nI) - b = model.ops.alloc2f(nH, nP) - pad = model.ops.alloc4f(1, nF, nH, nP) - - ops = model.ops - W = normal_init(ops, W.shape, mean=float(ops.xp.sqrt(1.0 / nF * nI))) - pad = normal_init(ops, pad.shape, mean=1.0) - model.set_param("W", W) - model.set_param("b", b) - model.set_param("pad", pad) - - ids = ops.alloc_f((5000, nF), dtype="f") - ids += ops.xp.random.uniform(0, 1000, ids.shape) - ids = ops.asarray(ids, dtype="i") - tokvecs = ops.alloc_f((5000, nI), dtype="f") - tokvecs += ops.xp.random.normal(loc=0.0, scale=1.0, size=tokvecs.size).reshape( - tokvecs.shape - ) - - def predict(ids, tokvecs): - # nS ids. nW tokvecs. Exclude the padding array. - hiddens, _ = _forward_precomputable_affine(model, tokvecs[:-1], False) - vectors = model.ops.alloc2f(ids.shape[0], nH * nP) - # need nS vectors - hiddens = hiddens.reshape((hiddens.shape[0] * nF, nH * nP)) - model.ops.scatter_add(vectors, ids.flatten(), hiddens) - vectors3f = model.ops.reshape3f(vectors, vectors.shape[0], nH, nP) - vectors3f += b - return model.ops.maxout(vectors3f)[0] - - tol_var = 0.01 - tol_mean = 0.01 - t_max = 10 - W = cast(Floats4d, model.get_param("hidden_W").copy()) - b = cast(Floats2d, model.get_param("hidden_b").copy()) - for t_i in range(t_max): - acts1 = predict(ids, tokvecs) - var = model.ops.xp.var(acts1) - mean = model.ops.xp.mean(acts1) - if abs(var - 1.0) >= tol_var: - W /= model.ops.xp.sqrt(var) - model.set_param("hidden_W", W) - elif abs(mean) >= tol_mean: - b -= mean - model.set_param("hidden_b", b) - else: - break - return model - - -cdef WeightsC _get_c_weights(model, const float* feats, np.ndarray[np.npy_bool, ndim=1] seen_mask) except *: - output = model.get_ref("output") - cdef np.ndarray hidden_b = model.get_param("hidden_b") - cdef np.ndarray output_W = output.get_param("W") - cdef np.ndarray output_b = output.get_param("b") - - cdef WeightsC weights - weights.feat_weights = feats - weights.feat_bias = hidden_b.data - weights.hidden_weights = output_W.data - weights.hidden_bias = output_b.data - weights.seen_mask = seen_mask.data - - return weights - - -cdef SizesC _get_c_sizes(model, int batch_size, int tokens) except *: - cdef SizesC sizes - sizes.states = batch_size - sizes.classes = model.get_dim("nO") - sizes.hiddens = model.get_dim("nH") - sizes.pieces = model.get_dim("nP") - sizes.feats = model.get_dim("nF") - sizes.embed_width = model.get_dim("nI") - sizes.tokens = tokens - return sizes - - -cdef ActivationsC _alloc_activations(SizesC n) nogil: - cdef ActivationsC A - memset(&A, 0, sizeof(A)) - _resize_activations(&A, n) - return A - - -cdef void _free_activations(const ActivationsC* A) nogil: - free(A.token_ids) - free(A.unmaxed) - free(A.hiddens) - free(A.is_valid) - - -cdef void _resize_activations(ActivationsC* A, SizesC n) nogil: - if n.states <= A._max_size: - A._curr_size = n.states - return - if A._max_size == 0: - A.token_ids = calloc(n.states * n.feats, sizeof(A.token_ids[0])) - A.unmaxed = calloc(n.states * n.hiddens * n.pieces, sizeof(A.unmaxed[0])) - A.hiddens = calloc(n.states * n.hiddens, sizeof(A.hiddens[0])) - A.is_valid = calloc(n.states * n.classes, sizeof(A.is_valid[0])) - A._max_size = n.states - else: - A.token_ids = realloc(A.token_ids, - n.states * n.feats * sizeof(A.token_ids[0])) - A.unmaxed = realloc(A.unmaxed, - n.states * n.hiddens * n.pieces * sizeof(A.unmaxed[0])) - A.hiddens = realloc(A.hiddens, - n.states * n.hiddens * sizeof(A.hiddens[0])) - A.is_valid = realloc(A.is_valid, - n.states * n.classes * sizeof(A.is_valid[0])) - A._max_size = n.states - A._curr_size = n.states - - -cdef void _predict_states(CBlas cblas, ActivationsC* A, float* scores, StateC** states, const WeightsC* W, SizesC n) nogil: - _resize_activations(A, n) - for i in range(n.states): - states[i].set_context_tokens(&A.token_ids[i*n.feats], n.feats) - memset(A.unmaxed, 0, n.states * n.hiddens * n.pieces * sizeof(float)) - _sum_state_features(cblas, A.unmaxed, W.feat_weights, A.token_ids, n) - for i in range(n.states): - saxpy(cblas)(n.hiddens * n.pieces, 1., W.feat_bias, 1, &A.unmaxed[i*n.hiddens*n.pieces], 1) - for j in range(n.hiddens): - index = i * n.hiddens * n.pieces + j * n.pieces - which = arg_max(&A.unmaxed[index], n.pieces) - A.hiddens[i*n.hiddens + j] = A.unmaxed[index + which] - if W.hidden_weights == NULL: - memcpy(scores, A.hiddens, n.states * n.classes * sizeof(float)) - else: - # Compute hidden-to-output - sgemm(cblas)(False, True, n.states, n.classes, n.hiddens, - 1.0, A.hiddens, n.hiddens, - W.hidden_weights, n.hiddens, - 0.0, scores, n.classes) - # Add bias - for i in range(n.states): - saxpy(cblas)(n.classes, 1., W.hidden_bias, 1, &scores[i*n.classes], 1) - # Set unseen classes to minimum value - i = 0 - min_ = scores[0] - for i in range(1, n.states * n.classes): - if scores[i] < min_: - min_ = scores[i] - for i in range(n.states): - for j in range(n.classes): - if W.seen_mask[j]: - scores[i*n.classes+j] = min_ - - -cdef void _sum_state_features(CBlas cblas, float* output, const float* cached, - const int* token_ids, SizesC n) nogil: - cdef int idx, b, f - cdef const float* feature - cdef int B = n.states - cdef int O = n.hiddens * n.pieces # no-cython-lint - cdef int F = n.feats - cdef int T = n.tokens - padding = cached + (T * F * O) - cdef int id_stride = F*O - cdef float one = 1. - for b in range(B): - for f in range(F): - if token_ids[f] < 0: - feature = &padding[f*O] - else: - idx = token_ids[f] * id_stride + f*O - feature = &cached[idx] - saxpy(cblas)(O, one, feature, 1, &output[b*O], 1) - token_ids += F diff --git a/spacy/pipeline/_parser_internals/_beam_utils.pyx b/spacy/pipeline/_parser_internals/_beam_utils.pyx index fff8d63e9..5efc52a60 100644 --- a/spacy/pipeline/_parser_internals/_beam_utils.pyx +++ b/spacy/pipeline/_parser_internals/_beam_utils.pyx @@ -6,8 +6,6 @@ from ...typedefs cimport class_t from .transition_system cimport Transition, TransitionSystem from ...errors import Errors - -from .batch cimport Batch from .search cimport Beam, MaxViolation from .search import MaxViolation @@ -29,7 +27,7 @@ cdef int check_final_state(void* _state, void* extra_args) except -1: return state.is_final() -cdef class BeamBatch(Batch): +cdef class BeamBatch(object): cdef public TransitionSystem moves cdef public object states cdef public object docs diff --git a/spacy/pipeline/_parser_internals/_parser_utils.pxd b/spacy/pipeline/_parser_internals/_parser_utils.pxd deleted file mode 100644 index 7fee05bad..000000000 --- a/spacy/pipeline/_parser_internals/_parser_utils.pxd +++ /dev/null @@ -1,2 +0,0 @@ -cdef int arg_max(const float* scores, const int n_classes) nogil -cdef int arg_max_if_valid(const float* scores, const int* is_valid, int n) nogil diff --git a/spacy/pipeline/_parser_internals/_parser_utils.pyx b/spacy/pipeline/_parser_internals/_parser_utils.pyx deleted file mode 100644 index 582756bf5..000000000 --- a/spacy/pipeline/_parser_internals/_parser_utils.pyx +++ /dev/null @@ -1,22 +0,0 @@ -# cython: infer_types=True - -cdef inline int arg_max(const float* scores, const int n_classes) nogil: - if n_classes == 2: - return 0 if scores[0] > scores[1] else 1 - cdef int i - cdef int best = 0 - cdef float mode = scores[0] - for i in range(1, n_classes): - if scores[i] > mode: - mode = scores[i] - best = i - return best - - -cdef inline int arg_max_if_valid(const float* scores, const int* is_valid, int n) nogil: - cdef int best = -1 - for i in range(n): - if is_valid[i] >= 1: - if best == -1 or scores[i] > scores[best]: - best = i - return best diff --git a/spacy/pipeline/_parser_internals/_state.pxd b/spacy/pipeline/_parser_internals/_state.pxd index 1c61ac271..04274ce8a 100644 --- a/spacy/pipeline/_parser_internals/_state.pxd +++ b/spacy/pipeline/_parser_internals/_state.pxd @@ -7,6 +7,8 @@ from libc.string cimport memcpy, memset from libcpp.set cimport set from libcpp.unordered_map cimport unordered_map from libcpp.vector cimport vector +from libcpp.set cimport set +from cpython.exc cimport PyErr_CheckSignals, PyErr_SetFromErrno from murmurhash.mrmr cimport hash64 from ...attrs cimport IS_SPACE @@ -26,7 +28,7 @@ cdef struct ArcC: cdef cppclass StateC: - vector[int] _heads + int* _heads const TokenC* _sent vector[int] _stack vector[int] _rebuffer @@ -34,34 +36,31 @@ cdef cppclass StateC: unordered_map[int, vector[ArcC]] _left_arcs unordered_map[int, vector[ArcC]] _right_arcs vector[libcpp.bool] _unshiftable - vector[int] history set[int] _sent_starts TokenC _empty_token int length int offset int _b_i - __init__(const TokenC* sent, int length) nogil except +: - this._heads.resize(length, -1) - this._unshiftable.resize(length, False) - - # Reserve memory ahead of time to minimize allocations during parsing. - # The initial capacity set here ideally reflects the expected average-case/majority usage. - cdef int init_capacity = 32 - this._stack.reserve(init_capacity) - this._rebuffer.reserve(init_capacity) - this._ents.reserve(init_capacity) - this._left_arcs.reserve(init_capacity) - this._right_arcs.reserve(init_capacity) - this.history.reserve(init_capacity) - + __init__(const TokenC* sent, int length) nogil: this._sent = sent + this._heads = calloc(length, sizeof(int)) + if not (this._sent and this._heads): + with gil: + PyErr_SetFromErrno(MemoryError) + PyErr_CheckSignals() this.offset = 0 this.length = length this._b_i = 0 + for i in range(length): + this._heads[i] = -1 + this._unshiftable.push_back(0) memset(&this._empty_token, 0, sizeof(TokenC)) this._empty_token.lex = &EMPTY_LEXEME + __dealloc__(): + free(this._heads) + void set_context_tokens(int* ids, int n) nogil: cdef int i, j if n == 1: @@ -134,20 +133,19 @@ cdef cppclass StateC: ids[i] = -1 int S(int i) nogil const: - cdef int stack_size = this._stack.size() - if i >= stack_size or i < 0: + if i >= this._stack.size(): return -1 - else: - return this._stack[stack_size - (i+1)] + elif i < 0: + return -1 + return this._stack.at(this._stack.size() - (i+1)) int B(int i) nogil const: - cdef int buf_size = this._rebuffer.size() if i < 0: return -1 - elif i < buf_size: - return this._rebuffer[buf_size - (i+1)] + elif i < this._rebuffer.size(): + return this._rebuffer.at(this._rebuffer.size() - (i+1)) else: - b_i = this._b_i + (i - buf_size) + b_i = this._b_i + (i - this._rebuffer.size()) if b_i >= this.length: return -1 else: @@ -246,7 +244,7 @@ cdef cppclass StateC: return 0 elif this._sent[word].sent_start == 1: return 1 - elif this._sent_starts.const_find(word) != this._sent_starts.const_end(): + elif this._sent_starts.count(word) >= 1: return 1 else: return 0 @@ -330,7 +328,7 @@ cdef cppclass StateC: if item >= this._unshiftable.size(): return 0 else: - return this._unshiftable[item] + return this._unshiftable.at(item) void set_reshiftable(int item) nogil: if item < this._unshiftable.size(): @@ -350,9 +348,6 @@ cdef cppclass StateC: this._heads[child] = head void map_del_arc(unordered_map[int, vector[ArcC]]* heads_arcs, int h_i, int c_i) nogil: - cdef vector[ArcC]* arcs - cdef ArcC* arc - arcs_it = heads_arcs.find(h_i) if arcs_it == heads_arcs.end(): return @@ -361,12 +356,12 @@ cdef cppclass StateC: if arcs.size() == 0: return - arc = &arcs.back() + arc = arcs.back() if arc.head == h_i and arc.child == c_i: arcs.pop_back() else: for i in range(arcs.size()-1): - arc = &deref(arcs)[i] + arc = arcs.at(i) if arc.head == h_i and arc.child == c_i: arc.head = -1 arc.child = -1 @@ -406,11 +401,10 @@ cdef cppclass StateC: this._rebuffer = src._rebuffer this._sent_starts = src._sent_starts this._unshiftable = src._unshiftable - this._heads = src._heads + memcpy(this._heads, src._heads, this.length * sizeof(this._heads[0])) this._ents = src._ents this._left_arcs = src._left_arcs this._right_arcs = src._right_arcs this._b_i = src._b_i this.offset = src.offset this._empty_token = src._empty_token - this.history = src.history diff --git a/spacy/pipeline/_parser_internals/arc_eager.pyx b/spacy/pipeline/_parser_internals/arc_eager.pyx index b2653bce3..cec3a38f5 100644 --- a/spacy/pipeline/_parser_internals/arc_eager.pyx +++ b/spacy/pipeline/_parser_internals/arc_eager.pyx @@ -779,8 +779,6 @@ cdef class ArcEager(TransitionSystem): return list(arcs) def has_gold(self, Example eg, start=0, end=None): - if end is not None and end < 0: - end = None for word in eg.y[start:end]: if word.dep != 0: return True @@ -865,7 +863,6 @@ cdef class ArcEager(TransitionSystem): state.print_state() ))) action.do(state.c, action.label) - state.c.history.push_back(i) break else: failed = False diff --git a/spacy/pipeline/_parser_internals/batch.pxd b/spacy/pipeline/_parser_internals/batch.pxd deleted file mode 100644 index 60734e549..000000000 --- a/spacy/pipeline/_parser_internals/batch.pxd +++ /dev/null @@ -1,2 +0,0 @@ -cdef class Batch: - pass diff --git a/spacy/pipeline/_parser_internals/batch.pyx b/spacy/pipeline/_parser_internals/batch.pyx deleted file mode 100644 index 91073b52e..000000000 --- a/spacy/pipeline/_parser_internals/batch.pyx +++ /dev/null @@ -1,52 +0,0 @@ -from typing import Any - -TransitionSystem = Any # TODO - -cdef class Batch: - def advance(self, scores): - raise NotImplementedError - - def get_states(self): - raise NotImplementedError - - @property - def is_done(self): - raise NotImplementedError - - def get_unfinished_states(self): - raise NotImplementedError - - def __getitem__(self, i): - raise NotImplementedError - - def __len__(self): - raise NotImplementedError - - -class GreedyBatch(Batch): - def __init__(self, moves: TransitionSystem, states, golds): - self._moves = moves - self._states = states - self._next_states = [s for s in states if not s.is_final()] - - def advance(self, scores): - self._next_states = self._moves.transition_states(self._next_states, scores) - - def advance_with_actions(self, actions): - self._next_states = self._moves.apply_actions(self._next_states, actions) - - def get_states(self): - return self._states - - @property - def is_done(self): - return all(s.is_final() for s in self._states) - - def get_unfinished_states(self): - return [st for st in self._states if not st.is_final()] - - def __getitem__(self, i): - return self._states[i] - - def __len__(self): - return len(self._states) diff --git a/spacy/pipeline/_parser_internals/ner.pyx b/spacy/pipeline/_parser_internals/ner.pyx index 9220bb522..d6ee29397 100644 --- a/spacy/pipeline/_parser_internals/ner.pyx +++ b/spacy/pipeline/_parser_internals/ner.pyx @@ -156,7 +156,7 @@ cdef class BiluoPushDown(TransitionSystem): if token.ent_type: labels.add(token.ent_type_) return labels - + def move_name(self, int move, attr_t label): if move == OUT: return 'O' @@ -306,8 +306,6 @@ cdef class BiluoPushDown(TransitionSystem): for span in eg.y.spans.get(neg_key, []): if span.start >= start and span.end <= end: return True - if end is not None and end < 0: - end = None for word in eg.y[start:end]: if word.ent_iob != 0: return True @@ -643,7 +641,7 @@ cdef class Unit: cost += 1 break return cost - + cdef class Out: @staticmethod diff --git a/spacy/pipeline/_parser_internals/stateclass.pyx b/spacy/pipeline/_parser_internals/stateclass.pyx index bdb4d1cf0..fdb5004bb 100644 --- a/spacy/pipeline/_parser_internals/stateclass.pyx +++ b/spacy/pipeline/_parser_internals/stateclass.pyx @@ -19,10 +19,6 @@ cdef class StateClass: if self._borrowed != 1: del self.c - @property - def history(self): - return list(self.c.history) - @property def stack(self): return [self.S(i) for i in range(self.c.stack_depth())] @@ -179,6 +175,3 @@ cdef class StateClass: def clone(self, StateClass src): self.c.clone(src.c) - - def set_context_tokens(self, int[:, :] output, int row, int n_feats): - self.c.set_context_tokens(&output[row, 0], n_feats) diff --git a/spacy/pipeline/_parser_internals/transition_system.pxd b/spacy/pipeline/_parser_internals/transition_system.pxd index 08baed932..04cd10d88 100644 --- a/spacy/pipeline/_parser_internals/transition_system.pxd +++ b/spacy/pipeline/_parser_internals/transition_system.pxd @@ -57,10 +57,3 @@ cdef class TransitionSystem: cdef int set_costs(self, int* is_valid, weight_t* costs, const StateC* state, gold) except -1 - - -cdef void c_apply_actions(TransitionSystem moves, StateC** states, const int* actions, - int batch_size) nogil - -cdef void c_transition_batch(TransitionSystem moves, StateC** states, const float* scores, - int nr_class, int batch_size) nogil diff --git a/spacy/pipeline/_parser_internals/transition_system.pyx b/spacy/pipeline/_parser_internals/transition_system.pyx index aaafe2aa0..e084689bc 100644 --- a/spacy/pipeline/_parser_internals/transition_system.pyx +++ b/spacy/pipeline/_parser_internals/transition_system.pyx @@ -2,16 +2,12 @@ from __future__ import print_function from cymem.cymem cimport Pool -from libc.stdlib cimport calloc, free -from libcpp.vector cimport vector from collections import Counter import srsly from ...structs cimport TokenC -from ...typedefs cimport attr_t, weight_t -from ._parser_utils cimport arg_max_if_valid from .stateclass cimport StateClass from ... import util @@ -76,18 +72,7 @@ cdef class TransitionSystem: offset += len(doc) return states - def follow_history(self, doc, history): - cdef int clas - cdef StateClass state = StateClass(doc) - for clas in history: - action = self.c[clas] - action.do(state.c, action.label) - state.c.history.push_back(clas) - return state - def get_oracle_sequence(self, Example example, _debug=False): - if not self.has_gold(example): - return [] states, golds, _ = self.init_gold_batch([example]) if not states: return [] @@ -99,8 +84,6 @@ cdef class TransitionSystem: return self.get_oracle_sequence_from_state(state, gold) def get_oracle_sequence_from_state(self, StateClass state, gold, _debug=None): - if state.is_final(): - return [] cdef Pool mem = Pool() # n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc assert self.n_moves > 0 @@ -126,7 +109,6 @@ cdef class TransitionSystem: "S0 head?", str(state.has_head(state.S(0))), ))) action.do(state.c, action.label) - state.c.history.push_back(i) break else: if _debug: @@ -154,28 +136,6 @@ cdef class TransitionSystem: raise ValueError(Errors.E170.format(name=name)) action = self.lookup_transition(name) action.do(state.c, action.label) - state.c.history.push_back(action.clas) - - def apply_actions(self, states, const int[::1] actions): - assert len(states) == actions.shape[0] - cdef StateClass state - cdef vector[StateC*] c_states - c_states.resize(len(states)) - cdef int i - for (i, state) in enumerate(states): - c_states[i] = state.c - c_apply_actions(self, &c_states[0], &actions[0], actions.shape[0]) - return [state for state in states if not state.c.is_final()] - - def transition_states(self, states, float[:, ::1] scores): - assert len(states) == scores.shape[0] - cdef StateClass state - cdef float* c_scores = &scores[0, 0] - cdef vector[StateC*] c_states - for state in states: - c_states.push_back(state.c) - c_transition_batch(self, &c_states[0], c_scores, scores.shape[1], scores.shape[0]) - return [state for state in states if not state.c.is_final()] cdef Transition lookup_transition(self, object name) except *: raise NotImplementedError @@ -288,34 +248,3 @@ cdef class TransitionSystem: self.cfg.update(msg['cfg']) self.initialize_actions(labels) return self - - -cdef void c_apply_actions(TransitionSystem moves, StateC** states, const int* actions, - int batch_size) nogil: - cdef int i - cdef Transition action - cdef StateC* state - for i in range(batch_size): - state = states[i] - action = moves.c[actions[i]] - action.do(state, action.label) - state.history.push_back(action.clas) - - -cdef void c_transition_batch(TransitionSystem moves, StateC** states, const float* scores, - int nr_class, int batch_size) nogil: - is_valid = calloc(moves.n_moves, sizeof(int)) - cdef int i, guess - cdef Transition action - for i in range(batch_size): - moves.set_valid(is_valid, states[i]) - guess = arg_max_if_valid(&scores[i*nr_class], is_valid, nr_class) - if guess == -1: - # This shouldn't happen, but it's hard to raise an error here, - # and we don't want to infinite loop. So, force to end state. - states[i].force_final() - else: - action = moves.c[guess] - action.do(states[i], action.label) - states[i].history.push_back(guess) - free(is_valid) diff --git a/spacy/pipeline/dep_parser.py b/spacy/pipeline/dep_parser.pyx similarity index 96% rename from spacy/pipeline/dep_parser.py rename to spacy/pipeline/dep_parser.pyx index f2472451b..8ddcbecaf 100644 --- a/spacy/pipeline/dep_parser.py +++ b/spacy/pipeline/dep_parser.pyx @@ -4,6 +4,10 @@ from typing import Callable, Optional from thinc.api import Config, Model +from ._parser_internals.transition_system import TransitionSystem +from .transition_parser cimport Parser +from ._parser_internals.arc_eager cimport ArcEager + from ..language import Language from ..scorer import Scorer from ..training import remove_bilu_prefix @@ -17,11 +21,12 @@ from .transition_parser import Parser default_model_config = """ [model] -@architectures = "spacy.TransitionBasedParser.v3" +@architectures = "spacy.TransitionBasedParser.v2" state_type = "parser" extra_state_tokens = false hidden_width = 64 maxout_pieces = 2 +use_upper = true [model.tok2vec] @architectures = "spacy.HashEmbedCNN.v2" @@ -121,7 +126,6 @@ def make_parser( scorer=scorer, ) - @Language.factory( "beam_parser", assigns=["token.dep", "token.head", "token.is_sent_start", "doc.sents"], @@ -227,7 +231,6 @@ def parser_score(examples, **kwargs): DOCS: https://spacy.io/api/dependencyparser#score """ - def has_sents(doc): return doc.has_annotation("SENT_START") @@ -235,11 +238,8 @@ def parser_score(examples, **kwargs): dep = getattr(token, attr) dep = token.vocab.strings.as_string(dep).lower() return dep - results = {} - results.update( - Scorer.score_spans(examples, "sents", has_annotation=has_sents, **kwargs) - ) + results.update(Scorer.score_spans(examples, "sents", has_annotation=has_sents, **kwargs)) kwargs.setdefault("getter", dep_getter) kwargs.setdefault("ignore_labels", ("p", "punct")) results.update(Scorer.score_deps(examples, "dep", **kwargs)) @@ -252,12 +252,11 @@ def make_parser_scorer(): return parser_score -class DependencyParser(Parser): +cdef class DependencyParser(Parser): """Pipeline component for dependency parsing. DOCS: https://spacy.io/api/dependencyparser """ - TransitionSystem = ArcEager def __init__( @@ -277,7 +276,8 @@ class DependencyParser(Parser): incorrect_spans_key=None, scorer=parser_score, ): - """Create a DependencyParser.""" + """Create a DependencyParser. + """ super().__init__( vocab, model, diff --git a/spacy/pipeline/ner.py b/spacy/pipeline/ner.pyx similarity index 93% rename from spacy/pipeline/ner.py rename to spacy/pipeline/ner.pyx index 445ed7663..9ae5cd1ef 100644 --- a/spacy/pipeline/ner.py +++ b/spacy/pipeline/ner.pyx @@ -10,15 +10,22 @@ from ..training import remove_bilu_prefix from ..util import registry from ._parser_internals.ner import BiluoPushDown from ._parser_internals.transition_system import TransitionSystem -from .transition_parser import Parser +from .transition_parser cimport Parser +from ._parser_internals.ner cimport BiluoPushDown +from ..language import Language +from ..scorer import get_ner_prf, PRFScore +from ..util import registry +from ..training import remove_bilu_prefix + default_model_config = """ [model] -@architectures = "spacy.TransitionBasedParser.v3" +@architectures = "spacy.TransitionBasedParser.v2" state_type = "ner" extra_state_tokens = false hidden_width = 64 maxout_pieces = 2 +use_upper = true [model.tok2vec] @architectures = "spacy.HashEmbedCNN.v2" @@ -43,12 +50,8 @@ DEFAULT_NER_MODEL = Config().from_str(default_model_config)["model"] "incorrect_spans_key": None, "scorer": {"@scorers": "spacy.ner_scorer.v1"}, }, - default_score_weights={ - "ents_f": 1.0, - "ents_p": 0.0, - "ents_r": 0.0, - "ents_per_type": None, - }, + default_score_weights={"ents_f": 1.0, "ents_p": 0.0, "ents_r": 0.0, "ents_per_type": None}, + ) def make_ner( nlp: Language, @@ -101,7 +104,6 @@ def make_ner( scorer=scorer, ) - @Language.factory( "beam_ner", assigns=["doc.ents", "token.ent_iob", "token.ent_type"], @@ -115,12 +117,7 @@ def make_ner( "incorrect_spans_key": None, "scorer": None, }, - default_score_weights={ - "ents_f": 1.0, - "ents_p": 0.0, - "ents_r": 0.0, - "ents_per_type": None, - }, + default_score_weights={"ents_f": 1.0, "ents_p": 0.0, "ents_r": 0.0, "ents_per_type": None}, ) def make_beam_ner( nlp: Language, @@ -194,12 +191,11 @@ def make_ner_scorer(): return ner_score -class EntityRecognizer(Parser): +cdef class EntityRecognizer(Parser): """Pipeline component for named entity recognition. DOCS: https://spacy.io/api/entityrecognizer """ - TransitionSystem = BiluoPushDown def __init__( @@ -217,14 +213,15 @@ class EntityRecognizer(Parser): incorrect_spans_key=None, scorer=ner_score, ): - """Create an EntityRecognizer.""" + """Create an EntityRecognizer. + """ super().__init__( vocab, model, name, moves, update_with_oracle_cut_size=update_with_oracle_cut_size, - min_action_freq=1, # not relevant for NER + min_action_freq=1, # not relevant for NER learn_tokens=False, # not relevant for NER beam_width=beam_width, beam_density=beam_density, diff --git a/spacy/pipeline/transition_parser.pxd b/spacy/pipeline/transition_parser.pxd new file mode 100644 index 000000000..f20e69a6e --- /dev/null +++ b/spacy/pipeline/transition_parser.pxd @@ -0,0 +1,21 @@ +from cymem.cymem cimport Pool +from thinc.backends.cblas cimport CBlas + +from ..vocab cimport Vocab +from .trainable_pipe cimport TrainablePipe +from ._parser_internals.transition_system cimport Transition, TransitionSystem +from ._parser_internals._state cimport StateC +from ..ml.parser_model cimport WeightsC, ActivationsC, SizesC + + +cdef class Parser(TrainablePipe): + cdef public object _rehearsal_model + cdef readonly TransitionSystem moves + cdef public object _multitasks + cdef object _cpu_ops + + cdef void _parseC(self, CBlas cblas, StateC** states, + WeightsC weights, SizesC sizes) nogil + + cdef void c_transition_batch(self, StateC** states, const float* scores, + int nr_class, int batch_size) nogil diff --git a/spacy/pipeline/transition_parser.pyx b/spacy/pipeline/transition_parser.pyx index 1231e439e..cc54111f7 100644 --- a/spacy/pipeline/transition_parser.pyx +++ b/spacy/pipeline/transition_parser.pyx @@ -1,15 +1,20 @@ # cython: infer_types=True, cdivision=True, boundscheck=False, binding=True from __future__ import print_function - from typing import Dict, Iterable, List, Optional, Tuple - -cimport numpy as np from cymem.cymem cimport Pool - -import contextlib -import random +cimport numpy as np from itertools import islice +from libcpp.vector cimport vector +from libc.string cimport memset, memcpy +from libc.stdlib cimport calloc, free +import random +import srsly +from thinc.api import get_ops, set_dropout_rate, CupyOps, NumpyOps, Optimizer +from thinc.api import chain, softmax_activation, use_ops +from thinc.legacy import LegacySequenceCategoricalCrossentropy +from thinc.types import Floats2d +import numpy.random import numpy import numpy.random import srsly @@ -23,7 +28,16 @@ from thinc.api import ( ) from thinc.types import Floats2d, Ints1d -from ..ml.tb_framework import TransitionModelInputs +from ._parser_internals.stateclass cimport StateClass +from ._parser_internals.search cimport Beam +from ..ml.parser_model cimport alloc_activations, free_activations +from ..ml.parser_model cimport predict_states, arg_max_if_valid +from ..ml.parser_model cimport WeightsC, ActivationsC, SizesC, cpu_log_loss +from ..ml.parser_model cimport get_c_weights, get_c_sizes +from ..tokens.doc cimport Doc +from .trainable_pipe import TrainablePipe +from ._parser_internals cimport _beam_utils +from ._parser_internals import _beam_utils from ..tokens.doc cimport Doc from ..typedefs cimport weight_t @@ -46,7 +60,7 @@ from ._parser_internals import _beam_utils NUMPY_OPS = NumpyOps() -class Parser(TrainablePipe): +cdef class Parser(TrainablePipe): """ Base class of the DependencyParser and EntityRecognizer. """ @@ -146,9 +160,8 @@ class Parser(TrainablePipe): @property def move_names(self): names = [] - cdef TransitionSystem moves = self.moves for i in range(self.moves.n_moves): - name = self.moves.move_name(moves.c[i].move, moves.c[i].label) + name = self.moves.move_name(self.moves.c[i].move, self.moves.c[i].label) # Explicitly removing the internal "U-" token used for blocking entities if name != "U-": names.append(name) @@ -255,6 +268,15 @@ class Parser(TrainablePipe): student_docs = [eg.predicted for eg in examples] + teacher_step_model = teacher_pipe.model.predict([eg.reference for eg in examples]) + student_step_model, backprop_tok2vec = self.model.begin_update(student_docs) + + # Add softmax activation, so that we can compute student losses + # with cross-entropy loss. + with use_ops("numpy"): + teacher_model = chain(teacher_step_model, softmax_activation()) + student_model = chain(student_step_model, softmax_activation()) + max_moves = self.cfg["update_with_oracle_cut_size"] if max_moves >= 1: # Chop sequences into lengths of this many words, to make the @@ -262,38 +284,50 @@ class Parser(TrainablePipe): # sequence, we use the teacher's predictions as the gold # standard. max_moves = int(random.uniform(max_moves // 2, max_moves * 2)) - states = self._init_batch(teacher_pipe, student_docs, max_moves) + states = self._init_batch(teacher_step_model, student_docs, max_moves) else: states = self.moves.init_batch(student_docs) - # We distill as follows: 1. we first let the student predict transition - # sequences (and the corresponding transition probabilities); (2) we - # let the teacher follow the student's predicted transition sequences - # to obtain the teacher's transition probabilities; (3) we compute the - # gradients of the student's transition distributions relative to the - # teacher's distributions. + loss = 0.0 + n_moves = 0 + while states: + # We do distillation as follows: (1) for every state, we compute the + # transition softmax distributions: (2) we backpropagate the error of + # the student (compared to the teacher) into the student model; (3) + # for all states, we move to the next state using the student's + # predictions. + teacher_scores = teacher_model.predict(states) + student_scores, backprop = student_model.begin_update(states) + state_loss, d_scores = self.get_teacher_student_loss(teacher_scores, student_scores) + backprop(d_scores) + loss += state_loss + self.transition_states(states, student_scores) + states = [state for state in states if not state.is_final()] - student_inputs = TransitionModelInputs(docs=student_docs, moves=self.moves, - max_moves=max_moves) - (student_states, student_scores), backprop_scores = self.model.begin_update(student_inputs) - actions = states2actions(student_states) - teacher_inputs = TransitionModelInputs(docs=[eg.reference for eg in examples], - moves=self.moves, actions=actions) - (_, teacher_scores) = teacher_pipe.model.predict(teacher_inputs) + # Stop when we reach the maximum number of moves, otherwise we start + # to process the remainder of cut sequences again. + if max_moves >= 1 and n_moves >= max_moves: + break + n_moves += 1 - loss, d_scores = self.get_teacher_student_loss(teacher_scores, student_scores) - backprop_scores((student_states, d_scores)) + backprop_tok2vec(student_docs) if sgd is not None: self.finish_update(sgd) losses[self.name] += loss + del backprop + del backprop_tok2vec + teacher_step_model.clear_memory() + student_step_model.clear_memory() + del teacher_model + del student_model + return losses def get_teacher_student_loss( - self, teacher_scores: List[Floats2d], student_scores: List[Floats2d], - normalize: bool = False, + self, teacher_scores: List[Floats2d], student_scores: List[Floats2d] ) -> Tuple[float, List[Floats2d]]: """Calculate the loss and its gradient for a batch of student scores, relative to teacher scores. @@ -305,28 +339,10 @@ class Parser(TrainablePipe): DOCS: https://spacy.io/api/dependencyparser#get_teacher_student_loss """ - - # We can't easily hook up a softmax layer in the parsing model, since - # the get_loss does additional masking. So, we could apply softmax - # manually here and use Thinc's cross-entropy loss. But it's a bit - # suboptimal, since we can have a lot of states that would result in - # many kernel launches. Futhermore the parsing model's backprop expects - # a XP array, so we'd have to concat the softmaxes anyway. So, like - # the get_loss implementation, we'll compute the loss and gradients - # ourselves. - - teacher_scores = self.model.ops.softmax(self.model.ops.xp.vstack(teacher_scores), - axis=-1, inplace=True) - student_scores = self.model.ops.softmax(self.model.ops.xp.vstack(student_scores), - axis=-1, inplace=True) - - assert teacher_scores.shape == student_scores.shape - - d_scores = student_scores - teacher_scores - if normalize: - d_scores /= d_scores.shape[0] - loss = (d_scores**2).sum() / d_scores.size - + loss_func = LegacySequenceCategoricalCrossentropy(normalize=False) + d_scores, loss = loss_func(student_scores, teacher_scores) + if self.model.ops.xp.isnan(loss): + raise ValueError(Errors.E910.format(name=self.name)) return float(loss), d_scores def init_multitask_objectives(self, get_examples, pipeline, **cfg): @@ -349,6 +365,9 @@ class Parser(TrainablePipe): stream: The sequence of documents to process. batch_size (int): Number of documents to accumulate into a working set. + error_handler (Callable[[str, List[Doc], Exception], Any]): Function that + deals with a failing batch of documents. The default function just reraises + the exception. YIELDS (Doc): Documents, in order. """ @@ -369,29 +388,78 @@ class Parser(TrainablePipe): def predict(self, docs): if isinstance(docs, Doc): docs = [docs] - self._ensure_labels_are_added(docs) if not any(len(doc) for doc in docs): result = self.moves.init_batch(docs) return result - with _change_attrs(self.model, beam_width=self.cfg["beam_width"], beam_density=self.cfg["beam_density"]): - inputs = TransitionModelInputs(docs=docs, moves=self.moves) - states_or_beams, _ = self.model.predict(inputs) - return states_or_beams + if self.cfg["beam_width"] == 1: + return self.greedy_parse(docs, drop=0.0) + else: + return self.beam_parse( + docs, + drop=0.0, + beam_width=self.cfg["beam_width"], + beam_density=self.cfg["beam_density"] + ) def greedy_parse(self, docs, drop=0.): - self._resize() + cdef vector[StateC*] states + cdef StateClass state + cdef CBlas cblas = self._cpu_ops.cblas() self._ensure_labels_are_added(docs) - with _change_attrs(self.model, beam_width=1): - inputs = TransitionModelInputs(docs=docs, moves=self.moves) - states, _ = self.model.predict(inputs) - return states + set_dropout_rate(self.model, drop) + batch = self.moves.init_batch(docs) + model = self.model.predict(docs) + weights = get_c_weights(model) + for state in batch: + if not state.is_final(): + states.push_back(state.c) + sizes = get_c_sizes(model, states.size()) + with nogil: + self._parseC(cblas, &states[0], weights, sizes) + model.clear_memory() + del model + return batch def beam_parse(self, docs, int beam_width, float drop=0., beam_density=0.): + cdef Beam beam + cdef Doc doc self._ensure_labels_are_added(docs) - with _change_attrs(self.model, beam_width=self.cfg["beam_width"], beam_density=self.cfg["beam_density"]): - inputs = TransitionModelInputs(docs=docs, moves=self.moves) - beams, _ = self.model.predict(inputs) - return beams + batch = _beam_utils.BeamBatch( + self.moves, + self.moves.init_batch(docs), + None, + beam_width, + density=beam_density + ) + model = self.model.predict(docs) + while not batch.is_done: + states = batch.get_unfinished_states() + if not states: + break + scores = model.predict(states) + batch.advance(scores) + model.clear_memory() + del model + return list(batch) + + cdef void _parseC(self, CBlas cblas, StateC** states, + WeightsC weights, SizesC sizes) nogil: + cdef int i, j + cdef vector[StateC*] unfinished + cdef ActivationsC activations = alloc_activations(sizes) + while sizes.states >= 1: + predict_states(cblas, &activations, states, &weights, sizes) + # Validate actions, argmax, take action. + self.c_transition_batch(states, + activations.scores, sizes.classes, sizes.states) + for i in range(sizes.states): + if not states[i].is_final(): + unfinished.push_back(states[i]) + for i in range(unfinished.size()): + states[i] = unfinished[i] + sizes.states = unfinished.size() + unfinished.clear() + free_activations(&activations) def set_annotations(self, docs, states_or_beams): cdef StateClass state @@ -402,6 +470,35 @@ class Parser(TrainablePipe): for hook in self.postprocesses: hook(doc) + def transition_states(self, states, float[:, ::1] scores): + cdef StateClass state + cdef float* c_scores = &scores[0, 0] + cdef vector[StateC*] c_states + for state in states: + c_states.push_back(state.c) + self.c_transition_batch(&c_states[0], c_scores, scores.shape[1], scores.shape[0]) + return [state for state in states if not state.c.is_final()] + + cdef void c_transition_batch(self, StateC** states, const float* scores, + int nr_class, int batch_size) nogil: + # n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc + with gil: + assert self.moves.n_moves > 0, Errors.E924.format(name=self.name) + is_valid = calloc(self.moves.n_moves, sizeof(int)) + cdef int i, guess + cdef Transition action + for i in range(batch_size): + self.moves.set_valid(is_valid, states[i]) + guess = arg_max_if_valid(&scores[i*nr_class], is_valid, nr_class) + if guess == -1: + # This shouldn't happen, but it's hard to raise an error here, + # and we don't want to infinite loop. So, force to end state. + states[i].force_final() + else: + action = self.moves.c[guess] + action.do(states[i], action.label) + free(is_valid) + def update(self, examples, *, drop=0., sgd=None, losses=None): if losses is None: losses = {} @@ -412,98 +509,66 @@ class Parser(TrainablePipe): ) for multitask in self._multitasks: multitask.update(examples, drop=drop, sgd=sgd) - # We need to take care to act on the whole batch, because we might be - # getting vectors via a listener. n_examples = len([eg for eg in examples if self.moves.has_gold(eg)]) if n_examples == 0: return losses set_dropout_rate(self.model, drop) - docs = [eg.x for eg in examples if len(eg.x)] - + # The probability we use beam update, instead of falling back to + # a greedy update + beam_update_prob = self.cfg["beam_update_prob"] + if self.cfg['beam_width'] >= 2 and numpy.random.random() < beam_update_prob: + return self.update_beam( + examples, + beam_width=self.cfg["beam_width"], + sgd=sgd, + losses=losses, + beam_density=self.cfg["beam_density"] + ) max_moves = self.cfg["update_with_oracle_cut_size"] if max_moves >= 1: # Chop sequences into lengths of this many words, to make the # batch uniform length. - max_moves = int(random.uniform(max(max_moves // 2, 1), max_moves * 2)) - init_states, gold_states, _ = self._init_gold_batch( + max_moves = int(random.uniform(max_moves // 2, max_moves * 2)) + states, golds, _ = self._init_gold_batch( examples, max_length=max_moves ) else: - init_states, gold_states, _ = self.moves.init_gold_batch(examples) - - inputs = TransitionModelInputs(docs=docs, - moves=self.moves, - max_moves=max_moves, - states=[state.copy() for state in init_states]) - (pred_states, scores), backprop_scores = self.model.begin_update(inputs) - if sum(s.shape[0] for s in scores) == 0: + states, golds, _ = self.moves.init_gold_batch(examples) + if not states: return losses - d_scores = self.get_loss((gold_states, init_states, pred_states, scores), - examples, max_moves) - backprop_scores((pred_states, d_scores)) - if sgd not in (None, False): - self.finish_update(sgd) - losses[self.name] += float((d_scores**2).sum()) - # Ugh, this is annoying. If we're working on GPU, we want to free the - # memory ASAP. It seems that Python doesn't necessarily get around to - # removing these in time if we don't explicitly delete? It's confusing. - del backprop_scores - return losses - - def get_loss(self, states_scores, examples, max_moves): - gold_states, init_states, pred_states, scores = states_scores - scores = self.model.ops.xp.vstack(scores) - costs = self._get_costs_from_histories( - examples, - gold_states, - init_states, - [list(state.history) for state in pred_states], - max_moves - ) - xp = get_array_module(scores) - best_costs = costs.min(axis=1, keepdims=True) - gscores = scores.copy() - min_score = scores.min() - 1000 - assert costs.shape == scores.shape, (costs.shape, scores.shape) - gscores[costs > best_costs] = min_score - max_ = scores.max(axis=1, keepdims=True) - gmax = gscores.max(axis=1, keepdims=True) - exp_scores = xp.exp(scores - max_) - exp_gscores = xp.exp(gscores - gmax) - Z = exp_scores.sum(axis=1, keepdims=True) - gZ = exp_gscores.sum(axis=1, keepdims=True) - d_scores = exp_scores / Z - d_scores -= (costs <= best_costs) * (exp_gscores / gZ) - return d_scores - - def _get_costs_from_histories(self, examples, gold_states, init_states, histories, max_moves): - cdef TransitionSystem moves = self.moves - cdef StateClass state - cdef int clas - cdef int nO = moves.n_moves - cdef Pool mem = Pool() - cdef np.ndarray costs_i - is_valid = mem.alloc(nO, sizeof(int)) - batch = list(zip(init_states, histories, gold_states)) + model, backprop_tok2vec = self.model.begin_update([eg.x for eg in examples]) + + all_states = list(states) + states_golds = list(zip(states, golds)) n_moves = 0 - output = [] - while batch: - costs = numpy.zeros((len(batch), nO), dtype="f") - for i, (state, history, gold) in enumerate(batch): - costs_i = costs[i] - clas = history.pop(0) - moves.set_costs(is_valid, costs_i.data, state.c, gold) - action = moves.c[clas] - action.do(state.c, action.label) - state.c.history.push_back(clas) - output.append(costs) - batch = [(s, h, g) for s, h, g in batch if len(h) != 0] - if n_moves >= max_moves >= 1: + while states_golds: + states, golds = zip(*states_golds) + scores, backprop = model.begin_update(states) + d_scores = self.get_batch_loss(states, golds, scores, losses) + # Note that the gradient isn't normalized by the batch size + # here, because our "samples" are really the states...But we + # can't normalize by the number of states either, as then we'd + # be getting smaller gradients for states in long sequences. + backprop(d_scores) + # Follow the predicted action + self.transition_states(states, scores) + states_golds = [(s, g) for (s, g) in zip(states, golds) if not s.is_final()] + if max_moves >= 1 and n_moves >= max_moves: break n_moves += 1 - return self.model.ops.xp.vstack(output) + backprop_tok2vec(golds) + if sgd not in (None, False): + self.finish_update(sgd) + # Ugh, this is annoying. If we're working on GPU, we want to free the + # memory ASAP. It seems that Python doesn't necessarily get around to + # removing these in time if we don't explicitly delete? It's confusing. + del backprop + del backprop_tok2vec + model.clear_memory() + del model + return losses def rehearse(self, examples, sgd=None, losses=None, **cfg): """Perform a "rehearsal" update, to prevent catastrophic forgetting.""" @@ -514,9 +579,10 @@ class Parser(TrainablePipe): multitask.rehearse(examples, losses=losses, sgd=sgd) if self._rehearsal_model is None: return None - losses.setdefault(self.name, 0.0) + losses.setdefault(self.name, 0.) validate_examples(examples, "Parser.rehearse") docs = [eg.predicted for eg in examples] + states = self.moves.init_batch(docs) # This is pretty dirty, but the NER can resize itself in init_batch, # if labels are missing. We therefore have to check whether we need to # expand our model output. @@ -524,33 +590,85 @@ class Parser(TrainablePipe): # Prepare the stepwise model, and get the callback for finishing the batch set_dropout_rate(self._rehearsal_model, 0.0) set_dropout_rate(self.model, 0.0) - student_inputs = TransitionModelInputs(docs=docs, moves=self.moves) - (student_states, student_scores), backprop_scores = self.model.begin_update(student_inputs) - actions = states2actions(student_states) - teacher_inputs = TransitionModelInputs(docs=docs, moves=self.moves, actions=actions) - _, teacher_scores = self._rehearsal_model.predict(teacher_inputs) - - loss, d_scores = self.get_teacher_student_loss(teacher_scores, student_scores, normalize=True) - - teacher_scores = self.model.ops.xp.vstack(teacher_scores) - student_scores = self.model.ops.xp.vstack(student_scores) - assert teacher_scores.shape == student_scores.shape - - d_scores = (student_scores - teacher_scores) / teacher_scores.shape[0] - # If all weights for an output are 0 in the original model, don't - # supervise that output. This allows us to add classes. - loss = (d_scores**2).sum() / d_scores.size - backprop_scores((student_states, d_scores)) - + tutor, _ = self._rehearsal_model.begin_update(docs) + model, backprop_tok2vec = self.model.begin_update(docs) + n_scores = 0. + loss = 0. + while states: + targets, _ = tutor.begin_update(states) + guesses, backprop = model.begin_update(states) + d_scores = (guesses - targets) / targets.shape[0] + # If all weights for an output are 0 in the original model, don't + # supervise that output. This allows us to add classes. + loss += (d_scores**2).sum() + backprop(d_scores) + # Follow the predicted action + self.transition_states(states, guesses) + states = [state for state in states if not state.is_final()] + n_scores += d_scores.size + # Do the backprop + backprop_tok2vec(docs) if sgd is not None: self.finish_update(sgd) - losses[self.name] += loss - + losses[self.name] += loss / n_scores + del backprop + del backprop_tok2vec + model.clear_memory() + tutor.clear_memory() + del model + del tutor return losses - def update_beam(self, examples, *, beam_width, drop=0., - sgd=None, losses=None, beam_density=0.0): - raise NotImplementedError + def update_beam(self, examples, *, beam_width, + drop=0., sgd=None, losses=None, beam_density=0.0): + states, golds, _ = self.moves.init_gold_batch(examples) + if not states: + return losses + # Prepare the stepwise model, and get the callback for finishing the batch + model, backprop_tok2vec = self.model.begin_update( + [eg.predicted for eg in examples]) + loss = _beam_utils.update_beam( + self.moves, + states, + golds, + model, + beam_width, + beam_density=beam_density, + ) + losses[self.name] += loss + backprop_tok2vec(golds) + if sgd is not None: + self.finish_update(sgd) + + def get_batch_loss(self, states, golds, float[:, ::1] scores, losses): + cdef StateClass state + cdef Pool mem = Pool() + cdef int i + + # n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc + assert self.moves.n_moves > 0, Errors.E924.format(name=self.name) + + is_valid = mem.alloc(self.moves.n_moves, sizeof(int)) + costs = mem.alloc(self.moves.n_moves, sizeof(float)) + cdef np.ndarray d_scores = numpy.zeros((len(states), self.moves.n_moves), + dtype='f', order='C') + c_d_scores = d_scores.data + unseen_classes = self.model.attrs["unseen_classes"] + for i, (state, gold) in enumerate(zip(states, golds)): + memset(is_valid, 0, self.moves.n_moves * sizeof(int)) + memset(costs, 0, self.moves.n_moves * sizeof(float)) + self.moves.set_costs(is_valid, costs, state.c, gold) + for j in range(self.moves.n_moves): + if costs[j] <= 0.0 and j in unseen_classes: + unseen_classes.remove(j) + cpu_log_loss(c_d_scores, + costs, is_valid, &scores[i, 0], d_scores.shape[1]) + c_d_scores += d_scores.shape[1] + # Note that we don't normalize this. See comment in update() for why. + if losses is not None: + losses.setdefault(self.name, 0.) + losses[self.name] += (d_scores**2).sum() + return d_scores def set_output(self, nO): self.model.attrs["resize_output"](self.model, nO) @@ -589,7 +707,7 @@ class Parser(TrainablePipe): for example in islice(get_examples(), 10): doc_sample.append(example.predicted) assert len(doc_sample) > 0, Errors.E923.format(name=self.name) - self.model.initialize((doc_sample, self.moves)) + self.model.initialize(doc_sample) if nlp is not None: self.init_multitask_objectives(get_examples, nlp.pipeline) @@ -682,27 +800,26 @@ class Parser(TrainablePipe): def _init_gold_batch(self, examples, max_length): """Make a square batch, of length equal to the shortest transition - sequence or a cap. A long doc will get multiple states. Let's say we - have a doc of length 2*N, where N is the shortest doc. We'll make - two states, one representing long_doc[:N], and another representing - long_doc[N:].""" + sequence or a cap. A long + doc will get multiple states. Let's say we have a doc of length 2*N, + where N is the shortest doc. We'll make two states, one representing + long_doc[:N], and another representing long_doc[N:].""" cdef: StateClass start_state StateClass state Transition action - TransitionSystem moves = self.moves - all_states = moves.init_batch([eg.predicted for eg in examples]) + all_states = self.moves.init_batch([eg.predicted for eg in examples]) states = [] golds = [] to_cut = [] for state, eg in zip(all_states, examples): - if moves.has_gold(eg) and not state.is_final(): - gold = moves.init_gold(state, eg) + if self.moves.has_gold(eg) and not state.is_final(): + gold = self.moves.init_gold(state, eg) if len(eg.x) < max_length: states.append(state) golds.append(gold) else: - oracle_actions = moves.get_oracle_sequence_from_state( + oracle_actions = self.moves.get_oracle_sequence_from_state( state.copy(), gold) to_cut.append((eg, state, gold, oracle_actions)) if not to_cut: @@ -712,52 +829,13 @@ class Parser(TrainablePipe): for i in range(0, len(oracle_actions), max_length): start_state = state.copy() for clas in oracle_actions[i:i+max_length]: - action = moves.c[clas] + action = self.moves.c[clas] action.do(state.c, action.label) if state.is_final(): break - if moves.has_gold(eg, start_state.B(0), state.B(0)): + if self.moves.has_gold(eg, start_state.B(0), state.B(0)): states.append(start_state) golds.append(gold) if state.is_final(): break return states, golds, max_length - - -@contextlib.contextmanager -def _change_attrs(model, **kwargs): - """Temporarily modify a thinc model's attributes.""" - unset = object() - old_attrs = {} - for key, value in kwargs.items(): - old_attrs[key] = model.attrs.get(key, unset) - model.attrs[key] = value - yield model - for key, value in old_attrs.items(): - if value is unset: - model.attrs.pop(key) - else: - model.attrs[key] = value - - -def states2actions(states: List[StateClass]) -> List[Ints1d]: - cdef int step - cdef StateClass state - cdef StateC* c_state - actions = [] - while True: - step = len(actions) - - step_actions = [] - for state in states: - c_state = state.c - if step < c_state.history.size(): - step_actions.append(c_state.history[step]) - - # We are done if we have exhausted all histories. - if len(step_actions) == 0: - break - - actions.append(numpy.array(step_actions, dtype="i")) - - return actions diff --git a/spacy/tests/parser/test_ner.py b/spacy/tests/parser/test_ner.py index bb9b7653c..5d5fbb027 100644 --- a/spacy/tests/parser/test_ner.py +++ b/spacy/tests/parser/test_ner.py @@ -17,6 +17,7 @@ from spacy.pipeline.ner import DEFAULT_NER_MODEL from spacy.tokens import Doc, Span from spacy.training import Example, iob_to_biluo, split_bilu_label from spacy.vocab import Vocab +import logging from ..util import make_tempdir @@ -413,7 +414,7 @@ def test_train_empty(): train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1])) ner = nlp.add_pipe("ner", last=True) ner.add_label("PERSON") - nlp.initialize(get_examples=lambda: train_examples) + nlp.initialize() for itn in range(2): losses = {} batches = util.minibatch(train_examples, size=8) @@ -540,11 +541,11 @@ def test_block_ner(): assert [token.ent_type_ for token in doc] == expected_types -def test_overfitting_IO(): - fix_random_seed(1) +@pytest.mark.parametrize("use_upper", [True, False]) +def test_overfitting_IO(use_upper): # Simple test to try and quickly overfit the NER component nlp = English() - ner = nlp.add_pipe("ner", config={"model": {}}) + ner = nlp.add_pipe("ner", config={"model": {"use_upper": use_upper}}) train_examples = [] for text, annotations in TRAIN_DATA: train_examples.append(Example.from_dict(nlp.make_doc(text), annotations)) @@ -576,6 +577,7 @@ def test_overfitting_IO(): assert ents2[0].label_ == "LOC" # Ensure that the predictions are still the same, even after adding a new label ner2 = nlp2.get_pipe("ner") + assert ner2.model.attrs["has_upper"] == use_upper ner2.add_label("RANDOM_NEW_LABEL") doc3 = nlp2(test_text) ents3 = doc3.ents diff --git a/spacy/tests/parser/test_parse.py b/spacy/tests/parser/test_parse.py index 2f6c77ba8..42cf5ced9 100644 --- a/spacy/tests/parser/test_parse.py +++ b/spacy/tests/parser/test_parse.py @@ -1,6 +1,3 @@ -import itertools - -import numpy import pytest from numpy.testing import assert_equal from thinc.api import Adam, fix_random_seed @@ -62,8 +59,6 @@ PARTIAL_DATA = [ ), ] -PARSERS = ["parser"] # TODO: Test beam_parser when ready - eps = 0.1 @@ -176,57 +171,6 @@ def test_parser_parse_one_word_sentence(en_vocab, en_parser, words): assert doc[0].dep != 0 -def test_parser_apply_actions(en_vocab, en_parser): - words = ["I", "ate", "pizza"] - words2 = ["Eat", "more", "pizza", "!"] - doc1 = Doc(en_vocab, words=words) - doc2 = Doc(en_vocab, words=words2) - docs = [doc1, doc2] - - moves = en_parser.moves - moves.add_action(0, "") - moves.add_action(1, "") - moves.add_action(2, "nsubj") - moves.add_action(3, "obj") - moves.add_action(2, "amod") - - actions = [ - numpy.array([0, 0], dtype="i"), - numpy.array([2, 0], dtype="i"), - numpy.array([0, 4], dtype="i"), - numpy.array([3, 3], dtype="i"), - numpy.array([1, 1], dtype="i"), - numpy.array([1, 1], dtype="i"), - numpy.array([0], dtype="i"), - numpy.array([1], dtype="i"), - ] - - states = moves.init_batch(docs) - active_states = states - - for step_actions in actions: - active_states = moves.apply_actions(active_states, step_actions) - - assert len(active_states) == 0 - - for state, doc in zip(states, docs): - moves.set_annotations(state, doc) - - assert docs[0][0].head.i == 1 - assert docs[0][0].dep_ == "nsubj" - assert docs[0][1].head.i == 1 - assert docs[0][1].dep_ == "ROOT" - assert docs[0][2].head.i == 1 - assert docs[0][2].dep_ == "obj" - - assert docs[1][0].head.i == 0 - assert docs[1][0].dep_ == "ROOT" - assert docs[1][1].head.i == 2 - assert docs[1][1].dep_ == "amod" - assert docs[1][2].head.i == 0 - assert docs[1][2].dep_ == "obj" - - @pytest.mark.skip( reason="The step_through API was removed (but should be brought back)" ) @@ -375,7 +319,7 @@ def test_parser_constructor(en_vocab): DependencyParser(en_vocab, model) -@pytest.mark.parametrize("pipe_name", PARSERS) +@pytest.mark.parametrize("pipe_name", ["parser", "beam_parser"]) def test_incomplete_data(pipe_name): # Test that the parser works with incomplete information nlp = English() @@ -401,15 +345,11 @@ def test_incomplete_data(pipe_name): assert doc[2].head.i == 1 -@pytest.mark.parametrize( - "pipe_name,max_moves", itertools.product(PARSERS, [0, 1, 5, 100]) -) -def test_overfitting_IO(pipe_name, max_moves): - fix_random_seed(0) +@pytest.mark.parametrize("pipe_name", ["parser", "beam_parser"]) +def test_overfitting_IO(pipe_name): # Simple test to try and quickly overfit the dependency parser (normal or beam) nlp = English() parser = nlp.add_pipe(pipe_name) - parser.cfg["update_with_oracle_cut_size"] = max_moves train_examples = [] for text, annotations in TRAIN_DATA: train_examples.append(Example.from_dict(nlp.make_doc(text), annotations)) @@ -511,12 +451,10 @@ def test_distill(): @pytest.mark.parametrize( "parser_config", [ - # TODO: re-enable after we have a spacy-legacy release for v4. See - # https://github.com/explosion/spacy-legacy/pull/36 - #({"@architectures": "spacy.TransitionBasedParser.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "state_type": "parser", "extra_state_tokens": False, "hidden_width": 64, "maxout_pieces": 2, "use_upper": True}), + # TransitionBasedParser V1 + ({"@architectures": "spacy.TransitionBasedParser.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "state_type": "parser", "extra_state_tokens": False, "hidden_width": 64, "maxout_pieces": 2, "use_upper": True}), + # TransitionBasedParser V2 ({"@architectures": "spacy.TransitionBasedParser.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "state_type": "parser", "extra_state_tokens": False, "hidden_width": 64, "maxout_pieces": 2, "use_upper": True}), - ({"@architectures": "spacy.TransitionBasedParser.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "state_type": "parser", "extra_state_tokens": False, "hidden_width": 64, "maxout_pieces": 2, "use_upper": False}), - ({"@architectures": "spacy.TransitionBasedParser.v3", "tok2vec": DEFAULT_TOK2VEC_MODEL, "state_type": "parser", "extra_state_tokens": False, "hidden_width": 64, "maxout_pieces": 2}), ], ) # fmt: on diff --git a/spacy/tests/pipeline/test_tok2vec.py b/spacy/tests/pipeline/test_tok2vec.py index f6cefbc1f..c3c4bb6c6 100644 --- a/spacy/tests/pipeline/test_tok2vec.py +++ b/spacy/tests/pipeline/test_tok2vec.py @@ -384,7 +384,7 @@ cfg_string_multi = """ factory = "ner" [components.ner.model] - @architectures = "spacy.TransitionBasedParser.v3" + @architectures = "spacy.TransitionBasedParser.v2" [components.ner.model.tok2vec] @architectures = "spacy.Tok2VecListener.v1" diff --git a/spacy/tests/serialize/test_serialize_config.py b/spacy/tests/serialize/test_serialize_config.py index b351ea801..8a1c74ca9 100644 --- a/spacy/tests/serialize/test_serialize_config.py +++ b/spacy/tests/serialize/test_serialize_config.py @@ -189,11 +189,33 @@ width = ${components.tok2vec.model.width} parser_config_string_upper = """ [model] -@architectures = "spacy.TransitionBasedParser.v3" +@architectures = "spacy.TransitionBasedParser.v2" state_type = "parser" extra_state_tokens = false hidden_width = 66 maxout_pieces = 2 +use_upper = true + +[model.tok2vec] +@architectures = "spacy.HashEmbedCNN.v1" +pretrained_vectors = null +width = 333 +depth = 4 +embed_size = 5555 +window_size = 1 +maxout_pieces = 7 +subword_features = false +""" + + +parser_config_string_no_upper = """ +[model] +@architectures = "spacy.TransitionBasedParser.v2" +state_type = "parser" +extra_state_tokens = false +hidden_width = 66 +maxout_pieces = 2 +use_upper = false [model.tok2vec] @architectures = "spacy.HashEmbedCNN.v1" @@ -224,6 +246,7 @@ def my_parser(): extra_state_tokens=True, hidden_width=65, maxout_pieces=5, + use_upper=True, ) return parser @@ -337,16 +360,15 @@ def test_serialize_custom_nlp(): nlp.to_disk(d) nlp2 = spacy.load(d) model = nlp2.get_pipe("parser").model - assert model.get_ref("tok2vec") is not None - assert model.has_param("hidden_W") - assert model.has_param("hidden_b") - output = model.get_ref("output") - assert output is not None - assert output.has_param("W") - assert output.has_param("b") + model.get_ref("tok2vec") + # check that we have the correct settings, not the default ones + assert model.get_ref("upper").get_dim("nI") == 65 + assert model.get_ref("lower").get_dim("nI") == 65 -@pytest.mark.parametrize("parser_config_string", [parser_config_string_upper]) +@pytest.mark.parametrize( + "parser_config_string", [parser_config_string_upper, parser_config_string_no_upper] +) def test_serialize_parser(parser_config_string): """Create a non-default parser config to check nlp serializes it correctly""" nlp = English() @@ -359,13 +381,11 @@ def test_serialize_parser(parser_config_string): nlp.to_disk(d) nlp2 = spacy.load(d) model = nlp2.get_pipe("parser").model - assert model.get_ref("tok2vec") is not None - assert model.has_param("hidden_W") - assert model.has_param("hidden_b") - output = model.get_ref("output") - assert output is not None - assert output.has_param("b") - assert output.has_param("W") + model.get_ref("tok2vec") + # check that we have the correct settings, not the default ones + if model.attrs["has_upper"]: + assert model.get_ref("upper").get_dim("nI") == 66 + assert model.get_ref("lower").get_dim("nI") == 66 def test_config_nlp_roundtrip(): @@ -561,7 +581,9 @@ def test_config_auto_fill_extra_fields(): load_model_from_config(nlp.config) -@pytest.mark.parametrize("parser_config_string", [parser_config_string_upper]) +@pytest.mark.parametrize( + "parser_config_string", [parser_config_string_upper, parser_config_string_no_upper] +) def test_config_validate_literal(parser_config_string): nlp = English() config = Config().from_str(parser_config_string) diff --git a/spacy/tests/test_misc.py b/spacy/tests/test_misc.py index cd1a25768..2016f9adc 100644 --- a/spacy/tests/test_misc.py +++ b/spacy/tests/test_misc.py @@ -1,19 +1,22 @@ import ctypes import os from pathlib import Path - import pytest -from pydantic import ValidationError -from thinc.api import ( - Config, - ConfigValidationError, - CupyOps, - MPSOps, - NumpyOps, - Optimizer, - get_current_ops, - set_current_ops, -) + +try: + from pydantic.v1 import ValidationError +except ImportError: + from pydantic import ValidationError # type: ignore + +from spacy.about import __version__ as spacy_version +from spacy import util +from spacy import prefer_gpu, require_gpu, require_cpu +from spacy.ml._precomputable_affine import PrecomputableAffine +from spacy.ml._precomputable_affine import _backprop_precomputable_affine_padding +from spacy.util import dot_to_object, SimpleFrozenList, import_file +from spacy.util import to_ternary_int, find_available_port +from thinc.api import Config, Optimizer, ConfigValidationError +from thinc.api import get_current_ops, set_current_ops, NumpyOps, CupyOps, MPSOps from thinc.compat import has_cupy_gpu, has_torch_mps_gpu from spacy import prefer_gpu, require_cpu, require_gpu, util @@ -92,6 +95,34 @@ def test_util_get_package_path(package): assert isinstance(path, Path) +def test_PrecomputableAffine(nO=4, nI=5, nF=3, nP=2): + model = PrecomputableAffine(nO=nO, nI=nI, nF=nF, nP=nP).initialize() + assert model.get_param("W").shape == (nF, nO, nP, nI) + tensor = model.ops.alloc((10, nI)) + Y, get_dX = model.begin_update(tensor) + assert Y.shape == (tensor.shape[0] + 1, nF, nO, nP) + dY = model.ops.alloc((15, nO, nP)) + ids = model.ops.alloc((15, nF)) + ids[1, 2] = -1 + dY[1] = 1 + assert not model.has_grad("pad") + d_pad = _backprop_precomputable_affine_padding(model, dY, ids) + assert d_pad[0, 2, 0, 0] == 1.0 + ids.fill(0.0) + dY.fill(0.0) + dY[0] = 0 + ids[1, 2] = 0 + ids[1, 1] = -1 + ids[1, 0] = -1 + dY[1] = 1 + ids[2, 0] = -1 + dY[2] = 5 + d_pad = _backprop_precomputable_affine_padding(model, dY, ids) + assert d_pad[0, 0, 0, 0] == 6 + assert d_pad[0, 1, 0, 0] == 1 + assert d_pad[0, 2, 0, 0] == 0 + + def test_prefer_gpu(): current_ops = get_current_ops() if has_cupy_gpu: diff --git a/spacy/training/example.pyx b/spacy/training/example.pyx index efca4bcb0..0a184f3a2 100644 --- a/spacy/training/example.pyx +++ b/spacy/training/example.pyx @@ -1,5 +1,5 @@ from collections.abc import Iterable as IterableInstance - +import warnings import numpy from murmurhash.mrmr cimport hash64 diff --git a/website/docs/api/architectures.mdx b/website/docs/api/architectures.mdx index 3c3597d6a..bab24f13b 100644 --- a/website/docs/api/architectures.mdx +++ b/website/docs/api/architectures.mdx @@ -553,17 +553,18 @@ for a Tok2Vec layer. ## Parser & NER architectures {id="parser"} -### spacy.TransitionBasedParser.v3 {id="TransitionBasedParser",source="spacy/ml/models/parser.py"} +### spacy.TransitionBasedParser.v2 {id="TransitionBasedParser",source="spacy/ml/models/parser.py"} > #### Example Config > > ```ini > [model] -> @architectures = "spacy.TransitionBasedParser.v3" +> @architectures = "spacy.TransitionBasedParser.v2" > state_type = "ner" > extra_state_tokens = false > hidden_width = 64 > maxout_pieces = 2 +> use_upper = true > > [model.tok2vec] > @architectures = "spacy.HashEmbedCNN.v2" @@ -593,22 +594,23 @@ consists of either two or three subnetworks: state representation. If not present, the output from the lower model is used as action scores directly. -| Name | Description | -| -------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `tok2vec` | Subnetwork to map tokens into vector representations. ~~Model[List[Doc], List[Floats2d]]~~ | -| `state_type` | Which task to extract features for. Possible values are "ner" and "parser". ~~str~~ | -| `extra_state_tokens` | Whether to use an expanded feature set when extracting the state tokens. Slightly slower, but sometimes improves accuracy slightly. Defaults to `False`. ~~bool~~ | -| `hidden_width` | The width of the hidden layer. ~~int~~ | -| `maxout_pieces` | How many pieces to use in the state prediction layer. Recommended values are `1`, `2` or `3`. ~~int~~ | -| `nO` | The number of actions the model will predict between. Usually inferred from data at the beginning of training, or loaded from disk. ~~int~~ | -| **CREATES** | The model using the architecture. ~~Model[List[Docs], List[List[Floats2d]]]~~ | +| Name | Description | +| -------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `tok2vec` | Subnetwork to map tokens into vector representations. ~~Model[List[Doc], List[Floats2d]]~~ | +| `state_type` | Which task to extract features for. Possible values are "ner" and "parser". ~~str~~ | +| `extra_state_tokens` | Whether to use an expanded feature set when extracting the state tokens. Slightly slower, but sometimes improves accuracy slightly. Defaults to `False`. ~~bool~~ | +| `hidden_width` | The width of the hidden layer. ~~int~~ | +| `maxout_pieces` | How many pieces to use in the state prediction layer. Recommended values are `1`, `2` or `3`. If `1`, the maxout non-linearity is replaced with a [`Relu`](https://thinc.ai/docs/api-layers#relu) non-linearity if `use_upper` is `True`, and no non-linearity if `False`. ~~int~~ | +| `use_upper` | Whether to use an additional hidden layer after the state vector in order to predict the action scores. It is recommended to set this to `False` for large pretrained models such as transformers, and `True` for smaller networks. The upper layer is computed on CPU, which becomes a bottleneck on larger GPU-based models, where it's also less necessary. ~~bool~~ | +| `nO` | The number of actions the model will predict between. Usually inferred from data at the beginning of training, or loaded from disk. ~~int~~ | +| **CREATES** | The model using the architecture. ~~Model[List[Docs], List[List[Floats2d]]]~~ | [TransitionBasedParser.v1](/api/legacy#TransitionBasedParser_v1) had the exact same signature, but the `use_upper` argument was `True` by default. - + ## Tagging architectures {id="tagger",source="spacy/ml/models/tagger.py"} diff --git a/website/docs/api/cli.mdx b/website/docs/api/cli.mdx index 002b0b39d..c61295778 100644 --- a/website/docs/api/cli.mdx +++ b/website/docs/api/cli.mdx @@ -361,7 +361,7 @@ Module spacy.language File /path/to/spacy/language.py (line 64) ℹ [components.ner.model] Registry @architectures -Name spacy.TransitionBasedParser.v3 +Name spacy.TransitionBasedParser.v1 Module spacy.ml.models.parser File /path/to/spacy/ml/models/parser.py (line 11) ℹ [components.ner.model.tok2vec] @@ -371,7 +371,7 @@ Module spacy.ml.models.tok2vec File /path/to/spacy/ml/models/tok2vec.py (line 16) ℹ [components.parser.model] Registry @architectures -Name spacy.TransitionBasedParser.v3 +Name spacy.TransitionBasedParser.v1 Module spacy.ml.models.parser File /path/to/spacy/ml/models/parser.py (line 11) ℹ [components.parser.model.tok2vec] @@ -696,7 +696,7 @@ scorer = {"@scorers":"spacy.ner_scorer.v1"} update_with_oracle_cut_size = 100 [components.ner.model] -@architectures = "spacy.TransitionBasedParser.v3" +@architectures = "spacy.TransitionBasedParser.v2" state_type = "ner" extra_state_tokens = false - hidden_width = 64 @@ -719,7 +719,7 @@ scorer = {"@scorers":"spacy.parser_scorer.v1"} update_with_oracle_cut_size = 100 [components.parser.model] -@architectures = "spacy.TransitionBasedParser.v3" +@architectures = "spacy.TransitionBasedParser.v2" state_type = "parser" extra_state_tokens = false hidden_width = 128 diff --git a/website/docs/api/legacy.mdx b/website/docs/api/legacy.mdx index 70d6223e7..ea6d3a899 100644 --- a/website/docs/api/legacy.mdx +++ b/website/docs/api/legacy.mdx @@ -225,7 +225,7 @@ the others, but may not be as accurate, especially if texts are short. ### spacy.TransitionBasedParser.v1 {id="TransitionBasedParser_v1"} Identical to -[`spacy.TransitionBasedParser.v3`](/api/architectures#TransitionBasedParser) +[`spacy.TransitionBasedParser.v2`](/api/architectures#TransitionBasedParser) except the `use_upper` was set to `True` by default. ## Layers {id="layers"} diff --git a/website/docs/usage/embeddings-transformers.mdx b/website/docs/usage/embeddings-transformers.mdx index 5504f94b6..5f1e5b817 100644 --- a/website/docs/usage/embeddings-transformers.mdx +++ b/website/docs/usage/embeddings-transformers.mdx @@ -140,7 +140,7 @@ factory = "tok2vec" factory = "ner" [components.ner.model] -@architectures = "spacy.TransitionBasedParser.v3" +@architectures = "spacy.TransitionBasedParser.v1" [components.ner.model.tok2vec] @architectures = "spacy.Tok2VecListener.v1" @@ -156,7 +156,7 @@ same. This makes them fully independent and doesn't require an upstream factory = "ner" [components.ner.model] -@architectures = "spacy.TransitionBasedParser.v3" +@architectures = "spacy.TransitionBasedParser.v1" [components.ner.model.tok2vec] @architectures = "spacy.Tok2Vec.v2" @@ -472,7 +472,7 @@ sneakily delegates to the `Transformer` pipeline component. factory = "ner" [nlp.pipeline.ner.model] -@architectures = "spacy.TransitionBasedParser.v3" +@architectures = "spacy.TransitionBasedParser.v1" state_type = "ner" extra_state_tokens = false hidden_width = 128