Revert "Merge the parser refactor into v4 (#10940)"

This reverts commit a183db3cef.
This commit is contained in:
Daniël de Kok 2023-12-08 20:23:08 +01:00
parent 05803cfe76
commit e5ec45cb7e
36 changed files with 1385 additions and 1311 deletions

View File

@ -33,10 +33,12 @@ MOD_NAMES = [
"spacy.kb.candidate",
"spacy.kb.kb",
"spacy.kb.kb_in_memory",
"spacy.ml.tb_framework",
"spacy.ml.parser_model",
"spacy.morphology",
"spacy.pipeline.dep_parser",
"spacy.pipeline._edit_tree_internals.edit_trees",
"spacy.pipeline.morphologizer",
"spacy.pipeline.ner",
"spacy.pipeline.pipe",
"spacy.pipeline.trainable_pipe",
"spacy.pipeline.sentencizer",
@ -44,7 +46,6 @@ MOD_NAMES = [
"spacy.pipeline.tagger",
"spacy.pipeline.transition_parser",
"spacy.pipeline._parser_internals.arc_eager",
"spacy.pipeline._parser_internals.batch",
"spacy.pipeline._parser_internals.ner",
"spacy.pipeline._parser_internals.nonproj",
"spacy.pipeline._parser_internals.search",
@ -52,7 +53,6 @@ MOD_NAMES = [
"spacy.pipeline._parser_internals.stateclass",
"spacy.pipeline._parser_internals.transition_system",
"spacy.pipeline._parser_internals._beam_utils",
"spacy.pipeline._parser_internals._parser_utils",
"spacy.tokenizer",
"spacy.training.align",
"spacy.training.gold_io",

View File

@ -90,11 +90,12 @@ grad_factor = 1.0
factory = "parser"
[components.parser.model]
@architectures = "spacy.TransitionBasedParser.v3"
@architectures = "spacy.TransitionBasedParser.v2"
state_type = "parser"
extra_state_tokens = false
hidden_width = 128
maxout_pieces = 3
use_upper = false
nO = null
[components.parser.model.tok2vec]
@ -110,11 +111,12 @@ grad_factor = 1.0
factory = "ner"
[components.ner.model]
@architectures = "spacy.TransitionBasedParser.v3"
@architectures = "spacy.TransitionBasedParser.v2"
state_type = "ner"
extra_state_tokens = false
hidden_width = 64
maxout_pieces = 2
use_upper = false
nO = null
[components.ner.model.tok2vec]
@ -383,11 +385,12 @@ width = ${components.tok2vec.model.encode.width}
factory = "parser"
[components.parser.model]
@architectures = "spacy.TransitionBasedParser.v3"
@architectures = "spacy.TransitionBasedParser.v2"
state_type = "parser"
extra_state_tokens = false
hidden_width = 128
maxout_pieces = 3
use_upper = true
nO = null
[components.parser.model.tok2vec]
@ -400,11 +403,12 @@ width = ${components.tok2vec.model.encode.width}
factory = "ner"
[components.ner.model]
@architectures = "spacy.TransitionBasedParser.v3"
@architectures = "spacy.TransitionBasedParser.v2"
state_type = "ner"
extra_state_tokens = false
hidden_width = 64
maxout_pieces = 2
use_upper = true
nO = null
[components.ner.model.tok2vec]

View File

@ -23,6 +23,11 @@ try:
except ImportError:
cupy = None
if sys.version_info[:2] >= (3, 8): # Python 3.8+
from typing import Literal, Protocol, runtime_checkable
else:
from typing_extensions import Literal, Protocol, runtime_checkable # noqa: F401
from thinc.api import Optimizer # noqa: F401
pickle = pickle

View File

@ -215,12 +215,6 @@ class Warnings(metaclass=ErrorsWithCodes):
"key attribute for vectors, configure it through Vectors(attr=) or "
"'spacy init vectors --attr'")
# v4 warning strings
W400 = ("`use_upper=False` is ignored, the upper layer is always enabled")
W401 = ("`incl_prior is True`, but the selected knowledge base type {kb_type} doesn't support prior probability "
"lookups so this setting will be ignored. If your KB does support prior probability lookups, make sure "
"to return `True` in `.supports_prior_probs`.")
class Errors(metaclass=ErrorsWithCodes):
E001 = ("No component '{name}' found in pipeline. Available names: {opts}")
@ -1000,6 +994,7 @@ class Errors(metaclass=ErrorsWithCodes):
E4011 = ("Server error ({status_code}), couldn't fetch {url}")
RENAMED_LANGUAGE_CODES = {"xx": "mul", "is": "isl"}
# fmt: on

View File

@ -0,0 +1,164 @@
from thinc.api import Model, normal_init
from ..util import registry
@registry.layers("spacy.PrecomputableAffine.v1")
def PrecomputableAffine(nO, nI, nF, nP, dropout=0.1):
model = Model(
"precomputable_affine",
forward,
init=init,
dims={"nO": nO, "nI": nI, "nF": nF, "nP": nP},
params={"W": None, "b": None, "pad": None},
attrs={"dropout_rate": dropout},
)
return model
def forward(model, X, is_train):
nF = model.get_dim("nF")
nO = model.get_dim("nO")
nP = model.get_dim("nP")
nI = model.get_dim("nI")
W = model.get_param("W")
# Preallocate array for layer output, including padding.
Yf = model.ops.alloc2f(X.shape[0] + 1, nF * nO * nP, zeros=False)
model.ops.gemm(X, W.reshape((nF * nO * nP, nI)), trans2=True, out=Yf[1:])
Yf = Yf.reshape((Yf.shape[0], nF, nO, nP))
# Set padding. Padding has shape (1, nF, nO, nP). Unfortunately, we cannot
# change its shape to (nF, nO, nP) without breaking existing models. So
# we'll squeeze the first dimension here.
Yf[0] = model.ops.xp.squeeze(model.get_param("pad"), 0)
def backward(dY_ids):
# This backprop is particularly tricky, because we get back a different
# thing from what we put out. We put out an array of shape:
# (nB, nF, nO, nP), and get back:
# (nB, nO, nP) and ids (nB, nF)
# The ids tell us the values of nF, so we would have:
#
# dYf = zeros((nB, nF, nO, nP))
# for b in range(nB):
# for f in range(nF):
# dYf[b, ids[b, f]] += dY[b]
#
# However, we avoid building that array for efficiency -- and just pass
# in the indices.
dY, ids = dY_ids
assert dY.ndim == 3
assert dY.shape[1] == nO, dY.shape
assert dY.shape[2] == nP, dY.shape
# nB = dY.shape[0]
model.inc_grad("pad", _backprop_precomputable_affine_padding(model, dY, ids))
Xf = X[ids]
Xf = Xf.reshape((Xf.shape[0], nF * nI))
model.inc_grad("b", dY.sum(axis=0))
dY = dY.reshape((dY.shape[0], nO * nP))
Wopfi = W.transpose((1, 2, 0, 3))
Wopfi = Wopfi.reshape((nO * nP, nF * nI))
dXf = model.ops.gemm(dY.reshape((dY.shape[0], nO * nP)), Wopfi)
dWopfi = model.ops.gemm(dY, Xf, trans1=True)
dWopfi = dWopfi.reshape((nO, nP, nF, nI))
# (o, p, f, i) --> (f, o, p, i)
dWopfi = dWopfi.transpose((2, 0, 1, 3))
model.inc_grad("W", dWopfi)
return dXf.reshape((dXf.shape[0], nF, nI))
return Yf, backward
def _backprop_precomputable_affine_padding(model, dY, ids):
nB = dY.shape[0]
nF = model.get_dim("nF")
nP = model.get_dim("nP")
nO = model.get_dim("nO")
# Backprop the "padding", used as a filler for missing values.
# Values that are missing are set to -1, and each state vector could
# have multiple missing values. The padding has different values for
# different missing features. The gradient of the padding vector is:
#
# for b in range(nB):
# for f in range(nF):
# if ids[b, f] < 0:
# d_pad[f] += dY[b]
#
# Which can be rewritten as:
#
# (ids < 0).T @ dY
mask = model.ops.asarray(ids < 0, dtype="f")
d_pad = model.ops.gemm(mask, dY.reshape(nB, nO * nP), trans1=True)
return d_pad.reshape((1, nF, nO, nP))
def init(model, X=None, Y=None):
"""This is like the 'layer sequential unit variance', but instead
of taking the actual inputs, we randomly generate whitened data.
Why's this all so complicated? We have a huge number of inputs,
and the maxout unit makes guessing the dynamics tricky. Instead
we set the maxout weights to values that empirically result in
whitened outputs given whitened inputs.
"""
if model.has_param("W") and model.get_param("W").any():
return
nF = model.get_dim("nF")
nO = model.get_dim("nO")
nP = model.get_dim("nP")
nI = model.get_dim("nI")
W = model.ops.alloc4f(nF, nO, nP, nI)
b = model.ops.alloc2f(nO, nP)
pad = model.ops.alloc4f(1, nF, nO, nP)
ops = model.ops
W = normal_init(ops, W.shape, mean=float(ops.xp.sqrt(1.0 / nF * nI)))
pad = normal_init(ops, pad.shape, mean=1.0)
model.set_param("W", W)
model.set_param("b", b)
model.set_param("pad", pad)
ids = ops.alloc((5000, nF), dtype="f")
ids += ops.xp.random.uniform(0, 1000, ids.shape)
ids = ops.asarray(ids, dtype="i")
tokvecs = ops.alloc((5000, nI), dtype="f")
tokvecs += ops.xp.random.normal(loc=0.0, scale=1.0, size=tokvecs.size).reshape(
tokvecs.shape
)
def predict(ids, tokvecs):
# nS ids. nW tokvecs. Exclude the padding array.
hiddens = model.predict(tokvecs[:-1]) # (nW, f, o, p)
vectors = model.ops.alloc((ids.shape[0], nO * nP), dtype="f")
# need nS vectors
hiddens = hiddens.reshape((hiddens.shape[0] * nF, nO * nP))
model.ops.scatter_add(vectors, ids.flatten(), hiddens)
vectors = vectors.reshape((vectors.shape[0], nO, nP))
vectors += b
vectors = model.ops.asarray(vectors)
if nP >= 2:
return model.ops.maxout(vectors)[0]
else:
return vectors * (vectors >= 0)
tol_var = 0.01
tol_mean = 0.01
t_max = 10
W = model.get_param("W").copy()
b = model.get_param("b").copy()
for t_i in range(t_max):
acts1 = predict(ids, tokvecs)
var = model.ops.xp.var(acts1)
mean = model.ops.xp.mean(acts1)
if abs(var - 1.0) >= tol_var:
W /= model.ops.xp.sqrt(var)
model.set_param("W", W)
elif abs(mean) >= tol_mean:
b -= mean
model.set_param("b", b)
else:
break

View File

@ -1,66 +1,23 @@
import warnings
from typing import Any, List, Literal, Optional, Tuple
from thinc.api import Model
from typing import Optional, List, cast
from thinc.api import Model, chain, list2array, Linear, zero_init, use_ops
from thinc.types import Floats2d
from ...errors import Errors, Warnings
from ...tokens.doc import Doc
from ...errors import Errors
from ...compat import Literal
from ...util import registry
from .._precomputable_affine import PrecomputableAffine
from ..tb_framework import TransitionModel
TransitionSystem = Any # TODO
State = Any # TODO
@registry.architectures.register("spacy.TransitionBasedParser.v2")
def transition_parser_v2(
tok2vec: Model[List[Doc], List[Floats2d]],
state_type: Literal["parser", "ner"],
extra_state_tokens: bool,
hidden_width: int,
maxout_pieces: int,
use_upper: bool,
nO: Optional[int] = None,
) -> Model:
if not use_upper:
warnings.warn(Warnings.W400)
return build_tb_parser_model(
tok2vec,
state_type,
extra_state_tokens,
hidden_width,
maxout_pieces,
nO=nO,
)
@registry.architectures.register("spacy.TransitionBasedParser.v3")
def transition_parser_v3(
tok2vec: Model[List[Doc], List[Floats2d]],
state_type: Literal["parser", "ner"],
extra_state_tokens: bool,
hidden_width: int,
maxout_pieces: int,
nO: Optional[int] = None,
) -> Model:
return build_tb_parser_model(
tok2vec,
state_type,
extra_state_tokens,
hidden_width,
maxout_pieces,
nO=nO,
)
from ...tokens import Doc
@registry.architectures("spacy.TransitionBasedParser.v2")
def build_tb_parser_model(
tok2vec: Model[List[Doc], List[Floats2d]],
state_type: Literal["parser", "ner"],
extra_state_tokens: bool,
hidden_width: int,
maxout_pieces: int,
use_upper: bool,
nO: Optional[int] = None,
) -> Model:
"""
@ -94,7 +51,14 @@ def build_tb_parser_model(
feature sets (for the NER) or 13 (for the parser).
hidden_width (int): The width of the hidden layer.
maxout_pieces (int): How many pieces to use in the state prediction layer.
Recommended values are 1, 2 or 3.
Recommended values are 1, 2 or 3. If 1, the maxout non-linearity
is replaced with a ReLu non-linearity if use_upper=True, and no
non-linearity if use_upper=False.
use_upper (bool): Whether to use an additional hidden layer after the state
vector in order to predict the action scores. It is recommended to set
this to False for large pretrained models such as transformers, and True
for smaller networks. The upper layer is computed on CPU, which becomes
a bottleneck on larger GPU-based models, where it's also less necessary.
nO (int or None): The number of actions the model will predict between.
Usually inferred from data at the beginning of training, or loaded from
disk.
@ -105,11 +69,106 @@ def build_tb_parser_model(
nr_feature_tokens = 6 if extra_state_tokens else 3
else:
raise ValueError(Errors.E917.format(value=state_type))
return TransitionModel(
tok2vec=tok2vec,
state_tokens=nr_feature_tokens,
hidden_width=hidden_width,
maxout_pieces=maxout_pieces,
nO=nO,
unseen_classes=set(),
t2v_width = tok2vec.get_dim("nO") if tok2vec.has_dim("nO") else None
tok2vec = chain(
tok2vec,
list2array(),
Linear(hidden_width, t2v_width),
)
tok2vec.set_dim("nO", hidden_width)
lower = _define_lower(
nO=hidden_width if use_upper else nO,
nF=nr_feature_tokens,
nI=tok2vec.get_dim("nO"),
nP=maxout_pieces,
)
upper = None
if use_upper:
with use_ops("cpu"):
# Initialize weights at zero, as it's a classification layer.
upper = _define_upper(nO=nO, nI=None)
return TransitionModel(tok2vec, lower, upper, resize_output)
def _define_upper(nO, nI):
return Linear(nO=nO, nI=nI, init_W=zero_init)
def _define_lower(nO, nF, nI, nP):
return PrecomputableAffine(nO=nO, nF=nF, nI=nI, nP=nP)
def resize_output(model, new_nO):
if model.attrs["has_upper"]:
return _resize_upper(model, new_nO)
return _resize_lower(model, new_nO)
def _resize_upper(model, new_nO):
upper = model.get_ref("upper")
if upper.has_dim("nO") is None:
upper.set_dim("nO", new_nO)
return model
elif new_nO == upper.get_dim("nO"):
return model
smaller = upper
nI = smaller.maybe_get_dim("nI")
with use_ops("cpu"):
larger = _define_upper(nO=new_nO, nI=nI)
# it could be that the model is not initialized yet, then skip this bit
if smaller.has_param("W"):
larger_W = larger.ops.alloc2f(new_nO, nI)
larger_b = larger.ops.alloc1f(new_nO)
smaller_W = smaller.get_param("W")
smaller_b = smaller.get_param("b")
# Weights are stored in (nr_out, nr_in) format, so we're basically
# just adding rows here.
if smaller.has_dim("nO"):
old_nO = smaller.get_dim("nO")
larger_W[:old_nO] = smaller_W
larger_b[:old_nO] = smaller_b
for i in range(old_nO, new_nO):
model.attrs["unseen_classes"].add(i)
larger.set_param("W", larger_W)
larger.set_param("b", larger_b)
model._layers[-1] = larger
model.set_ref("upper", larger)
return model
def _resize_lower(model, new_nO):
lower = model.get_ref("lower")
if lower.has_dim("nO") is None:
lower.set_dim("nO", new_nO)
return model
smaller = lower
nI = smaller.maybe_get_dim("nI")
nF = smaller.maybe_get_dim("nF")
nP = smaller.maybe_get_dim("nP")
larger = _define_lower(nO=new_nO, nI=nI, nF=nF, nP=nP)
# it could be that the model is not initialized yet, then skip this bit
if smaller.has_param("W"):
larger_W = larger.ops.alloc4f(nF, new_nO, nP, nI)
larger_b = larger.ops.alloc2f(new_nO, nP)
larger_pad = larger.ops.alloc4f(1, nF, new_nO, nP)
smaller_W = smaller.get_param("W")
smaller_b = smaller.get_param("b")
smaller_pad = smaller.get_param("pad")
# Copy the old weights and padding into the new layer
if smaller.has_dim("nO"):
old_nO = smaller.get_dim("nO")
larger_W[:, 0:old_nO, :, :] = smaller_W
larger_pad[:, :, 0:old_nO, :] = smaller_pad
larger_b[0:old_nO, :] = smaller_b
for i in range(old_nO, new_nO):
model.attrs["unseen_classes"].add(i)
larger.set_param("W", larger_W)
larger.set_param("b", larger_b)
larger.set_param("pad", larger_pad)
model._layers[1] = larger
model.set_ref("lower", larger)
return model

49
spacy/ml/parser_model.pxd Normal file
View File

@ -0,0 +1,49 @@
from libc.string cimport memset, memcpy
from thinc.backends.cblas cimport CBlas
from ..typedefs cimport weight_t, hash_t
from ..pipeline._parser_internals._state cimport StateC
cdef struct SizesC:
int states
int classes
int hiddens
int pieces
int feats
int embed_width
cdef struct WeightsC:
const float* feat_weights
const float* feat_bias
const float* hidden_bias
const float* hidden_weights
const float* seen_classes
cdef struct ActivationsC:
int* token_ids
float* unmaxed
float* scores
float* hiddens
int* is_valid
int _curr_size
int _max_size
cdef WeightsC get_c_weights(model) except *
cdef SizesC get_c_sizes(model, int batch_size) except *
cdef ActivationsC alloc_activations(SizesC n) nogil
cdef void free_activations(const ActivationsC* A) nogil
cdef void predict_states(CBlas cblas, ActivationsC* A, StateC** states,
const WeightsC* W, SizesC n) nogil
cdef int arg_max_if_valid(const weight_t* scores, const int* is_valid, int n) nogil
cdef void cpu_log_loss(float* d_scores,
const float* costs, const int* is_valid, const float* scores, int O) nogil

500
spacy/ml/parser_model.pyx Normal file
View File

@ -0,0 +1,500 @@
# cython: infer_types=True, cdivision=True, boundscheck=False
cimport numpy as np
from libc.math cimport exp
from libc.string cimport memset, memcpy
from libc.stdlib cimport calloc, free, realloc
from thinc.backends.cblas cimport saxpy, sgemm
import numpy
import numpy.random
from thinc.api import Model, CupyOps, NumpyOps, get_ops
from .. import util
from ..errors import Errors
from ..typedefs cimport weight_t, class_t, hash_t
from ..pipeline._parser_internals.stateclass cimport StateClass
cdef WeightsC get_c_weights(model) except *:
cdef WeightsC output
cdef precompute_hiddens state2vec = model.state2vec
output.feat_weights = state2vec.get_feat_weights()
output.feat_bias = <const float*>state2vec.bias.data
cdef np.ndarray vec2scores_W
cdef np.ndarray vec2scores_b
if model.vec2scores is None:
output.hidden_weights = NULL
output.hidden_bias = NULL
else:
vec2scores_W = model.vec2scores.get_param("W")
vec2scores_b = model.vec2scores.get_param("b")
output.hidden_weights = <const float*>vec2scores_W.data
output.hidden_bias = <const float*>vec2scores_b.data
cdef np.ndarray class_mask = model._class_mask
output.seen_classes = <const float*>class_mask.data
return output
cdef SizesC get_c_sizes(model, int batch_size) except *:
cdef SizesC output
output.states = batch_size
if model.vec2scores is None:
output.classes = model.state2vec.get_dim("nO")
else:
output.classes = model.vec2scores.get_dim("nO")
output.hiddens = model.state2vec.get_dim("nO")
output.pieces = model.state2vec.get_dim("nP")
output.feats = model.state2vec.get_dim("nF")
output.embed_width = model.tokvecs.shape[1]
return output
cdef ActivationsC alloc_activations(SizesC n) nogil:
cdef ActivationsC A
memset(&A, 0, sizeof(A))
resize_activations(&A, n)
return A
cdef void free_activations(const ActivationsC* A) nogil:
free(A.token_ids)
free(A.scores)
free(A.unmaxed)
free(A.hiddens)
free(A.is_valid)
cdef void resize_activations(ActivationsC* A, SizesC n) nogil:
if n.states <= A._max_size:
A._curr_size = n.states
return
if A._max_size == 0:
A.token_ids = <int*>calloc(n.states * n.feats, sizeof(A.token_ids[0]))
A.scores = <float*>calloc(n.states * n.classes, sizeof(A.scores[0]))
A.unmaxed = <float*>calloc(n.states * n.hiddens * n.pieces, sizeof(A.unmaxed[0]))
A.hiddens = <float*>calloc(n.states * n.hiddens, sizeof(A.hiddens[0]))
A.is_valid = <int*>calloc(n.states * n.classes, sizeof(A.is_valid[0]))
A._max_size = n.states
else:
A.token_ids = <int*>realloc(A.token_ids,
n.states * n.feats * sizeof(A.token_ids[0]))
A.scores = <float*>realloc(A.scores,
n.states * n.classes * sizeof(A.scores[0]))
A.unmaxed = <float*>realloc(A.unmaxed,
n.states * n.hiddens * n.pieces * sizeof(A.unmaxed[0]))
A.hiddens = <float*>realloc(A.hiddens,
n.states * n.hiddens * sizeof(A.hiddens[0]))
A.is_valid = <int*>realloc(A.is_valid,
n.states * n.classes * sizeof(A.is_valid[0]))
A._max_size = n.states
A._curr_size = n.states
cdef void predict_states(CBlas cblas, ActivationsC* A, StateC** states,
const WeightsC* W, SizesC n) nogil:
cdef double one = 1.0
resize_activations(A, n)
for i in range(n.states):
states[i].set_context_tokens(&A.token_ids[i*n.feats], n.feats)
memset(A.unmaxed, 0, n.states * n.hiddens * n.pieces * sizeof(float))
memset(A.hiddens, 0, n.states * n.hiddens * sizeof(float))
sum_state_features(cblas, A.unmaxed,
W.feat_weights, A.token_ids, n.states, n.feats, n.hiddens * n.pieces)
for i in range(n.states):
saxpy(cblas)(n.hiddens * n.pieces, 1., W.feat_bias, 1, &A.unmaxed[i*n.hiddens*n.pieces], 1)
for j in range(n.hiddens):
index = i * n.hiddens * n.pieces + j * n.pieces
which = _arg_max(&A.unmaxed[index], n.pieces)
A.hiddens[i*n.hiddens + j] = A.unmaxed[index + which]
memset(A.scores, 0, n.states * n.classes * sizeof(float))
if W.hidden_weights == NULL:
memcpy(A.scores, A.hiddens, n.states * n.classes * sizeof(float))
else:
# Compute hidden-to-output
sgemm(cblas)(False, True, n.states, n.classes, n.hiddens,
1.0, <const float *>A.hiddens, n.hiddens,
<const float *>W.hidden_weights, n.hiddens,
0.0, A.scores, n.classes)
# Add bias
for i in range(n.states):
saxpy(cblas)(n.classes, 1., W.hidden_bias, 1, &A.scores[i*n.classes], 1)
# Set unseen classes to minimum value
i = 0
min_ = A.scores[0]
for i in range(1, n.states * n.classes):
if A.scores[i] < min_:
min_ = A.scores[i]
for i in range(n.states):
for j in range(n.classes):
if not W.seen_classes[j]:
A.scores[i*n.classes+j] = min_
cdef void sum_state_features(CBlas cblas, float* output,
const float* cached, const int* token_ids, int B, int F, int O) nogil:
cdef int idx, b, f, i
cdef const float* feature
padding = cached
cached += F * O
cdef int id_stride = F*O
cdef float one = 1.
for b in range(B):
for f in range(F):
if token_ids[f] < 0:
feature = &padding[f*O]
else:
idx = token_ids[f] * id_stride + f*O
feature = &cached[idx]
saxpy(cblas)(O, one, <const float*>feature, 1, &output[b*O], 1)
token_ids += F
cdef void cpu_log_loss(float* d_scores,
const float* costs, const int* is_valid, const float* scores,
int O) nogil:
"""Do multi-label log loss"""
cdef double max_, gmax, Z, gZ
best = arg_max_if_gold(scores, costs, is_valid, O)
guess = _arg_max(scores, O)
if best == -1 or guess == -1:
# These shouldn't happen, but if they do, we want to make sure we don't
# cause an OOB access.
return
Z = 1e-10
gZ = 1e-10
max_ = scores[guess]
gmax = scores[best]
for i in range(O):
Z += exp(scores[i] - max_)
if costs[i] <= costs[best]:
gZ += exp(scores[i] - gmax)
for i in range(O):
if costs[i] <= costs[best]:
d_scores[i] = (exp(scores[i]-max_) / Z) - (exp(scores[i]-gmax)/gZ)
else:
d_scores[i] = exp(scores[i]-max_) / Z
cdef int arg_max_if_gold(const weight_t* scores, const weight_t* costs,
const int* is_valid, int n) nogil:
# Find minimum cost
cdef float cost = 1
for i in range(n):
if is_valid[i] and costs[i] < cost:
cost = costs[i]
# Now find best-scoring with that cost
cdef int best = -1
for i in range(n):
if costs[i] <= cost and is_valid[i]:
if best == -1 or scores[i] > scores[best]:
best = i
return best
cdef int arg_max_if_valid(const weight_t* scores, const int* is_valid, int n) nogil:
cdef int best = -1
for i in range(n):
if is_valid[i] >= 1:
if best == -1 or scores[i] > scores[best]:
best = i
return best
class ParserStepModel(Model):
def __init__(self, docs, layers, *, has_upper, unseen_classes=None, train=True,
dropout=0.1):
Model.__init__(self, name="parser_step_model", forward=step_forward)
self.attrs["has_upper"] = has_upper
self.attrs["dropout_rate"] = dropout
self.tokvecs, self.bp_tokvecs = layers[0](docs, is_train=train)
if layers[1].get_dim("nP") >= 2:
activation = "maxout"
elif has_upper:
activation = None
else:
activation = "relu"
self.state2vec = precompute_hiddens(len(docs), self.tokvecs, layers[1],
activation=activation, train=train)
if has_upper:
self.vec2scores = layers[-1]
else:
self.vec2scores = None
self.cuda_stream = util.get_cuda_stream(non_blocking=True)
self.backprops = []
self._class_mask = numpy.zeros((self.nO,), dtype='f')
self._class_mask.fill(1)
if unseen_classes is not None:
for class_ in unseen_classes:
self._class_mask[class_] = 0.
def clear_memory(self):
del self.tokvecs
del self.bp_tokvecs
del self.state2vec
del self.backprops
del self._class_mask
@property
def nO(self):
if self.attrs["has_upper"]:
return self.vec2scores.get_dim("nO")
else:
return self.state2vec.get_dim("nO")
def class_is_unseen(self, class_):
return self._class_mask[class_]
def mark_class_unseen(self, class_):
self._class_mask[class_] = 0
def mark_class_seen(self, class_):
self._class_mask[class_] = 1
def get_token_ids(self, states):
cdef StateClass state
states = [state for state in states if not state.is_final()]
cdef np.ndarray ids = numpy.zeros((len(states), self.state2vec.nF),
dtype='i', order='C')
ids.fill(-1)
c_ids = <int*>ids.data
for state in states:
state.c.set_context_tokens(c_ids, ids.shape[1])
c_ids += ids.shape[1]
return ids
def backprop_step(self, token_ids, d_vector, get_d_tokvecs):
if isinstance(self.state2vec.ops, CupyOps) \
and not isinstance(token_ids, self.state2vec.ops.xp.ndarray):
# Move token_ids and d_vector to GPU, asynchronously
self.backprops.append((
util.get_async(self.cuda_stream, token_ids),
util.get_async(self.cuda_stream, d_vector),
get_d_tokvecs
))
else:
self.backprops.append((token_ids, d_vector, get_d_tokvecs))
def finish_steps(self, golds):
# Add a padding vector to the d_tokvecs gradient, so that missing
# values don't affect the real gradient.
d_tokvecs = self.ops.alloc((self.tokvecs.shape[0]+1, self.tokvecs.shape[1]))
# Tells CUDA to block, so our async copies complete.
if self.cuda_stream is not None:
self.cuda_stream.synchronize()
for ids, d_vector, bp_vector in self.backprops:
d_state_features = bp_vector((d_vector, ids))
ids = ids.flatten()
d_state_features = d_state_features.reshape(
(ids.size, d_state_features.shape[2]))
self.ops.scatter_add(d_tokvecs, ids,
d_state_features)
# Padded -- see update()
self.bp_tokvecs(d_tokvecs[:-1])
return d_tokvecs
NUMPY_OPS = NumpyOps()
def step_forward(model: ParserStepModel, states, is_train):
token_ids = model.get_token_ids(states)
vector, get_d_tokvecs = model.state2vec(token_ids, is_train)
mask = None
if model.attrs["has_upper"]:
dropout_rate = model.attrs["dropout_rate"]
if is_train and dropout_rate > 0:
mask = NUMPY_OPS.get_dropout_mask(vector.shape, 0.1)
vector *= mask
scores, get_d_vector = model.vec2scores(vector, is_train)
else:
scores = NumpyOps().asarray(vector)
get_d_vector = lambda d_scores: d_scores
# If the class is unseen, make sure its score is minimum
scores[:, model._class_mask == 0] = numpy.nanmin(scores)
def backprop_parser_step(d_scores):
# Zero vectors for unseen classes
d_scores *= model._class_mask
d_vector = get_d_vector(d_scores)
if mask is not None:
d_vector *= mask
model.backprop_step(token_ids, d_vector, get_d_tokvecs)
return None
return scores, backprop_parser_step
cdef class precompute_hiddens:
"""Allow a model to be "primed" by pre-computing input features in bulk.
This is used for the parser, where we want to take a batch of documents,
and compute vectors for each (token, position) pair. These vectors can then
be reused, especially for beam-search.
Let's say we're using 12 features for each state, e.g. word at start of
buffer, three words on stack, their children, etc. In the normal arc-eager
system, a document of length N is processed in 2*N states. This means we'll
create 2*N*12 feature vectors --- but if we pre-compute, we only need
N*12 vector computations. The saving for beam-search is much better:
if we have a beam of k, we'll normally make 2*N*12*K computations --
so we can save the factor k. This also gives a nice CPU/GPU division:
we can do all our hard maths up front, packed into large multiplications,
and do the hard-to-program parsing on the CPU.
"""
cdef readonly int nF, nO, nP
cdef bint _is_synchronized
cdef public object ops
cdef public object numpy_ops
cdef public object _cpu_ops
cdef np.ndarray _features
cdef np.ndarray _cached
cdef np.ndarray bias
cdef object _cuda_stream
cdef object _bp_hiddens
cdef object activation
def __init__(self, batch_size, tokvecs, lower_model, cuda_stream=None,
activation="maxout", train=False):
gpu_cached, bp_features = lower_model(tokvecs, train)
cdef np.ndarray cached
if not isinstance(gpu_cached, numpy.ndarray):
# Note the passing of cuda_stream here: it lets
# cupy make the copy asynchronously.
# We then have to block before first use.
cached = gpu_cached.get(stream=cuda_stream)
else:
cached = gpu_cached
if not isinstance(lower_model.get_param("b"), numpy.ndarray):
self.bias = lower_model.get_param("b").get(stream=cuda_stream)
else:
self.bias = lower_model.get_param("b")
self.nF = cached.shape[1]
if lower_model.has_dim("nP"):
self.nP = lower_model.get_dim("nP")
else:
self.nP = 1
self.nO = cached.shape[2]
self.ops = lower_model.ops
self.numpy_ops = NumpyOps()
self._cpu_ops = get_ops("cpu") if isinstance(self.ops, CupyOps) else self.ops
assert activation in (None, "relu", "maxout")
self.activation = activation
self._is_synchronized = False
self._cuda_stream = cuda_stream
self._cached = cached
self._bp_hiddens = bp_features
cdef const float* get_feat_weights(self) except NULL:
if not self._is_synchronized and self._cuda_stream is not None:
self._cuda_stream.synchronize()
self._is_synchronized = True
return <float*>self._cached.data
def has_dim(self, name):
if name == "nF":
return self.nF if self.nF is not None else True
elif name == "nP":
return self.nP if self.nP is not None else True
elif name == "nO":
return self.nO if self.nO is not None else True
else:
return False
def get_dim(self, name):
if name == "nF":
return self.nF
elif name == "nP":
return self.nP
elif name == "nO":
return self.nO
else:
raise ValueError(Errors.E1033.format(name=name))
def set_dim(self, name, value):
if name == "nF":
self.nF = value
elif name == "nP":
self.nP = value
elif name == "nO":
self.nO = value
else:
raise ValueError(Errors.E1033.format(name=name))
def __call__(self, X, bint is_train):
if is_train:
return self.begin_update(X)
else:
return self.predict(X), lambda X: X
def predict(self, X):
return self.begin_update(X)[0]
def begin_update(self, token_ids):
cdef np.ndarray state_vector = numpy.zeros(
(token_ids.shape[0], self.nO, self.nP), dtype='f')
# This is tricky, but (assuming GPU available);
# - Input to forward on CPU
# - Output from forward on CPU
# - Input to backward on GPU!
# - Output from backward on GPU
bp_hiddens = self._bp_hiddens
cdef CBlas cblas = self._cpu_ops.cblas()
feat_weights = self.get_feat_weights()
cdef int[:, ::1] ids = token_ids
sum_state_features(cblas, <float*>state_vector.data,
feat_weights, &ids[0,0],
token_ids.shape[0], self.nF, self.nO*self.nP)
state_vector += self.bias
state_vector, bp_nonlinearity = self._nonlinearity(state_vector)
def backward(d_state_vector_ids):
d_state_vector, token_ids = d_state_vector_ids
d_state_vector = bp_nonlinearity(d_state_vector)
d_tokens = bp_hiddens((d_state_vector, token_ids))
return d_tokens
return state_vector, backward
def _nonlinearity(self, state_vector):
if self.activation == "maxout":
return self._maxout_nonlinearity(state_vector)
else:
return self._relu_nonlinearity(state_vector)
def _maxout_nonlinearity(self, state_vector):
state_vector, mask = self.numpy_ops.maxout(state_vector)
# We're outputting to CPU, but we need this variable on GPU for the
# backward pass.
mask = self.ops.asarray(mask)
def backprop_maxout(d_best):
return self.ops.backprop_maxout(d_best, mask, self.nP)
return state_vector, backprop_maxout
def _relu_nonlinearity(self, state_vector):
state_vector = state_vector.reshape((state_vector.shape[0], -1))
mask = state_vector >= 0.
state_vector *= mask
# We're outputting to CPU, but we need this variable on GPU for the
# backward pass.
mask = self.ops.asarray(mask)
def backprop_relu(d_best):
d_best *= mask
return d_best.reshape((d_best.shape + (1,)))
return state_vector, backprop_relu
cdef inline int _arg_max(const float* scores, const int n_classes) nogil:
if n_classes == 2:
return 0 if scores[0] > scores[1] else 1
cdef int i
cdef int best = 0
cdef float mode = scores[0]
for i in range(1, n_classes):
if scores[i] > mode:
mode = scores[i]
best = i
return best

View File

@ -1,28 +0,0 @@
from libc.stdint cimport int8_t
cdef struct SizesC:
int states
int classes
int hiddens
int pieces
int feats
int embed_width
int tokens
cdef struct WeightsC:
const float* feat_weights
const float* feat_bias
const float* hidden_bias
const float* hidden_weights
const int8_t* seen_mask
cdef struct ActivationsC:
int* token_ids
float* unmaxed
float* hiddens
int* is_valid
int _curr_size
int _max_size

50
spacy/ml/tb_framework.py Normal file
View File

@ -0,0 +1,50 @@
from thinc.api import Model, noop
from .parser_model import ParserStepModel
from ..util import registry
@registry.layers("spacy.TransitionModel.v1")
def TransitionModel(
tok2vec, lower, upper, resize_output, dropout=0.2, unseen_classes=set()
):
"""Set up a stepwise transition-based model"""
if upper is None:
has_upper = False
upper = noop()
else:
has_upper = True
# don't define nO for this object, because we can't dynamically change it
return Model(
name="parser_model",
forward=forward,
dims={"nI": tok2vec.maybe_get_dim("nI")},
layers=[tok2vec, lower, upper],
refs={"tok2vec": tok2vec, "lower": lower, "upper": upper},
init=init,
attrs={
"has_upper": has_upper,
"unseen_classes": set(unseen_classes),
"resize_output": resize_output,
},
)
def forward(model, X, is_train):
step_model = ParserStepModel(
X,
model.layers,
unseen_classes=model.attrs["unseen_classes"],
train=is_train,
has_upper=model.attrs["has_upper"],
)
return step_model, step_model.finish_steps
def init(model, X=None, Y=None):
model.get_ref("tok2vec").initialize(X=X)
lower = model.get_ref("lower")
lower.initialize()
if model.attrs["has_upper"]:
statevecs = model.ops.alloc2f(2, lower.get_dim("nO"))
model.get_ref("upper").initialize(X=statevecs)

View File

@ -1,639 +0,0 @@
# cython: infer_types=True, cdivision=True, boundscheck=False
from typing import Any, List, Optional, Tuple, cast
from libc.stdlib cimport calloc, free, realloc
from libc.string cimport memcpy, memset
from libcpp.vector cimport vector
import numpy
cimport numpy as np
from thinc.api import (
Linear,
Model,
NumpyOps,
chain,
glorot_uniform_init,
list2array,
normal_init,
uniform_init,
zero_init,
)
from thinc.backends.cblas cimport CBlas, saxpy, sgemm
from thinc.types import Floats2d, Floats3d, Floats4d, Ints1d, Ints2d
from ..errors import Errors
from ..pipeline._parser_internals import _beam_utils
from ..pipeline._parser_internals.batch import GreedyBatch
from ..pipeline._parser_internals._parser_utils cimport arg_max
from ..pipeline._parser_internals.stateclass cimport StateC, StateClass
from ..pipeline._parser_internals.transition_system cimport (
TransitionSystem,
c_apply_actions,
c_transition_batch,
)
from ..tokens.doc import Doc
from ..util import registry
State = Any # TODO
@registry.layers("spacy.TransitionModel.v2")
def TransitionModel(
*,
tok2vec: Model[List[Doc], List[Floats2d]],
beam_width: int = 1,
beam_density: float = 0.0,
state_tokens: int,
hidden_width: int,
maxout_pieces: int,
nO: Optional[int] = None,
unseen_classes=set(),
) -> Model[Tuple[List[Doc], TransitionSystem], List[Tuple[State, List[Floats2d]]]]:
"""Set up a transition-based parsing model, using a maxout hidden
layer and a linear output layer.
"""
t2v_width = tok2vec.get_dim("nO") if tok2vec.has_dim("nO") else None
tok2vec_projected = chain(tok2vec, list2array(), Linear(hidden_width, t2v_width)) # type: ignore
tok2vec_projected.set_dim("nO", hidden_width)
# FIXME: we use `output` as a container for the output layer's
# weights and biases. Thinc optimizers cannot handle resizing
# of parameters. So, when the parser model is resized, we
# construct a new `output` layer, which has a different key in
# the optimizer. Once the optimizer supports parameter resizing,
# we can replace the `output` layer by `output_W` and `output_b`
# parameters in this model.
output = Linear(nO=None, nI=hidden_width, init_W=zero_init)
return Model(
name="parser_model",
forward=forward,
init=init,
layers=[tok2vec_projected, output],
refs={
"tok2vec": tok2vec_projected,
"output": output,
},
params={
"hidden_W": None, # Floats2d W for the hidden layer
"hidden_b": None, # Floats1d bias for the hidden layer
"hidden_pad": None, # Floats1d padding for the hidden layer
},
dims={
"nO": None, # Output size
"nP": maxout_pieces,
"nH": hidden_width,
"nI": tok2vec_projected.maybe_get_dim("nO"),
"nF": state_tokens,
},
attrs={
"beam_width": beam_width,
"beam_density": beam_density,
"unseen_classes": set(unseen_classes),
"resize_output": resize_output,
},
)
def resize_output(model: Model, new_nO: int) -> Model:
old_nO = model.maybe_get_dim("nO")
output = model.get_ref("output")
if old_nO is None:
model.set_dim("nO", new_nO)
output.set_dim("nO", new_nO)
output.initialize()
return model
elif new_nO <= old_nO:
return model
elif output.has_param("W"):
nH = model.get_dim("nH")
new_output = Linear(nO=new_nO, nI=nH, init_W=zero_init)
new_output.initialize()
new_W = new_output.get_param("W")
new_b = new_output.get_param("b")
old_W = output.get_param("W")
old_b = output.get_param("b")
new_W[:old_nO] = old_W # type: ignore
new_b[:old_nO] = old_b # type: ignore
for i in range(old_nO, new_nO):
model.attrs["unseen_classes"].add(i)
model.layers[-1] = new_output
model.set_ref("output", new_output)
# TODO: Avoid this private intrusion
model._dims["nO"] = new_nO
return model
def init(
model,
X: Optional[Tuple[List[Doc], TransitionSystem]] = None,
Y: Optional[Tuple[List[State], List[Floats2d]]] = None,
):
if X is not None:
docs, _ = X
model.get_ref("tok2vec").initialize(X=docs)
else:
model.get_ref("tok2vec").initialize()
inferred_nO = _infer_nO(Y)
if inferred_nO is not None:
current_nO = model.maybe_get_dim("nO")
if current_nO is None or current_nO != inferred_nO:
model.attrs["resize_output"](model, inferred_nO)
nP = model.get_dim("nP")
nH = model.get_dim("nH")
nI = model.get_dim("nI")
nF = model.get_dim("nF")
ops = model.ops
Wl = ops.alloc2f(nH * nP, nF * nI)
bl = ops.alloc1f(nH * nP)
padl = ops.alloc1f(nI)
# Wl = zero_init(ops, Wl.shape)
Wl = glorot_uniform_init(ops, Wl.shape)
padl = uniform_init(ops, padl.shape) # type: ignore
# TODO: Experiment with whether better to initialize output_W
model.set_param("hidden_W", Wl)
model.set_param("hidden_b", bl)
model.set_param("hidden_pad", padl)
# model = _lsuv_init(model)
return model
class TransitionModelInputs:
"""
Input to transition model.
"""
# dataclass annotation is not yet supported in Cython 0.29.x,
# so, we'll do something close to it.
actions: Optional[List[Ints1d]]
docs: List[Doc]
max_moves: int
moves: TransitionSystem
states: Optional[List[State]]
__slots__ = [
"actions",
"docs",
"max_moves",
"moves",
"states",
]
def __init__(
self,
docs: List[Doc],
moves: TransitionSystem,
actions: Optional[List[Ints1d]] = None,
max_moves: int = 0,
states: Optional[List[State]] = None,
):
"""
actions (Optional[List[Ints1d]]): actions to apply for each Doc.
docs (List[Doc]): Docs to predict transition sequences for.
max_moves: (int): the maximum number of moves to apply, values less
than 1 will apply moves to states until they are final states.
moves (TransitionSystem): the transition system to use when predicting
the transition sequences.
states (Optional[List[States]]): the initial states to predict the
transition sequences for. When absent, the initial states are
initialized from the provided Docs.
"""
self.actions = actions
self.docs = docs
self.moves = moves
self.max_moves = max_moves
self.states = states
def forward(model, inputs: TransitionModelInputs, is_train: bool):
docs = inputs.docs
moves = inputs.moves
actions = inputs.actions
beam_width = model.attrs["beam_width"]
hidden_pad = model.get_param("hidden_pad")
tok2vec = model.get_ref("tok2vec")
states = moves.init_batch(docs) if inputs.states is None else inputs.states
tokvecs, backprop_tok2vec = tok2vec(docs, is_train)
tokvecs = model.ops.xp.vstack((tokvecs, hidden_pad))
feats, backprop_feats = _forward_precomputable_affine(model, tokvecs, is_train)
seen_mask = _get_seen_mask(model)
if not is_train and beam_width == 1 and isinstance(model.ops, NumpyOps):
# Note: max_moves is only used during training, so we don't need to
# pass it to the greedy inference path.
return _forward_greedy_cpu(model, moves, states, feats, seen_mask, actions=actions)
else:
return _forward_fallback(model, moves, states, tokvecs, backprop_tok2vec,
feats, backprop_feats, seen_mask, is_train, actions=actions,
max_moves=inputs.max_moves)
def _forward_greedy_cpu(model: Model, TransitionSystem moves, states: List[StateClass], np.ndarray feats,
np.ndarray[np.npy_bool, ndim = 1] seen_mask, actions: Optional[List[Ints1d]] = None):
cdef vector[StateC*] c_states
cdef StateClass state
for state in states:
if not state.is_final():
c_states.push_back(state.c)
weights = _get_c_weights(model, <float*>feats.data, seen_mask)
# Precomputed features have rows for each token, plus one for padding.
cdef int n_tokens = feats.shape[0] - 1
sizes = _get_c_sizes(model, c_states.size(), n_tokens)
cdef CBlas cblas = model.ops.cblas()
scores = _parse_batch(cblas, moves, &c_states[0], weights, sizes, actions=actions)
def backprop(dY):
raise ValueError(Errors.E4004)
return (states, scores), backprop
cdef list _parse_batch(CBlas cblas, TransitionSystem moves, StateC** states,
WeightsC weights, SizesC sizes, actions: Optional[List[Ints1d]]=None):
cdef int i
cdef vector[StateC *] unfinished
cdef ActivationsC activations = _alloc_activations(sizes)
cdef np.ndarray step_scores
cdef np.ndarray step_actions
scores = []
while sizes.states >= 1:
step_scores = numpy.empty((sizes.states, sizes.classes), dtype="f")
step_actions = actions[0] if actions is not None else None
with nogil:
_predict_states(cblas, &activations, <float*>step_scores.data, states, &weights, sizes)
if actions is None:
# Validate actions, argmax, take action.
c_transition_batch(moves, states, <const float*>step_scores.data, sizes.classes,
sizes.states)
else:
c_apply_actions(moves, states, <const int*>step_actions.data, sizes.states)
for i in range(sizes.states):
if not states[i].is_final():
unfinished.push_back(states[i])
for i in range(unfinished.size()):
states[i] = unfinished[i]
sizes.states = unfinished.size()
scores.append(step_scores)
unfinished.clear()
actions = actions[1:] if actions is not None else None
_free_activations(&activations)
return scores
def _forward_fallback(
model: Model,
moves: TransitionSystem,
states: List[StateClass],
tokvecs, backprop_tok2vec,
feats,
backprop_feats,
seen_mask,
is_train: bool,
actions: Optional[List[Ints1d]] = None,
max_moves: int = 0,
):
nF = model.get_dim("nF")
output = model.get_ref("output")
hidden_b = model.get_param("hidden_b")
nH = model.get_dim("nH")
nP = model.get_dim("nP")
beam_width = model.attrs["beam_width"]
beam_density = model.attrs["beam_density"]
ops = model.ops
all_ids = []
all_which = []
all_statevecs = []
all_scores = []
if beam_width == 1:
batch = GreedyBatch(moves, states, None)
else:
batch = _beam_utils.BeamBatch(
moves, states, None, width=beam_width, density=beam_density
)
arange = ops.xp.arange(nF)
n_moves = 0
while not batch.is_done:
ids = numpy.zeros((len(batch.get_unfinished_states()), nF), dtype="i")
for i, state in enumerate(batch.get_unfinished_states()):
state.set_context_tokens(ids, i, nF)
# Sum the state features, add the bias and apply the activation (maxout)
# to create the state vectors.
preacts2f = feats[ids, arange].sum(axis=1) # type: ignore
preacts2f += hidden_b
preacts = ops.reshape3f(preacts2f, preacts2f.shape[0], nH, nP)
assert preacts.shape[0] == len(batch.get_unfinished_states()), preacts.shape
statevecs, which = ops.maxout(preacts)
# We don't use output's backprop, since we want to backprop for
# all states at once, rather than a single state.
scores = output.predict(statevecs)
scores[:, seen_mask] = ops.xp.nanmin(scores)
# Transition the states, filtering out any that are finished.
cpu_scores = ops.to_numpy(scores)
if actions is None:
batch.advance(cpu_scores)
else:
batch.advance_with_actions(actions[0])
actions = actions[1:]
all_scores.append(scores)
if is_train:
# Remember intermediate results for the backprop.
all_ids.append(ids)
all_statevecs.append(statevecs)
all_which.append(which)
if n_moves >= max_moves >= 1:
break
n_moves += 1
def backprop_parser(d_states_d_scores):
ids = ops.xp.vstack(all_ids)
which = ops.xp.vstack(all_which)
statevecs = ops.xp.vstack(all_statevecs)
_, d_scores = d_states_d_scores
if model.attrs.get("unseen_classes"):
# If we have a negative gradient (i.e. the probability should
# increase) on any classes we filtered out as unseen, mark
# them as seen.
for clas in set(model.attrs["unseen_classes"]):
if (d_scores[:, clas] < 0).any():
model.attrs["unseen_classes"].remove(clas)
d_scores *= seen_mask == False # no-cython-lint
# Calculate the gradients for the parameters of the output layer.
# The weight gemm is (nS, nO) @ (nS, nH).T
output.inc_grad("b", d_scores.sum(axis=0))
output.inc_grad("W", ops.gemm(d_scores, statevecs, trans1=True))
# Now calculate d_statevecs, by backproping through the output linear layer.
# This gemm is (nS, nO) @ (nO, nH)
output_W = output.get_param("W")
d_statevecs = ops.gemm(d_scores, output_W)
# Backprop through the maxout activation
d_preacts = ops.backprop_maxout(d_statevecs, which, nP)
d_preacts2f = ops.reshape2f(d_preacts, d_preacts.shape[0], nH * nP)
model.inc_grad("hidden_b", d_preacts2f.sum(axis=0))
# We don't need to backprop the summation, because we pass back the IDs instead
d_state_features = backprop_feats((d_preacts2f, ids))
d_tokvecs = ops.alloc2f(tokvecs.shape[0], tokvecs.shape[1])
ops.scatter_add(d_tokvecs, ids, d_state_features)
model.inc_grad("hidden_pad", d_tokvecs[-1])
return (backprop_tok2vec(d_tokvecs[:-1]), None)
return (list(batch), all_scores), backprop_parser
def _get_seen_mask(model: Model) -> numpy.array[bool, 1]:
mask = model.ops.xp.zeros(model.get_dim("nO"), dtype="bool")
for class_ in model.attrs.get("unseen_classes", set()):
mask[class_] = True
return mask
def _forward_precomputable_affine(model, X: Floats2d, is_train: bool):
W: Floats2d = model.get_param("hidden_W")
nF = model.get_dim("nF")
nH = model.get_dim("nH")
nP = model.get_dim("nP")
nI = model.get_dim("nI")
# The weights start out (nH * nP, nF * nI). Transpose and reshape to (nF * nH *nP, nI)
W3f = model.ops.reshape3f(W, nH * nP, nF, nI)
W3f = W3f.transpose((1, 0, 2))
W2f = model.ops.reshape2f(W3f, nF * nH * nP, nI)
assert X.shape == (X.shape[0], nI), X.shape
Yf_ = model.ops.gemm(X, W2f, trans2=True)
Yf = model.ops.reshape3f(Yf_, Yf_.shape[0], nF, nH * nP)
def backward(dY_ids: Tuple[Floats3d, Ints2d]):
# This backprop is particularly tricky, because we get back a different
# thing from what we put out. We put out an array of shape:
# (nB, nF, nH, nP), and get back:
# (nB, nH, nP) and ids (nB, nF)
# The ids tell us the values of nF, so we would have:
#
# dYf = zeros((nB, nF, nH, nP))
# for b in range(nB):
# for f in range(nF):
# dYf[b, ids[b, f]] += dY[b]
#
# However, we avoid building that array for efficiency -- and just pass
# in the indices.
dY, ids = dY_ids
dXf = model.ops.gemm(dY, W)
Xf = X[ids].reshape((ids.shape[0], -1))
dW = model.ops.gemm(dY, Xf, trans1=True)
model.inc_grad("hidden_W", dW)
return model.ops.reshape3f(dXf, dXf.shape[0], nF, nI)
return Yf, backward
def _infer_nO(Y: Optional[Tuple[List[State], List[Floats2d]]]) -> Optional[int]:
if Y is None:
return None
_, scores = Y
if len(scores) == 0:
return None
assert scores[0].shape[0] >= 1
assert len(scores[0].shape) == 2
return scores[0].shape[1]
def _lsuv_init(model: Model):
"""This is like the 'layer sequential unit variance', but instead
of taking the actual inputs, we randomly generate whitened data.
Why's this all so complicated? We have a huge number of inputs,
and the maxout unit makes guessing the dynamics tricky. Instead
we set the maxout weights to values that empirically result in
whitened outputs given whitened inputs.
"""
W = model.maybe_get_param("hidden_W")
if W is not None and W.any():
return
nF = model.get_dim("nF")
nH = model.get_dim("nH")
nP = model.get_dim("nP")
nI = model.get_dim("nI")
W = model.ops.alloc4f(nF, nH, nP, nI)
b = model.ops.alloc2f(nH, nP)
pad = model.ops.alloc4f(1, nF, nH, nP)
ops = model.ops
W = normal_init(ops, W.shape, mean=float(ops.xp.sqrt(1.0 / nF * nI)))
pad = normal_init(ops, pad.shape, mean=1.0)
model.set_param("W", W)
model.set_param("b", b)
model.set_param("pad", pad)
ids = ops.alloc_f((5000, nF), dtype="f")
ids += ops.xp.random.uniform(0, 1000, ids.shape)
ids = ops.asarray(ids, dtype="i")
tokvecs = ops.alloc_f((5000, nI), dtype="f")
tokvecs += ops.xp.random.normal(loc=0.0, scale=1.0, size=tokvecs.size).reshape(
tokvecs.shape
)
def predict(ids, tokvecs):
# nS ids. nW tokvecs. Exclude the padding array.
hiddens, _ = _forward_precomputable_affine(model, tokvecs[:-1], False)
vectors = model.ops.alloc2f(ids.shape[0], nH * nP)
# need nS vectors
hiddens = hiddens.reshape((hiddens.shape[0] * nF, nH * nP))
model.ops.scatter_add(vectors, ids.flatten(), hiddens)
vectors3f = model.ops.reshape3f(vectors, vectors.shape[0], nH, nP)
vectors3f += b
return model.ops.maxout(vectors3f)[0]
tol_var = 0.01
tol_mean = 0.01
t_max = 10
W = cast(Floats4d, model.get_param("hidden_W").copy())
b = cast(Floats2d, model.get_param("hidden_b").copy())
for t_i in range(t_max):
acts1 = predict(ids, tokvecs)
var = model.ops.xp.var(acts1)
mean = model.ops.xp.mean(acts1)
if abs(var - 1.0) >= tol_var:
W /= model.ops.xp.sqrt(var)
model.set_param("hidden_W", W)
elif abs(mean) >= tol_mean:
b -= mean
model.set_param("hidden_b", b)
else:
break
return model
cdef WeightsC _get_c_weights(model, const float* feats, np.ndarray[np.npy_bool, ndim=1] seen_mask) except *:
output = model.get_ref("output")
cdef np.ndarray hidden_b = model.get_param("hidden_b")
cdef np.ndarray output_W = output.get_param("W")
cdef np.ndarray output_b = output.get_param("b")
cdef WeightsC weights
weights.feat_weights = feats
weights.feat_bias = <const float*>hidden_b.data
weights.hidden_weights = <const float *> output_W.data
weights.hidden_bias = <const float *> output_b.data
weights.seen_mask = <const int8_t*> seen_mask.data
return weights
cdef SizesC _get_c_sizes(model, int batch_size, int tokens) except *:
cdef SizesC sizes
sizes.states = batch_size
sizes.classes = model.get_dim("nO")
sizes.hiddens = model.get_dim("nH")
sizes.pieces = model.get_dim("nP")
sizes.feats = model.get_dim("nF")
sizes.embed_width = model.get_dim("nI")
sizes.tokens = tokens
return sizes
cdef ActivationsC _alloc_activations(SizesC n) nogil:
cdef ActivationsC A
memset(&A, 0, sizeof(A))
_resize_activations(&A, n)
return A
cdef void _free_activations(const ActivationsC* A) nogil:
free(A.token_ids)
free(A.unmaxed)
free(A.hiddens)
free(A.is_valid)
cdef void _resize_activations(ActivationsC* A, SizesC n) nogil:
if n.states <= A._max_size:
A._curr_size = n.states
return
if A._max_size == 0:
A.token_ids = <int*>calloc(n.states * n.feats, sizeof(A.token_ids[0]))
A.unmaxed = <float*>calloc(n.states * n.hiddens * n.pieces, sizeof(A.unmaxed[0]))
A.hiddens = <float*>calloc(n.states * n.hiddens, sizeof(A.hiddens[0]))
A.is_valid = <int*>calloc(n.states * n.classes, sizeof(A.is_valid[0]))
A._max_size = n.states
else:
A.token_ids = <int*>realloc(A.token_ids,
n.states * n.feats * sizeof(A.token_ids[0]))
A.unmaxed = <float*>realloc(A.unmaxed,
n.states * n.hiddens * n.pieces * sizeof(A.unmaxed[0]))
A.hiddens = <float*>realloc(A.hiddens,
n.states * n.hiddens * sizeof(A.hiddens[0]))
A.is_valid = <int*>realloc(A.is_valid,
n.states * n.classes * sizeof(A.is_valid[0]))
A._max_size = n.states
A._curr_size = n.states
cdef void _predict_states(CBlas cblas, ActivationsC* A, float* scores, StateC** states, const WeightsC* W, SizesC n) nogil:
_resize_activations(A, n)
for i in range(n.states):
states[i].set_context_tokens(&A.token_ids[i*n.feats], n.feats)
memset(A.unmaxed, 0, n.states * n.hiddens * n.pieces * sizeof(float))
_sum_state_features(cblas, A.unmaxed, W.feat_weights, A.token_ids, n)
for i in range(n.states):
saxpy(cblas)(n.hiddens * n.pieces, 1., W.feat_bias, 1, &A.unmaxed[i*n.hiddens*n.pieces], 1)
for j in range(n.hiddens):
index = i * n.hiddens * n.pieces + j * n.pieces
which = arg_max(&A.unmaxed[index], n.pieces)
A.hiddens[i*n.hiddens + j] = A.unmaxed[index + which]
if W.hidden_weights == NULL:
memcpy(scores, A.hiddens, n.states * n.classes * sizeof(float))
else:
# Compute hidden-to-output
sgemm(cblas)(False, True, n.states, n.classes, n.hiddens,
1.0, <const float *>A.hiddens, n.hiddens,
<const float *>W.hidden_weights, n.hiddens,
0.0, scores, n.classes)
# Add bias
for i in range(n.states):
saxpy(cblas)(n.classes, 1., W.hidden_bias, 1, &scores[i*n.classes], 1)
# Set unseen classes to minimum value
i = 0
min_ = scores[0]
for i in range(1, n.states * n.classes):
if scores[i] < min_:
min_ = scores[i]
for i in range(n.states):
for j in range(n.classes):
if W.seen_mask[j]:
scores[i*n.classes+j] = min_
cdef void _sum_state_features(CBlas cblas, float* output, const float* cached,
const int* token_ids, SizesC n) nogil:
cdef int idx, b, f
cdef const float* feature
cdef int B = n.states
cdef int O = n.hiddens * n.pieces # no-cython-lint
cdef int F = n.feats
cdef int T = n.tokens
padding = cached + (T * F * O)
cdef int id_stride = F*O
cdef float one = 1.
for b in range(B):
for f in range(F):
if token_ids[f] < 0:
feature = &padding[f*O]
else:
idx = token_ids[f] * id_stride + f*O
feature = &cached[idx]
saxpy(cblas)(O, one, <const float*>feature, 1, &output[b*O], 1)
token_ids += F

View File

@ -6,8 +6,6 @@ from ...typedefs cimport class_t
from .transition_system cimport Transition, TransitionSystem
from ...errors import Errors
from .batch cimport Batch
from .search cimport Beam, MaxViolation
from .search import MaxViolation
@ -29,7 +27,7 @@ cdef int check_final_state(void* _state, void* extra_args) except -1:
return state.is_final()
cdef class BeamBatch(Batch):
cdef class BeamBatch(object):
cdef public TransitionSystem moves
cdef public object states
cdef public object docs

View File

@ -1,2 +0,0 @@
cdef int arg_max(const float* scores, const int n_classes) nogil
cdef int arg_max_if_valid(const float* scores, const int* is_valid, int n) nogil

View File

@ -1,22 +0,0 @@
# cython: infer_types=True
cdef inline int arg_max(const float* scores, const int n_classes) nogil:
if n_classes == 2:
return 0 if scores[0] > scores[1] else 1
cdef int i
cdef int best = 0
cdef float mode = scores[0]
for i in range(1, n_classes):
if scores[i] > mode:
mode = scores[i]
best = i
return best
cdef inline int arg_max_if_valid(const float* scores, const int* is_valid, int n) nogil:
cdef int best = -1
for i in range(n):
if is_valid[i] >= 1:
if best == -1 or scores[i] > scores[best]:
best = i
return best

View File

@ -7,6 +7,8 @@ from libc.string cimport memcpy, memset
from libcpp.set cimport set
from libcpp.unordered_map cimport unordered_map
from libcpp.vector cimport vector
from libcpp.set cimport set
from cpython.exc cimport PyErr_CheckSignals, PyErr_SetFromErrno
from murmurhash.mrmr cimport hash64
from ...attrs cimport IS_SPACE
@ -26,7 +28,7 @@ cdef struct ArcC:
cdef cppclass StateC:
vector[int] _heads
int* _heads
const TokenC* _sent
vector[int] _stack
vector[int] _rebuffer
@ -34,34 +36,31 @@ cdef cppclass StateC:
unordered_map[int, vector[ArcC]] _left_arcs
unordered_map[int, vector[ArcC]] _right_arcs
vector[libcpp.bool] _unshiftable
vector[int] history
set[int] _sent_starts
TokenC _empty_token
int length
int offset
int _b_i
__init__(const TokenC* sent, int length) nogil except +:
this._heads.resize(length, -1)
this._unshiftable.resize(length, False)
# Reserve memory ahead of time to minimize allocations during parsing.
# The initial capacity set here ideally reflects the expected average-case/majority usage.
cdef int init_capacity = 32
this._stack.reserve(init_capacity)
this._rebuffer.reserve(init_capacity)
this._ents.reserve(init_capacity)
this._left_arcs.reserve(init_capacity)
this._right_arcs.reserve(init_capacity)
this.history.reserve(init_capacity)
__init__(const TokenC* sent, int length) nogil:
this._sent = sent
this._heads = <int*>calloc(length, sizeof(int))
if not (this._sent and this._heads):
with gil:
PyErr_SetFromErrno(MemoryError)
PyErr_CheckSignals()
this.offset = 0
this.length = length
this._b_i = 0
for i in range(length):
this._heads[i] = -1
this._unshiftable.push_back(0)
memset(&this._empty_token, 0, sizeof(TokenC))
this._empty_token.lex = &EMPTY_LEXEME
__dealloc__():
free(this._heads)
void set_context_tokens(int* ids, int n) nogil:
cdef int i, j
if n == 1:
@ -134,20 +133,19 @@ cdef cppclass StateC:
ids[i] = -1
int S(int i) nogil const:
cdef int stack_size = this._stack.size()
if i >= stack_size or i < 0:
if i >= this._stack.size():
return -1
else:
return this._stack[stack_size - (i+1)]
elif i < 0:
return -1
return this._stack.at(this._stack.size() - (i+1))
int B(int i) nogil const:
cdef int buf_size = this._rebuffer.size()
if i < 0:
return -1
elif i < buf_size:
return this._rebuffer[buf_size - (i+1)]
elif i < this._rebuffer.size():
return this._rebuffer.at(this._rebuffer.size() - (i+1))
else:
b_i = this._b_i + (i - buf_size)
b_i = this._b_i + (i - this._rebuffer.size())
if b_i >= this.length:
return -1
else:
@ -246,7 +244,7 @@ cdef cppclass StateC:
return 0
elif this._sent[word].sent_start == 1:
return 1
elif this._sent_starts.const_find(word) != this._sent_starts.const_end():
elif this._sent_starts.count(word) >= 1:
return 1
else:
return 0
@ -330,7 +328,7 @@ cdef cppclass StateC:
if item >= this._unshiftable.size():
return 0
else:
return this._unshiftable[item]
return this._unshiftable.at(item)
void set_reshiftable(int item) nogil:
if item < this._unshiftable.size():
@ -350,9 +348,6 @@ cdef cppclass StateC:
this._heads[child] = head
void map_del_arc(unordered_map[int, vector[ArcC]]* heads_arcs, int h_i, int c_i) nogil:
cdef vector[ArcC]* arcs
cdef ArcC* arc
arcs_it = heads_arcs.find(h_i)
if arcs_it == heads_arcs.end():
return
@ -361,12 +356,12 @@ cdef cppclass StateC:
if arcs.size() == 0:
return
arc = &arcs.back()
arc = arcs.back()
if arc.head == h_i and arc.child == c_i:
arcs.pop_back()
else:
for i in range(arcs.size()-1):
arc = &deref(arcs)[i]
arc = arcs.at(i)
if arc.head == h_i and arc.child == c_i:
arc.head = -1
arc.child = -1
@ -406,11 +401,10 @@ cdef cppclass StateC:
this._rebuffer = src._rebuffer
this._sent_starts = src._sent_starts
this._unshiftable = src._unshiftable
this._heads = src._heads
memcpy(this._heads, src._heads, this.length * sizeof(this._heads[0]))
this._ents = src._ents
this._left_arcs = src._left_arcs
this._right_arcs = src._right_arcs
this._b_i = src._b_i
this.offset = src.offset
this._empty_token = src._empty_token
this.history = src.history

View File

@ -779,8 +779,6 @@ cdef class ArcEager(TransitionSystem):
return list(arcs)
def has_gold(self, Example eg, start=0, end=None):
if end is not None and end < 0:
end = None
for word in eg.y[start:end]:
if word.dep != 0:
return True
@ -865,7 +863,6 @@ cdef class ArcEager(TransitionSystem):
state.print_state()
)))
action.do(state.c, action.label)
state.c.history.push_back(i)
break
else:
failed = False

View File

@ -1,2 +0,0 @@
cdef class Batch:
pass

View File

@ -1,52 +0,0 @@
from typing import Any
TransitionSystem = Any # TODO
cdef class Batch:
def advance(self, scores):
raise NotImplementedError
def get_states(self):
raise NotImplementedError
@property
def is_done(self):
raise NotImplementedError
def get_unfinished_states(self):
raise NotImplementedError
def __getitem__(self, i):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
class GreedyBatch(Batch):
def __init__(self, moves: TransitionSystem, states, golds):
self._moves = moves
self._states = states
self._next_states = [s for s in states if not s.is_final()]
def advance(self, scores):
self._next_states = self._moves.transition_states(self._next_states, scores)
def advance_with_actions(self, actions):
self._next_states = self._moves.apply_actions(self._next_states, actions)
def get_states(self):
return self._states
@property
def is_done(self):
return all(s.is_final() for s in self._states)
def get_unfinished_states(self):
return [st for st in self._states if not st.is_final()]
def __getitem__(self, i):
return self._states[i]
def __len__(self):
return len(self._states)

View File

@ -156,7 +156,7 @@ cdef class BiluoPushDown(TransitionSystem):
if token.ent_type:
labels.add(token.ent_type_)
return labels
def move_name(self, int move, attr_t label):
if move == OUT:
return 'O'
@ -306,8 +306,6 @@ cdef class BiluoPushDown(TransitionSystem):
for span in eg.y.spans.get(neg_key, []):
if span.start >= start and span.end <= end:
return True
if end is not None and end < 0:
end = None
for word in eg.y[start:end]:
if word.ent_iob != 0:
return True
@ -643,7 +641,7 @@ cdef class Unit:
cost += 1
break
return cost
cdef class Out:
@staticmethod

View File

@ -19,10 +19,6 @@ cdef class StateClass:
if self._borrowed != 1:
del self.c
@property
def history(self):
return list(self.c.history)
@property
def stack(self):
return [self.S(i) for i in range(self.c.stack_depth())]
@ -179,6 +175,3 @@ cdef class StateClass:
def clone(self, StateClass src):
self.c.clone(src.c)
def set_context_tokens(self, int[:, :] output, int row, int n_feats):
self.c.set_context_tokens(&output[row, 0], n_feats)

View File

@ -57,10 +57,3 @@ cdef class TransitionSystem:
cdef int set_costs(self, int* is_valid, weight_t* costs,
const StateC* state, gold) except -1
cdef void c_apply_actions(TransitionSystem moves, StateC** states, const int* actions,
int batch_size) nogil
cdef void c_transition_batch(TransitionSystem moves, StateC** states, const float* scores,
int nr_class, int batch_size) nogil

View File

@ -2,16 +2,12 @@
from __future__ import print_function
from cymem.cymem cimport Pool
from libc.stdlib cimport calloc, free
from libcpp.vector cimport vector
from collections import Counter
import srsly
from ...structs cimport TokenC
from ...typedefs cimport attr_t, weight_t
from ._parser_utils cimport arg_max_if_valid
from .stateclass cimport StateClass
from ... import util
@ -76,18 +72,7 @@ cdef class TransitionSystem:
offset += len(doc)
return states
def follow_history(self, doc, history):
cdef int clas
cdef StateClass state = StateClass(doc)
for clas in history:
action = self.c[clas]
action.do(state.c, action.label)
state.c.history.push_back(clas)
return state
def get_oracle_sequence(self, Example example, _debug=False):
if not self.has_gold(example):
return []
states, golds, _ = self.init_gold_batch([example])
if not states:
return []
@ -99,8 +84,6 @@ cdef class TransitionSystem:
return self.get_oracle_sequence_from_state(state, gold)
def get_oracle_sequence_from_state(self, StateClass state, gold, _debug=None):
if state.is_final():
return []
cdef Pool mem = Pool()
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
assert self.n_moves > 0
@ -126,7 +109,6 @@ cdef class TransitionSystem:
"S0 head?", str(state.has_head(state.S(0))),
)))
action.do(state.c, action.label)
state.c.history.push_back(i)
break
else:
if _debug:
@ -154,28 +136,6 @@ cdef class TransitionSystem:
raise ValueError(Errors.E170.format(name=name))
action = self.lookup_transition(name)
action.do(state.c, action.label)
state.c.history.push_back(action.clas)
def apply_actions(self, states, const int[::1] actions):
assert len(states) == actions.shape[0]
cdef StateClass state
cdef vector[StateC*] c_states
c_states.resize(len(states))
cdef int i
for (i, state) in enumerate(states):
c_states[i] = state.c
c_apply_actions(self, &c_states[0], &actions[0], actions.shape[0])
return [state for state in states if not state.c.is_final()]
def transition_states(self, states, float[:, ::1] scores):
assert len(states) == scores.shape[0]
cdef StateClass state
cdef float* c_scores = &scores[0, 0]
cdef vector[StateC*] c_states
for state in states:
c_states.push_back(state.c)
c_transition_batch(self, &c_states[0], c_scores, scores.shape[1], scores.shape[0])
return [state for state in states if not state.c.is_final()]
cdef Transition lookup_transition(self, object name) except *:
raise NotImplementedError
@ -288,34 +248,3 @@ cdef class TransitionSystem:
self.cfg.update(msg['cfg'])
self.initialize_actions(labels)
return self
cdef void c_apply_actions(TransitionSystem moves, StateC** states, const int* actions,
int batch_size) nogil:
cdef int i
cdef Transition action
cdef StateC* state
for i in range(batch_size):
state = states[i]
action = moves.c[actions[i]]
action.do(state, action.label)
state.history.push_back(action.clas)
cdef void c_transition_batch(TransitionSystem moves, StateC** states, const float* scores,
int nr_class, int batch_size) nogil:
is_valid = <int*>calloc(moves.n_moves, sizeof(int))
cdef int i, guess
cdef Transition action
for i in range(batch_size):
moves.set_valid(is_valid, states[i])
guess = arg_max_if_valid(&scores[i*nr_class], is_valid, nr_class)
if guess == -1:
# This shouldn't happen, but it's hard to raise an error here,
# and we don't want to infinite loop. So, force to end state.
states[i].force_final()
else:
action = moves.c[guess]
action.do(states[i], action.label)
states[i].history.push_back(guess)
free(is_valid)

View File

@ -4,6 +4,10 @@ from typing import Callable, Optional
from thinc.api import Config, Model
from ._parser_internals.transition_system import TransitionSystem
from .transition_parser cimport Parser
from ._parser_internals.arc_eager cimport ArcEager
from ..language import Language
from ..scorer import Scorer
from ..training import remove_bilu_prefix
@ -17,11 +21,12 @@ from .transition_parser import Parser
default_model_config = """
[model]
@architectures = "spacy.TransitionBasedParser.v3"
@architectures = "spacy.TransitionBasedParser.v2"
state_type = "parser"
extra_state_tokens = false
hidden_width = 64
maxout_pieces = 2
use_upper = true
[model.tok2vec]
@architectures = "spacy.HashEmbedCNN.v2"
@ -121,7 +126,6 @@ def make_parser(
scorer=scorer,
)
@Language.factory(
"beam_parser",
assigns=["token.dep", "token.head", "token.is_sent_start", "doc.sents"],
@ -227,7 +231,6 @@ def parser_score(examples, **kwargs):
DOCS: https://spacy.io/api/dependencyparser#score
"""
def has_sents(doc):
return doc.has_annotation("SENT_START")
@ -235,11 +238,8 @@ def parser_score(examples, **kwargs):
dep = getattr(token, attr)
dep = token.vocab.strings.as_string(dep).lower()
return dep
results = {}
results.update(
Scorer.score_spans(examples, "sents", has_annotation=has_sents, **kwargs)
)
results.update(Scorer.score_spans(examples, "sents", has_annotation=has_sents, **kwargs))
kwargs.setdefault("getter", dep_getter)
kwargs.setdefault("ignore_labels", ("p", "punct"))
results.update(Scorer.score_deps(examples, "dep", **kwargs))
@ -252,12 +252,11 @@ def make_parser_scorer():
return parser_score
class DependencyParser(Parser):
cdef class DependencyParser(Parser):
"""Pipeline component for dependency parsing.
DOCS: https://spacy.io/api/dependencyparser
"""
TransitionSystem = ArcEager
def __init__(
@ -277,7 +276,8 @@ class DependencyParser(Parser):
incorrect_spans_key=None,
scorer=parser_score,
):
"""Create a DependencyParser."""
"""Create a DependencyParser.
"""
super().__init__(
vocab,
model,

View File

@ -10,15 +10,22 @@ from ..training import remove_bilu_prefix
from ..util import registry
from ._parser_internals.ner import BiluoPushDown
from ._parser_internals.transition_system import TransitionSystem
from .transition_parser import Parser
from .transition_parser cimport Parser
from ._parser_internals.ner cimport BiluoPushDown
from ..language import Language
from ..scorer import get_ner_prf, PRFScore
from ..util import registry
from ..training import remove_bilu_prefix
default_model_config = """
[model]
@architectures = "spacy.TransitionBasedParser.v3"
@architectures = "spacy.TransitionBasedParser.v2"
state_type = "ner"
extra_state_tokens = false
hidden_width = 64
maxout_pieces = 2
use_upper = true
[model.tok2vec]
@architectures = "spacy.HashEmbedCNN.v2"
@ -43,12 +50,8 @@ DEFAULT_NER_MODEL = Config().from_str(default_model_config)["model"]
"incorrect_spans_key": None,
"scorer": {"@scorers": "spacy.ner_scorer.v1"},
},
default_score_weights={
"ents_f": 1.0,
"ents_p": 0.0,
"ents_r": 0.0,
"ents_per_type": None,
},
default_score_weights={"ents_f": 1.0, "ents_p": 0.0, "ents_r": 0.0, "ents_per_type": None},
)
def make_ner(
nlp: Language,
@ -101,7 +104,6 @@ def make_ner(
scorer=scorer,
)
@Language.factory(
"beam_ner",
assigns=["doc.ents", "token.ent_iob", "token.ent_type"],
@ -115,12 +117,7 @@ def make_ner(
"incorrect_spans_key": None,
"scorer": None,
},
default_score_weights={
"ents_f": 1.0,
"ents_p": 0.0,
"ents_r": 0.0,
"ents_per_type": None,
},
default_score_weights={"ents_f": 1.0, "ents_p": 0.0, "ents_r": 0.0, "ents_per_type": None},
)
def make_beam_ner(
nlp: Language,
@ -194,12 +191,11 @@ def make_ner_scorer():
return ner_score
class EntityRecognizer(Parser):
cdef class EntityRecognizer(Parser):
"""Pipeline component for named entity recognition.
DOCS: https://spacy.io/api/entityrecognizer
"""
TransitionSystem = BiluoPushDown
def __init__(
@ -217,14 +213,15 @@ class EntityRecognizer(Parser):
incorrect_spans_key=None,
scorer=ner_score,
):
"""Create an EntityRecognizer."""
"""Create an EntityRecognizer.
"""
super().__init__(
vocab,
model,
name,
moves,
update_with_oracle_cut_size=update_with_oracle_cut_size,
min_action_freq=1, # not relevant for NER
min_action_freq=1, # not relevant for NER
learn_tokens=False, # not relevant for NER
beam_width=beam_width,
beam_density=beam_density,

View File

@ -0,0 +1,21 @@
from cymem.cymem cimport Pool
from thinc.backends.cblas cimport CBlas
from ..vocab cimport Vocab
from .trainable_pipe cimport TrainablePipe
from ._parser_internals.transition_system cimport Transition, TransitionSystem
from ._parser_internals._state cimport StateC
from ..ml.parser_model cimport WeightsC, ActivationsC, SizesC
cdef class Parser(TrainablePipe):
cdef public object _rehearsal_model
cdef readonly TransitionSystem moves
cdef public object _multitasks
cdef object _cpu_ops
cdef void _parseC(self, CBlas cblas, StateC** states,
WeightsC weights, SizesC sizes) nogil
cdef void c_transition_batch(self, StateC** states, const float* scores,
int nr_class, int batch_size) nogil

View File

@ -1,15 +1,20 @@
# cython: infer_types=True, cdivision=True, boundscheck=False, binding=True
from __future__ import print_function
from typing import Dict, Iterable, List, Optional, Tuple
cimport numpy as np
from cymem.cymem cimport Pool
import contextlib
import random
cimport numpy as np
from itertools import islice
from libcpp.vector cimport vector
from libc.string cimport memset, memcpy
from libc.stdlib cimport calloc, free
import random
import srsly
from thinc.api import get_ops, set_dropout_rate, CupyOps, NumpyOps, Optimizer
from thinc.api import chain, softmax_activation, use_ops
from thinc.legacy import LegacySequenceCategoricalCrossentropy
from thinc.types import Floats2d
import numpy.random
import numpy
import numpy.random
import srsly
@ -23,7 +28,16 @@ from thinc.api import (
)
from thinc.types import Floats2d, Ints1d
from ..ml.tb_framework import TransitionModelInputs
from ._parser_internals.stateclass cimport StateClass
from ._parser_internals.search cimport Beam
from ..ml.parser_model cimport alloc_activations, free_activations
from ..ml.parser_model cimport predict_states, arg_max_if_valid
from ..ml.parser_model cimport WeightsC, ActivationsC, SizesC, cpu_log_loss
from ..ml.parser_model cimport get_c_weights, get_c_sizes
from ..tokens.doc cimport Doc
from .trainable_pipe import TrainablePipe
from ._parser_internals cimport _beam_utils
from ._parser_internals import _beam_utils
from ..tokens.doc cimport Doc
from ..typedefs cimport weight_t
@ -46,7 +60,7 @@ from ._parser_internals import _beam_utils
NUMPY_OPS = NumpyOps()
class Parser(TrainablePipe):
cdef class Parser(TrainablePipe):
"""
Base class of the DependencyParser and EntityRecognizer.
"""
@ -146,9 +160,8 @@ class Parser(TrainablePipe):
@property
def move_names(self):
names = []
cdef TransitionSystem moves = self.moves
for i in range(self.moves.n_moves):
name = self.moves.move_name(moves.c[i].move, moves.c[i].label)
name = self.moves.move_name(self.moves.c[i].move, self.moves.c[i].label)
# Explicitly removing the internal "U-" token used for blocking entities
if name != "U-":
names.append(name)
@ -255,6 +268,15 @@ class Parser(TrainablePipe):
student_docs = [eg.predicted for eg in examples]
teacher_step_model = teacher_pipe.model.predict([eg.reference for eg in examples])
student_step_model, backprop_tok2vec = self.model.begin_update(student_docs)
# Add softmax activation, so that we can compute student losses
# with cross-entropy loss.
with use_ops("numpy"):
teacher_model = chain(teacher_step_model, softmax_activation())
student_model = chain(student_step_model, softmax_activation())
max_moves = self.cfg["update_with_oracle_cut_size"]
if max_moves >= 1:
# Chop sequences into lengths of this many words, to make the
@ -262,38 +284,50 @@ class Parser(TrainablePipe):
# sequence, we use the teacher's predictions as the gold
# standard.
max_moves = int(random.uniform(max_moves // 2, max_moves * 2))
states = self._init_batch(teacher_pipe, student_docs, max_moves)
states = self._init_batch(teacher_step_model, student_docs, max_moves)
else:
states = self.moves.init_batch(student_docs)
# We distill as follows: 1. we first let the student predict transition
# sequences (and the corresponding transition probabilities); (2) we
# let the teacher follow the student's predicted transition sequences
# to obtain the teacher's transition probabilities; (3) we compute the
# gradients of the student's transition distributions relative to the
# teacher's distributions.
loss = 0.0
n_moves = 0
while states:
# We do distillation as follows: (1) for every state, we compute the
# transition softmax distributions: (2) we backpropagate the error of
# the student (compared to the teacher) into the student model; (3)
# for all states, we move to the next state using the student's
# predictions.
teacher_scores = teacher_model.predict(states)
student_scores, backprop = student_model.begin_update(states)
state_loss, d_scores = self.get_teacher_student_loss(teacher_scores, student_scores)
backprop(d_scores)
loss += state_loss
self.transition_states(states, student_scores)
states = [state for state in states if not state.is_final()]
student_inputs = TransitionModelInputs(docs=student_docs, moves=self.moves,
max_moves=max_moves)
(student_states, student_scores), backprop_scores = self.model.begin_update(student_inputs)
actions = states2actions(student_states)
teacher_inputs = TransitionModelInputs(docs=[eg.reference for eg in examples],
moves=self.moves, actions=actions)
(_, teacher_scores) = teacher_pipe.model.predict(teacher_inputs)
# Stop when we reach the maximum number of moves, otherwise we start
# to process the remainder of cut sequences again.
if max_moves >= 1 and n_moves >= max_moves:
break
n_moves += 1
loss, d_scores = self.get_teacher_student_loss(teacher_scores, student_scores)
backprop_scores((student_states, d_scores))
backprop_tok2vec(student_docs)
if sgd is not None:
self.finish_update(sgd)
losses[self.name] += loss
del backprop
del backprop_tok2vec
teacher_step_model.clear_memory()
student_step_model.clear_memory()
del teacher_model
del student_model
return losses
def get_teacher_student_loss(
self, teacher_scores: List[Floats2d], student_scores: List[Floats2d],
normalize: bool = False,
self, teacher_scores: List[Floats2d], student_scores: List[Floats2d]
) -> Tuple[float, List[Floats2d]]:
"""Calculate the loss and its gradient for a batch of student
scores, relative to teacher scores.
@ -305,28 +339,10 @@ class Parser(TrainablePipe):
DOCS: https://spacy.io/api/dependencyparser#get_teacher_student_loss
"""
# We can't easily hook up a softmax layer in the parsing model, since
# the get_loss does additional masking. So, we could apply softmax
# manually here and use Thinc's cross-entropy loss. But it's a bit
# suboptimal, since we can have a lot of states that would result in
# many kernel launches. Futhermore the parsing model's backprop expects
# a XP array, so we'd have to concat the softmaxes anyway. So, like
# the get_loss implementation, we'll compute the loss and gradients
# ourselves.
teacher_scores = self.model.ops.softmax(self.model.ops.xp.vstack(teacher_scores),
axis=-1, inplace=True)
student_scores = self.model.ops.softmax(self.model.ops.xp.vstack(student_scores),
axis=-1, inplace=True)
assert teacher_scores.shape == student_scores.shape
d_scores = student_scores - teacher_scores
if normalize:
d_scores /= d_scores.shape[0]
loss = (d_scores**2).sum() / d_scores.size
loss_func = LegacySequenceCategoricalCrossentropy(normalize=False)
d_scores, loss = loss_func(student_scores, teacher_scores)
if self.model.ops.xp.isnan(loss):
raise ValueError(Errors.E910.format(name=self.name))
return float(loss), d_scores
def init_multitask_objectives(self, get_examples, pipeline, **cfg):
@ -349,6 +365,9 @@ class Parser(TrainablePipe):
stream: The sequence of documents to process.
batch_size (int): Number of documents to accumulate into a working set.
error_handler (Callable[[str, List[Doc], Exception], Any]): Function that
deals with a failing batch of documents. The default function just reraises
the exception.
YIELDS (Doc): Documents, in order.
"""
@ -369,29 +388,78 @@ class Parser(TrainablePipe):
def predict(self, docs):
if isinstance(docs, Doc):
docs = [docs]
self._ensure_labels_are_added(docs)
if not any(len(doc) for doc in docs):
result = self.moves.init_batch(docs)
return result
with _change_attrs(self.model, beam_width=self.cfg["beam_width"], beam_density=self.cfg["beam_density"]):
inputs = TransitionModelInputs(docs=docs, moves=self.moves)
states_or_beams, _ = self.model.predict(inputs)
return states_or_beams
if self.cfg["beam_width"] == 1:
return self.greedy_parse(docs, drop=0.0)
else:
return self.beam_parse(
docs,
drop=0.0,
beam_width=self.cfg["beam_width"],
beam_density=self.cfg["beam_density"]
)
def greedy_parse(self, docs, drop=0.):
self._resize()
cdef vector[StateC*] states
cdef StateClass state
cdef CBlas cblas = self._cpu_ops.cblas()
self._ensure_labels_are_added(docs)
with _change_attrs(self.model, beam_width=1):
inputs = TransitionModelInputs(docs=docs, moves=self.moves)
states, _ = self.model.predict(inputs)
return states
set_dropout_rate(self.model, drop)
batch = self.moves.init_batch(docs)
model = self.model.predict(docs)
weights = get_c_weights(model)
for state in batch:
if not state.is_final():
states.push_back(state.c)
sizes = get_c_sizes(model, states.size())
with nogil:
self._parseC(cblas, &states[0], weights, sizes)
model.clear_memory()
del model
return batch
def beam_parse(self, docs, int beam_width, float drop=0., beam_density=0.):
cdef Beam beam
cdef Doc doc
self._ensure_labels_are_added(docs)
with _change_attrs(self.model, beam_width=self.cfg["beam_width"], beam_density=self.cfg["beam_density"]):
inputs = TransitionModelInputs(docs=docs, moves=self.moves)
beams, _ = self.model.predict(inputs)
return beams
batch = _beam_utils.BeamBatch(
self.moves,
self.moves.init_batch(docs),
None,
beam_width,
density=beam_density
)
model = self.model.predict(docs)
while not batch.is_done:
states = batch.get_unfinished_states()
if not states:
break
scores = model.predict(states)
batch.advance(scores)
model.clear_memory()
del model
return list(batch)
cdef void _parseC(self, CBlas cblas, StateC** states,
WeightsC weights, SizesC sizes) nogil:
cdef int i, j
cdef vector[StateC*] unfinished
cdef ActivationsC activations = alloc_activations(sizes)
while sizes.states >= 1:
predict_states(cblas, &activations, states, &weights, sizes)
# Validate actions, argmax, take action.
self.c_transition_batch(states,
activations.scores, sizes.classes, sizes.states)
for i in range(sizes.states):
if not states[i].is_final():
unfinished.push_back(states[i])
for i in range(unfinished.size()):
states[i] = unfinished[i]
sizes.states = unfinished.size()
unfinished.clear()
free_activations(&activations)
def set_annotations(self, docs, states_or_beams):
cdef StateClass state
@ -402,6 +470,35 @@ class Parser(TrainablePipe):
for hook in self.postprocesses:
hook(doc)
def transition_states(self, states, float[:, ::1] scores):
cdef StateClass state
cdef float* c_scores = &scores[0, 0]
cdef vector[StateC*] c_states
for state in states:
c_states.push_back(state.c)
self.c_transition_batch(&c_states[0], c_scores, scores.shape[1], scores.shape[0])
return [state for state in states if not state.c.is_final()]
cdef void c_transition_batch(self, StateC** states, const float* scores,
int nr_class, int batch_size) nogil:
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
with gil:
assert self.moves.n_moves > 0, Errors.E924.format(name=self.name)
is_valid = <int*>calloc(self.moves.n_moves, sizeof(int))
cdef int i, guess
cdef Transition action
for i in range(batch_size):
self.moves.set_valid(is_valid, states[i])
guess = arg_max_if_valid(&scores[i*nr_class], is_valid, nr_class)
if guess == -1:
# This shouldn't happen, but it's hard to raise an error here,
# and we don't want to infinite loop. So, force to end state.
states[i].force_final()
else:
action = self.moves.c[guess]
action.do(states[i], action.label)
free(is_valid)
def update(self, examples, *, drop=0., sgd=None, losses=None):
if losses is None:
losses = {}
@ -412,98 +509,66 @@ class Parser(TrainablePipe):
)
for multitask in self._multitasks:
multitask.update(examples, drop=drop, sgd=sgd)
# We need to take care to act on the whole batch, because we might be
# getting vectors via a listener.
n_examples = len([eg for eg in examples if self.moves.has_gold(eg)])
if n_examples == 0:
return losses
set_dropout_rate(self.model, drop)
docs = [eg.x for eg in examples if len(eg.x)]
# The probability we use beam update, instead of falling back to
# a greedy update
beam_update_prob = self.cfg["beam_update_prob"]
if self.cfg['beam_width'] >= 2 and numpy.random.random() < beam_update_prob:
return self.update_beam(
examples,
beam_width=self.cfg["beam_width"],
sgd=sgd,
losses=losses,
beam_density=self.cfg["beam_density"]
)
max_moves = self.cfg["update_with_oracle_cut_size"]
if max_moves >= 1:
# Chop sequences into lengths of this many words, to make the
# batch uniform length.
max_moves = int(random.uniform(max(max_moves // 2, 1), max_moves * 2))
init_states, gold_states, _ = self._init_gold_batch(
max_moves = int(random.uniform(max_moves // 2, max_moves * 2))
states, golds, _ = self._init_gold_batch(
examples,
max_length=max_moves
)
else:
init_states, gold_states, _ = self.moves.init_gold_batch(examples)
inputs = TransitionModelInputs(docs=docs,
moves=self.moves,
max_moves=max_moves,
states=[state.copy() for state in init_states])
(pred_states, scores), backprop_scores = self.model.begin_update(inputs)
if sum(s.shape[0] for s in scores) == 0:
states, golds, _ = self.moves.init_gold_batch(examples)
if not states:
return losses
d_scores = self.get_loss((gold_states, init_states, pred_states, scores),
examples, max_moves)
backprop_scores((pred_states, d_scores))
if sgd not in (None, False):
self.finish_update(sgd)
losses[self.name] += float((d_scores**2).sum())
# Ugh, this is annoying. If we're working on GPU, we want to free the
# memory ASAP. It seems that Python doesn't necessarily get around to
# removing these in time if we don't explicitly delete? It's confusing.
del backprop_scores
return losses
def get_loss(self, states_scores, examples, max_moves):
gold_states, init_states, pred_states, scores = states_scores
scores = self.model.ops.xp.vstack(scores)
costs = self._get_costs_from_histories(
examples,
gold_states,
init_states,
[list(state.history) for state in pred_states],
max_moves
)
xp = get_array_module(scores)
best_costs = costs.min(axis=1, keepdims=True)
gscores = scores.copy()
min_score = scores.min() - 1000
assert costs.shape == scores.shape, (costs.shape, scores.shape)
gscores[costs > best_costs] = min_score
max_ = scores.max(axis=1, keepdims=True)
gmax = gscores.max(axis=1, keepdims=True)
exp_scores = xp.exp(scores - max_)
exp_gscores = xp.exp(gscores - gmax)
Z = exp_scores.sum(axis=1, keepdims=True)
gZ = exp_gscores.sum(axis=1, keepdims=True)
d_scores = exp_scores / Z
d_scores -= (costs <= best_costs) * (exp_gscores / gZ)
return d_scores
def _get_costs_from_histories(self, examples, gold_states, init_states, histories, max_moves):
cdef TransitionSystem moves = self.moves
cdef StateClass state
cdef int clas
cdef int nO = moves.n_moves
cdef Pool mem = Pool()
cdef np.ndarray costs_i
is_valid = <int*>mem.alloc(nO, sizeof(int))
batch = list(zip(init_states, histories, gold_states))
model, backprop_tok2vec = self.model.begin_update([eg.x for eg in examples])
all_states = list(states)
states_golds = list(zip(states, golds))
n_moves = 0
output = []
while batch:
costs = numpy.zeros((len(batch), nO), dtype="f")
for i, (state, history, gold) in enumerate(batch):
costs_i = costs[i]
clas = history.pop(0)
moves.set_costs(is_valid, <weight_t*>costs_i.data, state.c, gold)
action = moves.c[clas]
action.do(state.c, action.label)
state.c.history.push_back(clas)
output.append(costs)
batch = [(s, h, g) for s, h, g in batch if len(h) != 0]
if n_moves >= max_moves >= 1:
while states_golds:
states, golds = zip(*states_golds)
scores, backprop = model.begin_update(states)
d_scores = self.get_batch_loss(states, golds, scores, losses)
# Note that the gradient isn't normalized by the batch size
# here, because our "samples" are really the states...But we
# can't normalize by the number of states either, as then we'd
# be getting smaller gradients for states in long sequences.
backprop(d_scores)
# Follow the predicted action
self.transition_states(states, scores)
states_golds = [(s, g) for (s, g) in zip(states, golds) if not s.is_final()]
if max_moves >= 1 and n_moves >= max_moves:
break
n_moves += 1
return self.model.ops.xp.vstack(output)
backprop_tok2vec(golds)
if sgd not in (None, False):
self.finish_update(sgd)
# Ugh, this is annoying. If we're working on GPU, we want to free the
# memory ASAP. It seems that Python doesn't necessarily get around to
# removing these in time if we don't explicitly delete? It's confusing.
del backprop
del backprop_tok2vec
model.clear_memory()
del model
return losses
def rehearse(self, examples, sgd=None, losses=None, **cfg):
"""Perform a "rehearsal" update, to prevent catastrophic forgetting."""
@ -514,9 +579,10 @@ class Parser(TrainablePipe):
multitask.rehearse(examples, losses=losses, sgd=sgd)
if self._rehearsal_model is None:
return None
losses.setdefault(self.name, 0.0)
losses.setdefault(self.name, 0.)
validate_examples(examples, "Parser.rehearse")
docs = [eg.predicted for eg in examples]
states = self.moves.init_batch(docs)
# This is pretty dirty, but the NER can resize itself in init_batch,
# if labels are missing. We therefore have to check whether we need to
# expand our model output.
@ -524,33 +590,85 @@ class Parser(TrainablePipe):
# Prepare the stepwise model, and get the callback for finishing the batch
set_dropout_rate(self._rehearsal_model, 0.0)
set_dropout_rate(self.model, 0.0)
student_inputs = TransitionModelInputs(docs=docs, moves=self.moves)
(student_states, student_scores), backprop_scores = self.model.begin_update(student_inputs)
actions = states2actions(student_states)
teacher_inputs = TransitionModelInputs(docs=docs, moves=self.moves, actions=actions)
_, teacher_scores = self._rehearsal_model.predict(teacher_inputs)
loss, d_scores = self.get_teacher_student_loss(teacher_scores, student_scores, normalize=True)
teacher_scores = self.model.ops.xp.vstack(teacher_scores)
student_scores = self.model.ops.xp.vstack(student_scores)
assert teacher_scores.shape == student_scores.shape
d_scores = (student_scores - teacher_scores) / teacher_scores.shape[0]
# If all weights for an output are 0 in the original model, don't
# supervise that output. This allows us to add classes.
loss = (d_scores**2).sum() / d_scores.size
backprop_scores((student_states, d_scores))
tutor, _ = self._rehearsal_model.begin_update(docs)
model, backprop_tok2vec = self.model.begin_update(docs)
n_scores = 0.
loss = 0.
while states:
targets, _ = tutor.begin_update(states)
guesses, backprop = model.begin_update(states)
d_scores = (guesses - targets) / targets.shape[0]
# If all weights for an output are 0 in the original model, don't
# supervise that output. This allows us to add classes.
loss += (d_scores**2).sum()
backprop(d_scores)
# Follow the predicted action
self.transition_states(states, guesses)
states = [state for state in states if not state.is_final()]
n_scores += d_scores.size
# Do the backprop
backprop_tok2vec(docs)
if sgd is not None:
self.finish_update(sgd)
losses[self.name] += loss
losses[self.name] += loss / n_scores
del backprop
del backprop_tok2vec
model.clear_memory()
tutor.clear_memory()
del model
del tutor
return losses
def update_beam(self, examples, *, beam_width, drop=0.,
sgd=None, losses=None, beam_density=0.0):
raise NotImplementedError
def update_beam(self, examples, *, beam_width,
drop=0., sgd=None, losses=None, beam_density=0.0):
states, golds, _ = self.moves.init_gold_batch(examples)
if not states:
return losses
# Prepare the stepwise model, and get the callback for finishing the batch
model, backprop_tok2vec = self.model.begin_update(
[eg.predicted for eg in examples])
loss = _beam_utils.update_beam(
self.moves,
states,
golds,
model,
beam_width,
beam_density=beam_density,
)
losses[self.name] += loss
backprop_tok2vec(golds)
if sgd is not None:
self.finish_update(sgd)
def get_batch_loss(self, states, golds, float[:, ::1] scores, losses):
cdef StateClass state
cdef Pool mem = Pool()
cdef int i
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
assert self.moves.n_moves > 0, Errors.E924.format(name=self.name)
is_valid = <int*>mem.alloc(self.moves.n_moves, sizeof(int))
costs = <float*>mem.alloc(self.moves.n_moves, sizeof(float))
cdef np.ndarray d_scores = numpy.zeros((len(states), self.moves.n_moves),
dtype='f', order='C')
c_d_scores = <float*>d_scores.data
unseen_classes = self.model.attrs["unseen_classes"]
for i, (state, gold) in enumerate(zip(states, golds)):
memset(is_valid, 0, self.moves.n_moves * sizeof(int))
memset(costs, 0, self.moves.n_moves * sizeof(float))
self.moves.set_costs(is_valid, costs, state.c, gold)
for j in range(self.moves.n_moves):
if costs[j] <= 0.0 and j in unseen_classes:
unseen_classes.remove(j)
cpu_log_loss(c_d_scores,
costs, is_valid, &scores[i, 0], d_scores.shape[1])
c_d_scores += d_scores.shape[1]
# Note that we don't normalize this. See comment in update() for why.
if losses is not None:
losses.setdefault(self.name, 0.)
losses[self.name] += (d_scores**2).sum()
return d_scores
def set_output(self, nO):
self.model.attrs["resize_output"](self.model, nO)
@ -589,7 +707,7 @@ class Parser(TrainablePipe):
for example in islice(get_examples(), 10):
doc_sample.append(example.predicted)
assert len(doc_sample) > 0, Errors.E923.format(name=self.name)
self.model.initialize((doc_sample, self.moves))
self.model.initialize(doc_sample)
if nlp is not None:
self.init_multitask_objectives(get_examples, nlp.pipeline)
@ -682,27 +800,26 @@ class Parser(TrainablePipe):
def _init_gold_batch(self, examples, max_length):
"""Make a square batch, of length equal to the shortest transition
sequence or a cap. A long doc will get multiple states. Let's say we
have a doc of length 2*N, where N is the shortest doc. We'll make
two states, one representing long_doc[:N], and another representing
long_doc[N:]."""
sequence or a cap. A long
doc will get multiple states. Let's say we have a doc of length 2*N,
where N is the shortest doc. We'll make two states, one representing
long_doc[:N], and another representing long_doc[N:]."""
cdef:
StateClass start_state
StateClass state
Transition action
TransitionSystem moves = self.moves
all_states = moves.init_batch([eg.predicted for eg in examples])
all_states = self.moves.init_batch([eg.predicted for eg in examples])
states = []
golds = []
to_cut = []
for state, eg in zip(all_states, examples):
if moves.has_gold(eg) and not state.is_final():
gold = moves.init_gold(state, eg)
if self.moves.has_gold(eg) and not state.is_final():
gold = self.moves.init_gold(state, eg)
if len(eg.x) < max_length:
states.append(state)
golds.append(gold)
else:
oracle_actions = moves.get_oracle_sequence_from_state(
oracle_actions = self.moves.get_oracle_sequence_from_state(
state.copy(), gold)
to_cut.append((eg, state, gold, oracle_actions))
if not to_cut:
@ -712,52 +829,13 @@ class Parser(TrainablePipe):
for i in range(0, len(oracle_actions), max_length):
start_state = state.copy()
for clas in oracle_actions[i:i+max_length]:
action = moves.c[clas]
action = self.moves.c[clas]
action.do(state.c, action.label)
if state.is_final():
break
if moves.has_gold(eg, start_state.B(0), state.B(0)):
if self.moves.has_gold(eg, start_state.B(0), state.B(0)):
states.append(start_state)
golds.append(gold)
if state.is_final():
break
return states, golds, max_length
@contextlib.contextmanager
def _change_attrs(model, **kwargs):
"""Temporarily modify a thinc model's attributes."""
unset = object()
old_attrs = {}
for key, value in kwargs.items():
old_attrs[key] = model.attrs.get(key, unset)
model.attrs[key] = value
yield model
for key, value in old_attrs.items():
if value is unset:
model.attrs.pop(key)
else:
model.attrs[key] = value
def states2actions(states: List[StateClass]) -> List[Ints1d]:
cdef int step
cdef StateClass state
cdef StateC* c_state
actions = []
while True:
step = len(actions)
step_actions = []
for state in states:
c_state = state.c
if step < c_state.history.size():
step_actions.append(c_state.history[step])
# We are done if we have exhausted all histories.
if len(step_actions) == 0:
break
actions.append(numpy.array(step_actions, dtype="i"))
return actions

View File

@ -17,6 +17,7 @@ from spacy.pipeline.ner import DEFAULT_NER_MODEL
from spacy.tokens import Doc, Span
from spacy.training import Example, iob_to_biluo, split_bilu_label
from spacy.vocab import Vocab
import logging
from ..util import make_tempdir
@ -413,7 +414,7 @@ def test_train_empty():
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
ner = nlp.add_pipe("ner", last=True)
ner.add_label("PERSON")
nlp.initialize(get_examples=lambda: train_examples)
nlp.initialize()
for itn in range(2):
losses = {}
batches = util.minibatch(train_examples, size=8)
@ -540,11 +541,11 @@ def test_block_ner():
assert [token.ent_type_ for token in doc] == expected_types
def test_overfitting_IO():
fix_random_seed(1)
@pytest.mark.parametrize("use_upper", [True, False])
def test_overfitting_IO(use_upper):
# Simple test to try and quickly overfit the NER component
nlp = English()
ner = nlp.add_pipe("ner", config={"model": {}})
ner = nlp.add_pipe("ner", config={"model": {"use_upper": use_upper}})
train_examples = []
for text, annotations in TRAIN_DATA:
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
@ -576,6 +577,7 @@ def test_overfitting_IO():
assert ents2[0].label_ == "LOC"
# Ensure that the predictions are still the same, even after adding a new label
ner2 = nlp2.get_pipe("ner")
assert ner2.model.attrs["has_upper"] == use_upper
ner2.add_label("RANDOM_NEW_LABEL")
doc3 = nlp2(test_text)
ents3 = doc3.ents

View File

@ -1,6 +1,3 @@
import itertools
import numpy
import pytest
from numpy.testing import assert_equal
from thinc.api import Adam, fix_random_seed
@ -62,8 +59,6 @@ PARTIAL_DATA = [
),
]
PARSERS = ["parser"] # TODO: Test beam_parser when ready
eps = 0.1
@ -176,57 +171,6 @@ def test_parser_parse_one_word_sentence(en_vocab, en_parser, words):
assert doc[0].dep != 0
def test_parser_apply_actions(en_vocab, en_parser):
words = ["I", "ate", "pizza"]
words2 = ["Eat", "more", "pizza", "!"]
doc1 = Doc(en_vocab, words=words)
doc2 = Doc(en_vocab, words=words2)
docs = [doc1, doc2]
moves = en_parser.moves
moves.add_action(0, "")
moves.add_action(1, "")
moves.add_action(2, "nsubj")
moves.add_action(3, "obj")
moves.add_action(2, "amod")
actions = [
numpy.array([0, 0], dtype="i"),
numpy.array([2, 0], dtype="i"),
numpy.array([0, 4], dtype="i"),
numpy.array([3, 3], dtype="i"),
numpy.array([1, 1], dtype="i"),
numpy.array([1, 1], dtype="i"),
numpy.array([0], dtype="i"),
numpy.array([1], dtype="i"),
]
states = moves.init_batch(docs)
active_states = states
for step_actions in actions:
active_states = moves.apply_actions(active_states, step_actions)
assert len(active_states) == 0
for state, doc in zip(states, docs):
moves.set_annotations(state, doc)
assert docs[0][0].head.i == 1
assert docs[0][0].dep_ == "nsubj"
assert docs[0][1].head.i == 1
assert docs[0][1].dep_ == "ROOT"
assert docs[0][2].head.i == 1
assert docs[0][2].dep_ == "obj"
assert docs[1][0].head.i == 0
assert docs[1][0].dep_ == "ROOT"
assert docs[1][1].head.i == 2
assert docs[1][1].dep_ == "amod"
assert docs[1][2].head.i == 0
assert docs[1][2].dep_ == "obj"
@pytest.mark.skip(
reason="The step_through API was removed (but should be brought back)"
)
@ -375,7 +319,7 @@ def test_parser_constructor(en_vocab):
DependencyParser(en_vocab, model)
@pytest.mark.parametrize("pipe_name", PARSERS)
@pytest.mark.parametrize("pipe_name", ["parser", "beam_parser"])
def test_incomplete_data(pipe_name):
# Test that the parser works with incomplete information
nlp = English()
@ -401,15 +345,11 @@ def test_incomplete_data(pipe_name):
assert doc[2].head.i == 1
@pytest.mark.parametrize(
"pipe_name,max_moves", itertools.product(PARSERS, [0, 1, 5, 100])
)
def test_overfitting_IO(pipe_name, max_moves):
fix_random_seed(0)
@pytest.mark.parametrize("pipe_name", ["parser", "beam_parser"])
def test_overfitting_IO(pipe_name):
# Simple test to try and quickly overfit the dependency parser (normal or beam)
nlp = English()
parser = nlp.add_pipe(pipe_name)
parser.cfg["update_with_oracle_cut_size"] = max_moves
train_examples = []
for text, annotations in TRAIN_DATA:
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
@ -511,12 +451,10 @@ def test_distill():
@pytest.mark.parametrize(
"parser_config",
[
# TODO: re-enable after we have a spacy-legacy release for v4. See
# https://github.com/explosion/spacy-legacy/pull/36
#({"@architectures": "spacy.TransitionBasedParser.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "state_type": "parser", "extra_state_tokens": False, "hidden_width": 64, "maxout_pieces": 2, "use_upper": True}),
# TransitionBasedParser V1
({"@architectures": "spacy.TransitionBasedParser.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "state_type": "parser", "extra_state_tokens": False, "hidden_width": 64, "maxout_pieces": 2, "use_upper": True}),
# TransitionBasedParser V2
({"@architectures": "spacy.TransitionBasedParser.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "state_type": "parser", "extra_state_tokens": False, "hidden_width": 64, "maxout_pieces": 2, "use_upper": True}),
({"@architectures": "spacy.TransitionBasedParser.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "state_type": "parser", "extra_state_tokens": False, "hidden_width": 64, "maxout_pieces": 2, "use_upper": False}),
({"@architectures": "spacy.TransitionBasedParser.v3", "tok2vec": DEFAULT_TOK2VEC_MODEL, "state_type": "parser", "extra_state_tokens": False, "hidden_width": 64, "maxout_pieces": 2}),
],
)
# fmt: on

View File

@ -384,7 +384,7 @@ cfg_string_multi = """
factory = "ner"
[components.ner.model]
@architectures = "spacy.TransitionBasedParser.v3"
@architectures = "spacy.TransitionBasedParser.v2"
[components.ner.model.tok2vec]
@architectures = "spacy.Tok2VecListener.v1"

View File

@ -189,11 +189,33 @@ width = ${components.tok2vec.model.width}
parser_config_string_upper = """
[model]
@architectures = "spacy.TransitionBasedParser.v3"
@architectures = "spacy.TransitionBasedParser.v2"
state_type = "parser"
extra_state_tokens = false
hidden_width = 66
maxout_pieces = 2
use_upper = true
[model.tok2vec]
@architectures = "spacy.HashEmbedCNN.v1"
pretrained_vectors = null
width = 333
depth = 4
embed_size = 5555
window_size = 1
maxout_pieces = 7
subword_features = false
"""
parser_config_string_no_upper = """
[model]
@architectures = "spacy.TransitionBasedParser.v2"
state_type = "parser"
extra_state_tokens = false
hidden_width = 66
maxout_pieces = 2
use_upper = false
[model.tok2vec]
@architectures = "spacy.HashEmbedCNN.v1"
@ -224,6 +246,7 @@ def my_parser():
extra_state_tokens=True,
hidden_width=65,
maxout_pieces=5,
use_upper=True,
)
return parser
@ -337,16 +360,15 @@ def test_serialize_custom_nlp():
nlp.to_disk(d)
nlp2 = spacy.load(d)
model = nlp2.get_pipe("parser").model
assert model.get_ref("tok2vec") is not None
assert model.has_param("hidden_W")
assert model.has_param("hidden_b")
output = model.get_ref("output")
assert output is not None
assert output.has_param("W")
assert output.has_param("b")
model.get_ref("tok2vec")
# check that we have the correct settings, not the default ones
assert model.get_ref("upper").get_dim("nI") == 65
assert model.get_ref("lower").get_dim("nI") == 65
@pytest.mark.parametrize("parser_config_string", [parser_config_string_upper])
@pytest.mark.parametrize(
"parser_config_string", [parser_config_string_upper, parser_config_string_no_upper]
)
def test_serialize_parser(parser_config_string):
"""Create a non-default parser config to check nlp serializes it correctly"""
nlp = English()
@ -359,13 +381,11 @@ def test_serialize_parser(parser_config_string):
nlp.to_disk(d)
nlp2 = spacy.load(d)
model = nlp2.get_pipe("parser").model
assert model.get_ref("tok2vec") is not None
assert model.has_param("hidden_W")
assert model.has_param("hidden_b")
output = model.get_ref("output")
assert output is not None
assert output.has_param("b")
assert output.has_param("W")
model.get_ref("tok2vec")
# check that we have the correct settings, not the default ones
if model.attrs["has_upper"]:
assert model.get_ref("upper").get_dim("nI") == 66
assert model.get_ref("lower").get_dim("nI") == 66
def test_config_nlp_roundtrip():
@ -561,7 +581,9 @@ def test_config_auto_fill_extra_fields():
load_model_from_config(nlp.config)
@pytest.mark.parametrize("parser_config_string", [parser_config_string_upper])
@pytest.mark.parametrize(
"parser_config_string", [parser_config_string_upper, parser_config_string_no_upper]
)
def test_config_validate_literal(parser_config_string):
nlp = English()
config = Config().from_str(parser_config_string)

View File

@ -1,19 +1,22 @@
import ctypes
import os
from pathlib import Path
import pytest
from pydantic import ValidationError
from thinc.api import (
Config,
ConfigValidationError,
CupyOps,
MPSOps,
NumpyOps,
Optimizer,
get_current_ops,
set_current_ops,
)
try:
from pydantic.v1 import ValidationError
except ImportError:
from pydantic import ValidationError # type: ignore
from spacy.about import __version__ as spacy_version
from spacy import util
from spacy import prefer_gpu, require_gpu, require_cpu
from spacy.ml._precomputable_affine import PrecomputableAffine
from spacy.ml._precomputable_affine import _backprop_precomputable_affine_padding
from spacy.util import dot_to_object, SimpleFrozenList, import_file
from spacy.util import to_ternary_int, find_available_port
from thinc.api import Config, Optimizer, ConfigValidationError
from thinc.api import get_current_ops, set_current_ops, NumpyOps, CupyOps, MPSOps
from thinc.compat import has_cupy_gpu, has_torch_mps_gpu
from spacy import prefer_gpu, require_cpu, require_gpu, util
@ -92,6 +95,34 @@ def test_util_get_package_path(package):
assert isinstance(path, Path)
def test_PrecomputableAffine(nO=4, nI=5, nF=3, nP=2):
model = PrecomputableAffine(nO=nO, nI=nI, nF=nF, nP=nP).initialize()
assert model.get_param("W").shape == (nF, nO, nP, nI)
tensor = model.ops.alloc((10, nI))
Y, get_dX = model.begin_update(tensor)
assert Y.shape == (tensor.shape[0] + 1, nF, nO, nP)
dY = model.ops.alloc((15, nO, nP))
ids = model.ops.alloc((15, nF))
ids[1, 2] = -1
dY[1] = 1
assert not model.has_grad("pad")
d_pad = _backprop_precomputable_affine_padding(model, dY, ids)
assert d_pad[0, 2, 0, 0] == 1.0
ids.fill(0.0)
dY.fill(0.0)
dY[0] = 0
ids[1, 2] = 0
ids[1, 1] = -1
ids[1, 0] = -1
dY[1] = 1
ids[2, 0] = -1
dY[2] = 5
d_pad = _backprop_precomputable_affine_padding(model, dY, ids)
assert d_pad[0, 0, 0, 0] == 6
assert d_pad[0, 1, 0, 0] == 1
assert d_pad[0, 2, 0, 0] == 0
def test_prefer_gpu():
current_ops = get_current_ops()
if has_cupy_gpu:

View File

@ -1,5 +1,5 @@
from collections.abc import Iterable as IterableInstance
import warnings
import numpy
from murmurhash.mrmr cimport hash64

View File

@ -553,17 +553,18 @@ for a Tok2Vec layer.
## Parser & NER architectures {id="parser"}
### spacy.TransitionBasedParser.v3 {id="TransitionBasedParser",source="spacy/ml/models/parser.py"}
### spacy.TransitionBasedParser.v2 {id="TransitionBasedParser",source="spacy/ml/models/parser.py"}
> #### Example Config
>
> ```ini
> [model]
> @architectures = "spacy.TransitionBasedParser.v3"
> @architectures = "spacy.TransitionBasedParser.v2"
> state_type = "ner"
> extra_state_tokens = false
> hidden_width = 64
> maxout_pieces = 2
> use_upper = true
>
> [model.tok2vec]
> @architectures = "spacy.HashEmbedCNN.v2"
@ -593,22 +594,23 @@ consists of either two or three subnetworks:
state representation. If not present, the output from the lower model is used
as action scores directly.
| Name | Description |
| -------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `tok2vec` | Subnetwork to map tokens into vector representations. ~~Model[List[Doc], List[Floats2d]]~~ |
| `state_type` | Which task to extract features for. Possible values are "ner" and "parser". ~~str~~ |
| `extra_state_tokens` | Whether to use an expanded feature set when extracting the state tokens. Slightly slower, but sometimes improves accuracy slightly. Defaults to `False`. ~~bool~~ |
| `hidden_width` | The width of the hidden layer. ~~int~~ |
| `maxout_pieces` | How many pieces to use in the state prediction layer. Recommended values are `1`, `2` or `3`. ~~int~~ |
| `nO` | The number of actions the model will predict between. Usually inferred from data at the beginning of training, or loaded from disk. ~~int~~ |
| **CREATES** | The model using the architecture. ~~Model[List[Docs], List[List[Floats2d]]]~~ |
| Name | Description |
| -------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `tok2vec` | Subnetwork to map tokens into vector representations. ~~Model[List[Doc], List[Floats2d]]~~ |
| `state_type` | Which task to extract features for. Possible values are "ner" and "parser". ~~str~~ |
| `extra_state_tokens` | Whether to use an expanded feature set when extracting the state tokens. Slightly slower, but sometimes improves accuracy slightly. Defaults to `False`. ~~bool~~ |
| `hidden_width` | The width of the hidden layer. ~~int~~ |
| `maxout_pieces` | How many pieces to use in the state prediction layer. Recommended values are `1`, `2` or `3`. If `1`, the maxout non-linearity is replaced with a [`Relu`](https://thinc.ai/docs/api-layers#relu) non-linearity if `use_upper` is `True`, and no non-linearity if `False`. ~~int~~ |
| `use_upper` | Whether to use an additional hidden layer after the state vector in order to predict the action scores. It is recommended to set this to `False` for large pretrained models such as transformers, and `True` for smaller networks. The upper layer is computed on CPU, which becomes a bottleneck on larger GPU-based models, where it's also less necessary. ~~bool~~ |
| `nO` | The number of actions the model will predict between. Usually inferred from data at the beginning of training, or loaded from disk. ~~int~~ |
| **CREATES** | The model using the architecture. ~~Model[List[Docs], List[List[Floats2d]]]~~ |
<Accordion title="spacy.TransitionBasedParser.v1 definition" spaced>
[TransitionBasedParser.v1](/api/legacy#TransitionBasedParser_v1) had the exact
same signature, but the `use_upper` argument was `True` by default.
</Accordion>
</Accordion>
## Tagging architectures {id="tagger",source="spacy/ml/models/tagger.py"}

View File

@ -361,7 +361,7 @@ Module spacy.language
File /path/to/spacy/language.py (line 64)
[components.ner.model]
Registry @architectures
Name spacy.TransitionBasedParser.v3
Name spacy.TransitionBasedParser.v1
Module spacy.ml.models.parser
File /path/to/spacy/ml/models/parser.py (line 11)
[components.ner.model.tok2vec]
@ -371,7 +371,7 @@ Module spacy.ml.models.tok2vec
File /path/to/spacy/ml/models/tok2vec.py (line 16)
[components.parser.model]
Registry @architectures
Name spacy.TransitionBasedParser.v3
Name spacy.TransitionBasedParser.v1
Module spacy.ml.models.parser
File /path/to/spacy/ml/models/parser.py (line 11)
[components.parser.model.tok2vec]
@ -696,7 +696,7 @@ scorer = {"@scorers":"spacy.ner_scorer.v1"}
update_with_oracle_cut_size = 100
[components.ner.model]
@architectures = "spacy.TransitionBasedParser.v3"
@architectures = "spacy.TransitionBasedParser.v2"
state_type = "ner"
extra_state_tokens = false
- hidden_width = 64
@ -719,7 +719,7 @@ scorer = {"@scorers":"spacy.parser_scorer.v1"}
update_with_oracle_cut_size = 100
[components.parser.model]
@architectures = "spacy.TransitionBasedParser.v3"
@architectures = "spacy.TransitionBasedParser.v2"
state_type = "parser"
extra_state_tokens = false
hidden_width = 128

View File

@ -225,7 +225,7 @@ the others, but may not be as accurate, especially if texts are short.
### spacy.TransitionBasedParser.v1 {id="TransitionBasedParser_v1"}
Identical to
[`spacy.TransitionBasedParser.v3`](/api/architectures#TransitionBasedParser)
[`spacy.TransitionBasedParser.v2`](/api/architectures#TransitionBasedParser)
except the `use_upper` was set to `True` by default.
## Layers {id="layers"}

View File

@ -140,7 +140,7 @@ factory = "tok2vec"
factory = "ner"
[components.ner.model]
@architectures = "spacy.TransitionBasedParser.v3"
@architectures = "spacy.TransitionBasedParser.v1"
[components.ner.model.tok2vec]
@architectures = "spacy.Tok2VecListener.v1"
@ -156,7 +156,7 @@ same. This makes them fully independent and doesn't require an upstream
factory = "ner"
[components.ner.model]
@architectures = "spacy.TransitionBasedParser.v3"
@architectures = "spacy.TransitionBasedParser.v1"
[components.ner.model.tok2vec]
@architectures = "spacy.Tok2Vec.v2"
@ -472,7 +472,7 @@ sneakily delegates to the `Transformer` pipeline component.
factory = "ner"
[nlp.pipeline.ner.model]
@architectures = "spacy.TransitionBasedParser.v3"
@architectures = "spacy.TransitionBasedParser.v1"
state_type = "ner"
extra_state_tokens = false
hidden_width = 128