mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-12 09:12:21 +03:00
New progress on parser model refactor
This commit is contained in:
parent
267ffb5605
commit
de8c88babb
|
@ -208,50 +208,41 @@ cdef int arg_max_if_valid(const weight_t* scores, const int* is_valid, int n) no
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ParserStepModel(Model):
|
def ParserStepModel(
|
||||||
def __init__(self, docs, layers, *, has_upper, unseen_classes=None, train=True,
|
tokvecs: Floats2d,
|
||||||
dropout=0.1):
|
bp_tokvecs: Callable,
|
||||||
Model.__init__(self, name="parser_step_model", forward=step_forward)
|
upper: Model[Floats2d, Floats2d],
|
||||||
self.attrs["has_upper"] = has_upper
|
dropout: float=0.1
|
||||||
self.attrs["dropout_rate"] = dropout
|
unseen_classes: Optional[List[int]]=None
|
||||||
self.tokvecs, self.bp_tokvecs = layers[0](docs, is_train=train)
|
) -> Model[Ints2d, Floats2d]:
|
||||||
if layers[1].get_dim("nP") >= 2:
|
# TODO: Keep working on replacing all of this with just 'chain'
|
||||||
activation = "maxout"
|
state2vec = precompute_hiddens(
|
||||||
elif has_upper:
|
tokvecs,
|
||||||
activation = None
|
bp_tokvecs
|
||||||
else:
|
)
|
||||||
activation = "relu"
|
class_mask = numpy.zeros((self.nO,), dtype='f')
|
||||||
self.state2vec = precompute_hiddens(
|
class_mask.fill(1)
|
||||||
len(docs),
|
if unseen_classes is not None:
|
||||||
self.tokvecs,
|
for class_ in unseen_classes:
|
||||||
layers[1],
|
class_mask[class_] = 0.
|
||||||
activation=activation,
|
|
||||||
train=train
|
|
||||||
)
|
|
||||||
if has_upper:
|
|
||||||
self.vec2scores = layers[-1]
|
|
||||||
else:
|
|
||||||
self.vec2scores = None
|
|
||||||
self._class_mask = numpy.zeros((self.nO,), dtype='f')
|
|
||||||
self._class_mask.fill(1)
|
|
||||||
if unseen_classes is not None:
|
|
||||||
for class_ in unseen_classes:
|
|
||||||
self._class_mask[class_] = 0.
|
|
||||||
|
|
||||||
@property
|
return _ParserStepModel(
|
||||||
def nO(self):
|
"ParserStep",
|
||||||
if self.attrs["has_upper"]:
|
step_forward,
|
||||||
return self.vec2scores.get_dim("nO")
|
init=None,
|
||||||
else:
|
dims={"nO": upper.get_dim("nO")},
|
||||||
return self.state2vec.get_dim("nO")
|
layers=[state2vec, upper],
|
||||||
|
attrs={
|
||||||
|
"tokvecs": tokvecs,
|
||||||
|
"bp_tokvecs": bp_tokvecs,
|
||||||
|
"dropout_rate": dropout,
|
||||||
|
"class_mask": class_mask
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
def clear_memory(self):
|
|
||||||
del self.tokvecs
|
|
||||||
del self.bp_tokvecs
|
|
||||||
del self.state2vec
|
|
||||||
del self.backprops
|
|
||||||
del self._class_mask
|
|
||||||
|
|
||||||
|
class _ParserStepModel(Model):
|
||||||
|
# TODO: Remove need for all this stuff, so we can normalize this
|
||||||
def class_is_unseen(self, class_):
|
def class_is_unseen(self, class_):
|
||||||
return self._class_mask[class_]
|
return self._class_mask[class_]
|
||||||
|
|
||||||
|
@ -274,21 +265,22 @@ class ParserStepModel(Model):
|
||||||
return ids
|
return ids
|
||||||
|
|
||||||
|
|
||||||
def step_forward(model: ParserStepModel, token_ids, is_train):
|
def step_forward(model: _ParserStepModel, token_ids, is_train):
|
||||||
vector, get_d_tokvecs = model.state2vec(token_ids, is_train)
|
# TODO: Eventually we hopefully can get rid of all of this?
|
||||||
|
# If we make the 'class_mask' thing its own layer, we can just
|
||||||
|
# have chain() here, right?
|
||||||
|
state2vec, upper = model.layers
|
||||||
|
vector, get_d_tokvecs = state2vec(token_ids, is_train)
|
||||||
mask = None
|
mask = None
|
||||||
if model.attrs["has_upper"]:
|
vec2scores = ensure_same_device(model.ops, vec2scores)
|
||||||
vec2scores = ensure_same_device(model.ops, model.vec2scores)
|
dropout_rate = model.attrs["dropout_rate"]
|
||||||
dropout_rate = model.attrs["dropout_rate"]
|
if is_train and dropout_rate > 0:
|
||||||
if is_train and dropout_rate > 0:
|
mask = model.ops.get_dropout_mask(vector.shape, dropout_rate)
|
||||||
mask = model.ops.get_dropout_mask(vector.shape, dropout_rate)
|
vector *= mask
|
||||||
vector *= mask
|
scores, get_d_vector = vec2scores(vector, is_train)
|
||||||
scores, get_d_vector = vec2scores(vector, is_train)
|
|
||||||
else:
|
|
||||||
scores = vector
|
|
||||||
get_d_vector = lambda d_scores: d_scores
|
|
||||||
# If the class is unseen, make sure its score is minimum
|
# If the class is unseen, make sure its score is minimum
|
||||||
scores[:, model._class_mask == 0] = model.ops.xp.nanmin(scores)
|
class_mask = model.attrs["class_mask"]
|
||||||
|
scores[:, class_mask == 0] = model.ops.xp.nanmin(scores)
|
||||||
|
|
||||||
def backprop_parser_step(d_scores):
|
def backprop_parser_step(d_scores):
|
||||||
# Zero vectors for unseen classes
|
# Zero vectors for unseen classes
|
||||||
|
@ -301,127 +293,45 @@ def step_forward(model: ParserStepModel, token_ids, is_train):
|
||||||
return scores, backprop_parser_step
|
return scores, backprop_parser_step
|
||||||
|
|
||||||
|
|
||||||
def ensure_same_device(ops, model):
|
def precompute_hiddens(lower_model, feat_weights: Floats3d, bp_hiddens: Callable) -> Model:
|
||||||
"""Ensure a model is on the same device as a given ops"""
|
return Model(
|
||||||
if not isinstance(model.ops, ops.__class__):
|
"precompute_hiddens",
|
||||||
model._to_ops(ops)
|
init=None,
|
||||||
return model
|
forward=_precompute_forward,
|
||||||
|
dims={
|
||||||
|
"nO": feat_weights.shape[2],
|
||||||
|
"nP": lower_model.get_dim("nP") if lower_model.has_dim("nP") else 1,
|
||||||
|
"nF": cached.shape[1]
|
||||||
|
},
|
||||||
|
ops=lower_model.ops
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
cdef class precompute_hiddens:
|
def _precomputed_forward(
|
||||||
"""Allow a model to be "primed" by pre-computing input features in bulk.
|
model: Model[Ints2d, Floats2d],
|
||||||
|
token_ids: Ints2d,
|
||||||
|
is_train: bool
|
||||||
|
) -> Tuple[Floats2d, Callable]:
|
||||||
|
nO = model.get_dim("nO")
|
||||||
|
nP = model.get_dim("nP")
|
||||||
|
bp_hiddens = model.attrs["bp_hiddens"]
|
||||||
|
feat_weights = model.attrs["feat_weights"]
|
||||||
|
bias = model.attrs["bias"]
|
||||||
|
hidden = model.ops.alloc2f(
|
||||||
|
token_ids.shape[0],
|
||||||
|
nO * nP
|
||||||
|
)
|
||||||
|
# TODO: This is probably wrong, right?
|
||||||
|
model.ops.scatter_add(
|
||||||
|
hidden,
|
||||||
|
feat_weights,
|
||||||
|
token_ids
|
||||||
|
)
|
||||||
|
statevec, mask = model.ops.maxout(hidden.reshape((-1, nO, nP)))
|
||||||
|
|
||||||
This is used for the parser, where we want to take a batch of documents,
|
def backward(d_statevec):
|
||||||
and compute vectors for each (token, position) pair. These vectors can then
|
return bp_hiddens(
|
||||||
be reused, especially for beam-search.
|
model.ops.backprop_maxout(d_statevec, mask, nP)
|
||||||
|
|
||||||
Let's say we're using 12 features for each state, e.g. word at start of
|
|
||||||
buffer, three words on stack, their children, etc. In the normal arc-eager
|
|
||||||
system, a document of length N is processed in 2*N states. This means we'll
|
|
||||||
create 2*N*12 feature vectors --- but if we pre-compute, we only need
|
|
||||||
N*12 vector computations. The saving for beam-search is much better:
|
|
||||||
if we have a beam of k, we'll normally make 2*N*12*K computations --
|
|
||||||
so we can save the factor k. This also gives a nice CPU/GPU division:
|
|
||||||
we can do all our hard maths up front, packed into large multiplications,
|
|
||||||
and do the hard-to-program parsing on the CPU.
|
|
||||||
"""
|
|
||||||
cdef readonly int nF, nO, nP
|
|
||||||
cdef public object ops
|
|
||||||
cdef readonly object bias
|
|
||||||
cdef readonly object activation
|
|
||||||
cdef readonly object _features
|
|
||||||
cdef readonly object _cached
|
|
||||||
cdef readonly object _bp_hiddens
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
batch_size,
|
|
||||||
tokvecs,
|
|
||||||
lower_model,
|
|
||||||
activation="maxout",
|
|
||||||
train=False
|
|
||||||
):
|
|
||||||
cached, bp_features = lower_model(tokvecs, train)
|
|
||||||
self.bias = lower_model.get_param("b")
|
|
||||||
self.nF = cached.shape[1]
|
|
||||||
if lower_model.has_dim("nP"):
|
|
||||||
self.nP = lower_model.get_dim("nP")
|
|
||||||
else:
|
|
||||||
self.nP = 1
|
|
||||||
self.nO = cached.shape[2]
|
|
||||||
self.ops = lower_model.ops
|
|
||||||
assert activation in (None, "relu", "maxout")
|
|
||||||
self.activation = activation
|
|
||||||
self._cached = cached
|
|
||||||
self._bp_hiddens = bp_features
|
|
||||||
|
|
||||||
cdef const float* get_feat_weights(self) except NULL:
|
|
||||||
cdef np.ndarray cached
|
|
||||||
if isinstance(self._cached, numpy.ndarray):
|
|
||||||
cached = self._cached
|
|
||||||
else:
|
|
||||||
cached = self._cached.get()
|
|
||||||
return <float*>cached.data
|
|
||||||
|
|
||||||
def has_dim(self, name):
|
|
||||||
if name == "nF":
|
|
||||||
return self.nF if self.nF is not None else True
|
|
||||||
elif name == "nP":
|
|
||||||
return self.nP if self.nP is not None else True
|
|
||||||
elif name == "nO":
|
|
||||||
return self.nO if self.nO is not None else True
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_dim(self, name):
|
|
||||||
if name == "nF":
|
|
||||||
return self.nF
|
|
||||||
elif name == "nP":
|
|
||||||
return self.nP
|
|
||||||
elif name == "nO":
|
|
||||||
return self.nO
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Dimension {name} invalid -- only nO, nF, nP")
|
|
||||||
|
|
||||||
def set_dim(self, name, value):
|
|
||||||
if name == "nF":
|
|
||||||
self.nF = value
|
|
||||||
elif name == "nP":
|
|
||||||
self.nP = value
|
|
||||||
elif name == "nO":
|
|
||||||
self.nO = value
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Dimension {name} invalid -- only nO, nF, nP")
|
|
||||||
|
|
||||||
def __call__(self, X, bint is_train):
|
|
||||||
if is_train:
|
|
||||||
return self.begin_update(X)
|
|
||||||
else:
|
|
||||||
return self.predict(X), lambda X: X
|
|
||||||
|
|
||||||
def predict(self, X):
|
|
||||||
return self.begin_update(X)[0]
|
|
||||||
|
|
||||||
def begin_update(self, token_ids):
|
|
||||||
nO = self.nO
|
|
||||||
nP = self.nP
|
|
||||||
hidden = self.model.ops.alloc2f(
|
|
||||||
token_ids.shape[0],
|
|
||||||
nO * nP
|
|
||||||
)
|
|
||||||
bp_hiddens = self._bp_hiddens
|
|
||||||
feat_weights = self.cached
|
|
||||||
self.ops.scatter_add(
|
|
||||||
hidden,
|
|
||||||
feat_weights,
|
|
||||||
token_ids
|
|
||||||
)
|
)
|
||||||
hidden += self.bias
|
|
||||||
statevec, mask = self.ops.maxout(hidden.reshape((-1, nO, nP)))
|
|
||||||
|
|
||||||
def backward(d_statevec):
|
|
||||||
return bp_hiddens(
|
|
||||||
self.ops.backprop_maxout(d_statevec, mask, nP)
|
|
||||||
)
|
|
||||||
|
|
||||||
return statevec, backward
|
return statevec, backward
|
||||||
|
|
|
@ -1,48 +1,314 @@
|
||||||
from thinc.api import Model, noop
|
from typing import List, Tuple, Any, Optional
|
||||||
from .parser_model import ParserStepModel
|
from thinc.api import Ops, Model, normal_init
|
||||||
|
from thinc.types import Floats1d, Floats2d, Floats3d, Ints2d, Floats4d
|
||||||
|
from ..tokens.doc import Doc
|
||||||
|
|
||||||
|
|
||||||
|
TransitionSystem = Any # TODO
|
||||||
|
State = Any # TODO
|
||||||
|
|
||||||
|
|
||||||
def TransitionModel(
|
def TransitionModel(
|
||||||
tok2vec, lower, upper, resize_output, dropout=0.2, unseen_classes=set()
|
*,
|
||||||
):
|
tok2vec: Model[List[Doc], List[Floats2d]],
|
||||||
"""Set up a stepwise transition-based model"""
|
state_tokens: int,
|
||||||
if upper is None:
|
hidden_width: int,
|
||||||
has_upper = False
|
maxout_pieces: int,
|
||||||
upper = noop()
|
nO: Optional[int] = None,
|
||||||
else:
|
unseen_classes=set(),
|
||||||
has_upper = True
|
) -> Model[Tuple[List[Doc], TransitionSystem], List[Tuple[State, List[Floats2d]]]]:
|
||||||
# don't define nO for this object, because we can't dynamically change it
|
"""Set up a transition-based parsing model, using a maxout hidden
|
||||||
|
layer and a linear output layer.
|
||||||
|
"""
|
||||||
return Model(
|
return Model(
|
||||||
name="parser_model",
|
name="parser_model",
|
||||||
forward=forward,
|
forward=forward,
|
||||||
dims={"nI": tok2vec.get_dim("nI") if tok2vec.has_dim("nI") else None},
|
|
||||||
layers=[tok2vec, lower, upper],
|
|
||||||
refs={"tok2vec": tok2vec, "lower": lower, "upper": upper},
|
|
||||||
init=init,
|
init=init,
|
||||||
|
layers=[tok2vec],
|
||||||
|
refs={"tok2vec": tok2vec},
|
||||||
|
params={
|
||||||
|
"lower_W": None, # Floats2d W for the hidden layer
|
||||||
|
"lower_b": None, # Floats1d bias for the hidden layer
|
||||||
|
"lower_pad": None, # Floats1d bias for the hidden layer
|
||||||
|
"upper_W": None, # Floats2d W for the output layer
|
||||||
|
"upper_b": None, # Floats1d bias for the output layer
|
||||||
|
},
|
||||||
|
dims={
|
||||||
|
"nO": None, # Output size
|
||||||
|
"nP": maxout_pieces,
|
||||||
|
"nH": hidden_width,
|
||||||
|
"nI": tok2vec.maybe_get_dim("nO"),
|
||||||
|
"nF": state_tokens,
|
||||||
|
},
|
||||||
attrs={
|
attrs={
|
||||||
"has_upper": has_upper,
|
|
||||||
"unseen_classes": set(unseen_classes),
|
"unseen_classes": set(unseen_classes),
|
||||||
"resize_output": resize_output,
|
"resize_output": resize_output,
|
||||||
|
"make_step_model": make_step_model,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def forward(model, X, is_train):
|
def make_step_model(model: Model) -> Model[List[State], Floats2d]:
|
||||||
step_model = ParserStepModel(
|
...
|
||||||
X,
|
|
||||||
model.layers,
|
|
||||||
unseen_classes=model.attrs["unseen_classes"],
|
def resize_output(model: Model) -> Model:
|
||||||
train=is_train,
|
...
|
||||||
has_upper=model.attrs["has_upper"],
|
|
||||||
|
|
||||||
|
def init(
|
||||||
|
model,
|
||||||
|
X: Optional[Tuple[List[Doc], TransitionSystem]] = None,
|
||||||
|
Y: Optional[Tuple[List[State], List[Floats2d]]] = None,
|
||||||
|
):
|
||||||
|
if X is not None:
|
||||||
|
docs, states = X
|
||||||
|
model.get_ref("tok2vec").initialize(X=docs)
|
||||||
|
inferred_nO = _infer_nO(Y)
|
||||||
|
if inferred_nO is not None:
|
||||||
|
current_nO = model.maybe_get_dim("nO")
|
||||||
|
if current_nO is None:
|
||||||
|
model.set_dim("nO", inferred_nO)
|
||||||
|
elif current_nO != inferred_nO:
|
||||||
|
model.attrs["resize_output"](model, inferred_nO)
|
||||||
|
nO = model.get_dim("nO")
|
||||||
|
nP = model.get_dim("nP")
|
||||||
|
nH = model.get_dim("nH")
|
||||||
|
nI = model.get_dim("nI")
|
||||||
|
nF = model.get_dim("nF")
|
||||||
|
ops = model.ops
|
||||||
|
|
||||||
|
Wl = ops.alloc4f(nF, nH, nP, nI)
|
||||||
|
bl = ops.alloc2f(nH, nP)
|
||||||
|
padl = ops.alloc4f(1, nF, nH, nP)
|
||||||
|
Wu = ops.alloc2f(nO, nH)
|
||||||
|
bu = ops.alloc1f(nO)
|
||||||
|
Wl = normal_init(ops, Wl.shape, mean=float(ops.xp.sqrt(1.0 / nF * nI)))
|
||||||
|
padl = normal_init(ops, padl.shape, mean=1.0)
|
||||||
|
# TODO: Experiment with whether better to initialize Wu
|
||||||
|
model.set_param("lower_W", Wl)
|
||||||
|
model.set_param("lower_b", bl)
|
||||||
|
model.set_param("lower_pad", padl)
|
||||||
|
model.set_param("upper_W", Wu)
|
||||||
|
model.set_param("upper_b", bu)
|
||||||
|
|
||||||
|
_lsuv_init(model)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(model, docs_moves, is_train):
|
||||||
|
tok2vec = model.get_ref("tok2vec")
|
||||||
|
state2scores = model.get_ref("state2scores")
|
||||||
|
# Get a reference to the parameters. We need to work with
|
||||||
|
# stable references through the forward/backward pass, to make
|
||||||
|
# sure we don't have a stale reference if there's concurrent shenanigans.
|
||||||
|
params = {name: model.get_param(name) for name in model.param_names}
|
||||||
|
ops = model.ops
|
||||||
|
docs, moves = docs_moves
|
||||||
|
states = moves.init_batch(docs)
|
||||||
|
tokvecs, backprop_tok2vec = tok2vec(docs, is_train)
|
||||||
|
feats, backprop_feats = _forward_precomputable_affine(model, tokvecs, is_train)
|
||||||
|
memory = []
|
||||||
|
all_scores = []
|
||||||
|
while states:
|
||||||
|
states, scores, memory = _step_parser(
|
||||||
|
ops, params, moves, states, feats, memory, is_train
|
||||||
|
)
|
||||||
|
all_scores.append(scores)
|
||||||
|
|
||||||
|
def backprop_parser(d_states_d_scores):
|
||||||
|
_, d_scores = d_states_d_scores
|
||||||
|
d_feats, ids = _backprop_parser_steps(ops, params, memory, d_scores)
|
||||||
|
d_tokvecs = backprop_feats((d_feats, ids))
|
||||||
|
return backprop_tok2vec(d_tokvecs), None
|
||||||
|
|
||||||
|
return (states, all_scores), backprop_parser
|
||||||
|
|
||||||
|
|
||||||
|
def _step_parser(ops, params, moves, states, feats, memory, is_train):
|
||||||
|
ids = moves.get_state_ids(states)
|
||||||
|
statevecs, which, scores = _score_ids(ops, params, ids, feats, is_train)
|
||||||
|
next_states = moves.transition_states(states, scores)
|
||||||
|
if is_train:
|
||||||
|
memory.append((ids, statevecs, which))
|
||||||
|
return next_states, scores, memory
|
||||||
|
|
||||||
|
|
||||||
|
def _score_ids(ops, params, ids, feats, is_train):
|
||||||
|
lower_pad = params["lower_pad"]
|
||||||
|
lower_b = params["lower_b"]
|
||||||
|
upper_W = params["upper_W"]
|
||||||
|
upper_b = params["upper_b"]
|
||||||
|
# During each step of the parser, we do:
|
||||||
|
# * Index into the features, to get the pre-activated vector
|
||||||
|
# for each (token, feature) and sum the feature vectors
|
||||||
|
preacts = _sum_state_features(feats, lower_pad, ids)
|
||||||
|
# * Add the bias
|
||||||
|
preacts += lower_b
|
||||||
|
# * Apply the activation (maxout)
|
||||||
|
statevecs, which = ops.maxout(preacts)
|
||||||
|
# * Multiply the state-vector by the scores weights
|
||||||
|
scores = ops.gemm(statevecs, upper_W, trans2=True)
|
||||||
|
# * Add the bias
|
||||||
|
scores += upper_b
|
||||||
|
# * Apply the is-class-unseen masking
|
||||||
|
# TODO
|
||||||
|
return statevecs, which, scores
|
||||||
|
|
||||||
|
|
||||||
|
def _sum_state_features(ops: Ops, feats: Floats3d, ids: Ints2d) -> Floats2d:
|
||||||
|
# Here's what we're trying to implement here:
|
||||||
|
#
|
||||||
|
# for i in range(ids.shape[0]):
|
||||||
|
# for j in range(ids.shape[1]):
|
||||||
|
# output[i] += feats[ids[i, j], j]
|
||||||
|
#
|
||||||
|
# Reshape the feats into 2d, to make indexing easier. Instead of getting an
|
||||||
|
# array of indices where the cell at (4, 2) needs to refer to the row at
|
||||||
|
# feats[4, 2], we'll translate the index so that it directly addresses
|
||||||
|
# feats[18]. This lets us make the indices array 1d, leading to fewer
|
||||||
|
# numpy shennanigans.
|
||||||
|
feats2d = ops.reshape2f(feats, feats.shape[0] * feats.shape[1], feats.shape[2])
|
||||||
|
# Now translate the ids. If we're looking for the row that used to be at
|
||||||
|
# (4, 1) and we have 4 features, we'll find it at (4*4)+1=17.
|
||||||
|
oob_ids = ids < 0 # Retain the -1 values
|
||||||
|
ids = ids * feats.shape[1] + ops.xp.arange(feats.shape[1])
|
||||||
|
ids[oob_ids] = -1
|
||||||
|
unsummed2d = feats2d[ops.reshape1i(ids, ids.size)]
|
||||||
|
unsummed3d = ops.reshape3f(
|
||||||
|
unsummed2d, feats.shape[0], feats.shape[1], feats.shape[2]
|
||||||
)
|
)
|
||||||
|
summed = unsummed3d.sum(axis=1) # type: ignore
|
||||||
return step_model, step_model.finish_steps
|
return summed
|
||||||
|
|
||||||
|
|
||||||
def init(model, X=None, Y=None):
|
def _process_memory(ops, memory):
|
||||||
model.get_ref("tok2vec").initialize(X=X)
|
"""Concatenate the memory buffers from each state into contiguous
|
||||||
lower = model.get_ref("lower")
|
buffers for the whole batch.
|
||||||
lower.initialize()
|
"""
|
||||||
if model.attrs["has_upper"]:
|
return [ops.xp.concatenate(*item) for item in zip(*memory)]
|
||||||
statevecs = model.ops.alloc2f(2, lower.get_dim("nO"))
|
|
||||||
model.get_ref("upper").initialize(X=statevecs)
|
|
||||||
|
def _backprop_parser_steps(model, upper_W, memory, d_scores):
|
||||||
|
# During each step of the parser, we do:
|
||||||
|
# * Index into the features, to get the pre-activated vector
|
||||||
|
# for each (token, feature)
|
||||||
|
# * Sum the feature vectors
|
||||||
|
# * Add the bias
|
||||||
|
# * Apply the activation (maxout)
|
||||||
|
# * Multiply the state-vector by the scores weights
|
||||||
|
# * Add the bias
|
||||||
|
# * Apply the is-class-unseen masking
|
||||||
|
#
|
||||||
|
# So we have to backprop through all those steps.
|
||||||
|
ids, statevecs, whiches = _process_memory(model.ops, memory)
|
||||||
|
# TODO: Unseen class masking
|
||||||
|
# Calculate the gradients for the parameters of the upper layer.
|
||||||
|
model.inc_grad("upper_b", d_scores.sum(axis=0))
|
||||||
|
model.inc_grad("upper_W", model.ops.gemm(d_scores, statevecs, trans1=True))
|
||||||
|
# Now calculate d_statevecs, by backproping through the upper linear layer.
|
||||||
|
d_statevecs = model.ops.gemm(d_scores, upper_W)
|
||||||
|
# Backprop through the maxout activation
|
||||||
|
d_preacts = model.ops.backprop_maxount(d_statevecs, whiches, model.get_dim("nP"))
|
||||||
|
# We don't need to backprop the summation, because we pass back the IDs instead
|
||||||
|
return d_preacts, ids
|
||||||
|
|
||||||
|
|
||||||
|
def _forward_precomputable_affine(model, X: Floats2d, is_train: bool):
|
||||||
|
|
||||||
|
W: Floats4d = model.get_param("lower_W")
|
||||||
|
b: Floats2d = model.get_param("lower_b")
|
||||||
|
pad: Floats4d = model.get_param("lower_pad")
|
||||||
|
nF = model.get_dim("nF")
|
||||||
|
nO = model.get_dim("nO")
|
||||||
|
nP = model.get_dim("nP")
|
||||||
|
nI = model.get_dim("nI")
|
||||||
|
Yf_ = model.ops.gemm(X, model.ops.reshape2f(W, nF * nO * nP, nI), trans2=True)
|
||||||
|
Yf = model.ops.reshape4f(Yf_, Yf_.shape[0], nF, nO, nP)
|
||||||
|
Yf = model.ops.xp.vstack((Yf, pad))
|
||||||
|
|
||||||
|
def backward(dY_ids: Tuple[Floats3d, Ints2d]):
|
||||||
|
# This backprop is particularly tricky, because we get back a different
|
||||||
|
# thing from what we put out. We put out an array of shape:
|
||||||
|
# (nB, nF, nO, nP), and get back:
|
||||||
|
# (nB, nO, nP) and ids (nB, nF)
|
||||||
|
# The ids tell us the values of nF, so we would have:
|
||||||
|
#
|
||||||
|
# dYf = zeros((nB, nF, nO, nP))
|
||||||
|
# for b in range(nB):
|
||||||
|
# for f in range(nF):
|
||||||
|
# dYf[b, ids[b, f]] += dY[b]
|
||||||
|
#
|
||||||
|
# However, we avoid building that array for efficiency -- and just pass
|
||||||
|
# in the indices.
|
||||||
|
dY, ids = dY_ids
|
||||||
|
assert dY.ndim == 3
|
||||||
|
assert dY.shape[1] == nO, dY.shape
|
||||||
|
assert dY.shape[2] == nP, dY.shape
|
||||||
|
# nB = dY.shape[0]
|
||||||
|
model.inc_grad(
|
||||||
|
"lower_pad", _backprop_precomputable_affine_padding(model, dY, ids)
|
||||||
|
)
|
||||||
|
Xf = model.ops.reshape2f(X[ids], ids.shape[0], nF * nI)
|
||||||
|
|
||||||
|
model.inc_grad("lower_b", dY.sum(axis=0)) # type: ignore
|
||||||
|
dY = model.ops.reshape2f(dY, dY.shape[0], nO * nP)
|
||||||
|
|
||||||
|
Wopfi = W.transpose((1, 2, 0, 3))
|
||||||
|
Wopfi = Wopfi.reshape((nO * nP, nF * nI))
|
||||||
|
dXf = model.ops.gemm(dY.reshape((dY.shape[0], nO * nP)), Wopfi)
|
||||||
|
|
||||||
|
dWopfi = model.ops.gemm(dY, Xf, trans1=True)
|
||||||
|
dWopfi = dWopfi.reshape((nO, nP, nF, nI))
|
||||||
|
# (o, p, f, i) --> (f, o, p, i)
|
||||||
|
dWopfi = dWopfi.transpose((2, 0, 1, 3))
|
||||||
|
model.inc_grad("W", dWopfi)
|
||||||
|
return model.ops.reshape3f(dXf, dXf.shape[0], nF, nI)
|
||||||
|
|
||||||
|
return Yf, backward
|
||||||
|
|
||||||
|
|
||||||
|
def _backprop_precomputable_affine_padding(model, dY, ids):
|
||||||
|
nB = dY.shape[0]
|
||||||
|
nF = model.get_dim("nF")
|
||||||
|
nP = model.get_dim("nP")
|
||||||
|
nO = model.get_dim("nO")
|
||||||
|
# Backprop the "padding", used as a filler for missing values.
|
||||||
|
# Values that are missing are set to -1, and each state vector could
|
||||||
|
# have multiple missing values. The padding has different values for
|
||||||
|
# different missing features. The gradient of the padding vector is:
|
||||||
|
#
|
||||||
|
# for b in range(nB):
|
||||||
|
# for f in range(nF):
|
||||||
|
# if ids[b, f] < 0:
|
||||||
|
# d_pad[f] += dY[b]
|
||||||
|
#
|
||||||
|
# Which can be rewritten as:
|
||||||
|
#
|
||||||
|
# (ids < 0).T @ dY
|
||||||
|
mask = model.ops.asarray(ids < 0, dtype="f")
|
||||||
|
d_pad = model.ops.gemm(mask, dY.reshape(nB, nO * nP), trans1=True)
|
||||||
|
return d_pad.reshape((1, nF, nO, nP))
|
||||||
|
|
||||||
|
|
||||||
|
def _infer_nO(Y: Optional[Tuple[List[State], List[Floats2d]]]) -> Optional[int]:
|
||||||
|
if Y is None:
|
||||||
|
return None
|
||||||
|
_, scores = Y
|
||||||
|
if len(scores) == 0:
|
||||||
|
return None
|
||||||
|
assert scores[0].shape[0] >= 1
|
||||||
|
assert len(scores[0].shape) == 2
|
||||||
|
return scores[0].shape[1]
|
||||||
|
|
||||||
|
|
||||||
|
def _lsuv_init(model):
|
||||||
|
"""This is like the 'layer sequential unit variance', but instead
|
||||||
|
of taking the actual inputs, we randomly generate whitened data.
|
||||||
|
|
||||||
|
Why's this all so complicated? We have a huge number of inputs,
|
||||||
|
and the maxout unit makes guessing the dynamics tricky. Instead
|
||||||
|
we set the maxout weights to values that empirically result in
|
||||||
|
whitened outputs given whitened inputs.
|
||||||
|
"""
|
||||||
|
# TODO
|
||||||
|
return None
|
||||||
|
|
|
@ -208,70 +208,11 @@ cdef class Parser(TrainablePipe):
|
||||||
# if labels are missing. We therefore have to check whether we need to
|
# if labels are missing. We therefore have to check whether we need to
|
||||||
# expand our model output.
|
# expand our model output.
|
||||||
self._resize()
|
self._resize()
|
||||||
model = self.model.predict(docs)
|
states, scores = self.model.predict((docs, self.moves))
|
||||||
batch = self.moves.init_batch(docs)
|
|
||||||
states = self._predict_states(model, batch)
|
|
||||||
model.clear_memory()
|
|
||||||
del model
|
|
||||||
return states
|
return states
|
||||||
|
|
||||||
def _predict_states(self, model, batch):
|
|
||||||
cdef vector[StateC*] states
|
|
||||||
cdef StateClass state
|
|
||||||
weights = get_c_weights(model)
|
|
||||||
for state in batch:
|
|
||||||
if not state.is_final():
|
|
||||||
states.push_back(state.c)
|
|
||||||
sizes = get_c_sizes(model, states.size())
|
|
||||||
with nogil:
|
|
||||||
self._parseC(&states[0],
|
|
||||||
weights, sizes)
|
|
||||||
return batch
|
|
||||||
|
|
||||||
def beam_parse(self, docs, int beam_width, float drop=0., beam_density=0.):
|
def beam_parse(self, docs, int beam_width, float drop=0., beam_density=0.):
|
||||||
cdef Beam beam
|
raise NotImplementedError
|
||||||
cdef Doc doc
|
|
||||||
batch = _beam_utils.BeamBatch(
|
|
||||||
self.moves,
|
|
||||||
self.moves.init_batch(docs),
|
|
||||||
None,
|
|
||||||
beam_width,
|
|
||||||
density=beam_density
|
|
||||||
)
|
|
||||||
# This is pretty dirty, but the NER can resize itself in init_batch,
|
|
||||||
# if labels are missing. We therefore have to check whether we need to
|
|
||||||
# expand our model output.
|
|
||||||
self._resize()
|
|
||||||
model = self.model.predict(docs)
|
|
||||||
while not batch.is_done:
|
|
||||||
states = batch.get_unfinished_states()
|
|
||||||
if not states:
|
|
||||||
break
|
|
||||||
scores = model.predict(states)
|
|
||||||
batch.advance(scores)
|
|
||||||
model.clear_memory()
|
|
||||||
del model
|
|
||||||
return list(batch)
|
|
||||||
|
|
||||||
cdef void _parseC(self, StateC** states,
|
|
||||||
WeightsC weights, SizesC sizes) nogil:
|
|
||||||
cdef int i, j
|
|
||||||
cdef vector[StateC*] unfinished
|
|
||||||
cdef ActivationsC activations = alloc_activations(sizes)
|
|
||||||
while sizes.states >= 1:
|
|
||||||
predict_states(&activations,
|
|
||||||
states, &weights, sizes)
|
|
||||||
# Validate actions, argmax, take action.
|
|
||||||
self.c_transition_batch(states,
|
|
||||||
activations.scores, sizes.classes, sizes.states)
|
|
||||||
for i in range(sizes.states):
|
|
||||||
if not states[i].is_final():
|
|
||||||
unfinished.push_back(states[i])
|
|
||||||
for i in range(unfinished.size()):
|
|
||||||
states[i] = unfinished[i]
|
|
||||||
sizes.states = unfinished.size()
|
|
||||||
unfinished.clear()
|
|
||||||
free_activations(&activations)
|
|
||||||
|
|
||||||
def set_annotations(self, docs, states_or_beams):
|
def set_annotations(self, docs, states_or_beams):
|
||||||
cdef StateClass state
|
cdef StateClass state
|
||||||
|
@ -283,36 +224,6 @@ cdef class Parser(TrainablePipe):
|
||||||
for hook in self.postprocesses:
|
for hook in self.postprocesses:
|
||||||
hook(doc)
|
hook(doc)
|
||||||
|
|
||||||
def transition_states(self, states, float[:, ::1] scores):
|
|
||||||
cdef StateClass state
|
|
||||||
cdef float* c_scores = &scores[0, 0]
|
|
||||||
cdef vector[StateC*] c_states
|
|
||||||
for state in states:
|
|
||||||
c_states.push_back(state.c)
|
|
||||||
self.c_transition_batch(&c_states[0], c_scores, scores.shape[1], scores.shape[0])
|
|
||||||
return [state for state in states if not state.c.is_final()]
|
|
||||||
|
|
||||||
cdef void c_transition_batch(self, StateC** states, const float* scores,
|
|
||||||
int nr_class, int batch_size) nogil:
|
|
||||||
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
|
|
||||||
with gil:
|
|
||||||
assert self.moves.n_moves > 0, Errors.E924.format(name=self.name)
|
|
||||||
is_valid = <int*>calloc(self.moves.n_moves, sizeof(int))
|
|
||||||
cdef int i, guess
|
|
||||||
cdef Transition action
|
|
||||||
for i in range(batch_size):
|
|
||||||
self.moves.set_valid(is_valid, states[i])
|
|
||||||
guess = arg_max_if_valid(&scores[i*nr_class], is_valid, nr_class)
|
|
||||||
if guess == -1:
|
|
||||||
# This shouldn't happen, but it's hard to raise an error here,
|
|
||||||
# and we don't want to infinite loop. So, force to end state.
|
|
||||||
states[i].force_final()
|
|
||||||
else:
|
|
||||||
action = self.moves.c[guess]
|
|
||||||
action.do(states[i], action.label)
|
|
||||||
states[i].history.push_back(guess)
|
|
||||||
free(is_valid)
|
|
||||||
|
|
||||||
def update(self, examples, *, drop=0., sgd=None, losses=None):
|
def update(self, examples, *, drop=0., sgd=None, losses=None):
|
||||||
cdef StateClass state
|
cdef StateClass state
|
||||||
if losses is None:
|
if losses is None:
|
||||||
|
@ -327,58 +238,48 @@ cdef class Parser(TrainablePipe):
|
||||||
if n_examples == 0:
|
if n_examples == 0:
|
||||||
return losses
|
return losses
|
||||||
set_dropout_rate(self.model, drop)
|
set_dropout_rate(self.model, drop)
|
||||||
# The probability we use beam update, instead of falling back to
|
|
||||||
# a greedy update
|
|
||||||
beam_update_prob = self.cfg["beam_update_prob"]
|
|
||||||
if self.cfg['beam_width'] >= 2 and numpy.random.random() < beam_update_prob:
|
|
||||||
return self.update_beam(
|
|
||||||
examples,
|
|
||||||
beam_width=self.cfg["beam_width"],
|
|
||||||
sgd=sgd,
|
|
||||||
losses=losses,
|
|
||||||
beam_density=self.cfg["beam_density"]
|
|
||||||
)
|
|
||||||
docs = [eg.x for eg in examples]
|
docs = [eg.x for eg in examples]
|
||||||
model, backprop_tok2vec = self.model.begin_update(docs)
|
(states, scores), backprop_scores = self.model.begin_update((docs, self.moves))
|
||||||
states = self.moves.init_batch(docs)
|
d_scores = self.get_loss((states, scores), examples)
|
||||||
self._predict_states(states)
|
backprop_scores(d_scores)
|
||||||
# I've separated the prediction from getting the batch because
|
|
||||||
# I like the idea of trying to store the histories or maybe compute
|
|
||||||
# them in another process or something. Just walking the states
|
|
||||||
# and transitioning isn't expensive anyway.
|
|
||||||
ids, costs = self._get_ids_and_costs_from_histories(
|
|
||||||
examples,
|
|
||||||
[list(state.history) for state in states]
|
|
||||||
)
|
|
||||||
scores, backprop_states = model.begin_update(ids)
|
|
||||||
d_scores = self.get_loss(scores, costs)
|
|
||||||
d_tokvecs = backprop_states(d_scores)
|
|
||||||
backprop_tok2vec(d_tokvecs)
|
|
||||||
if sgd not in (None, False):
|
if sgd not in (None, False):
|
||||||
self.finish_update(sgd)
|
self.finish_update(sgd)
|
||||||
self.set_annotations(docs, states)
|
|
||||||
losses[self.name] += (d_scores**2).sum()
|
losses[self.name] += (d_scores**2).sum()
|
||||||
# Ugh, this is annoying. If we're working on GPU, we want to free the
|
# Ugh, this is annoying. If we're working on GPU, we want to free the
|
||||||
# memory ASAP. It seems that Python doesn't necessarily get around to
|
# memory ASAP. It seems that Python doesn't necessarily get around to
|
||||||
# removing these in time if we don't explicitly delete? It's confusing.
|
# removing these in time if we don't explicitly delete? It's confusing.
|
||||||
del backprop_states
|
del backprop_scores
|
||||||
del backprop_tok2vec
|
|
||||||
model.clear_memory()
|
|
||||||
del model
|
|
||||||
return losses
|
return losses
|
||||||
|
|
||||||
def _get_ids_and_costs_from_histories(self, examples, histories):
|
def get_loss(self, states_scores, examples):
|
||||||
|
states, scores = states_scores
|
||||||
|
costs = self._get_costs_from_histories(
|
||||||
|
examples,
|
||||||
|
[list(state.history) for state in states]
|
||||||
|
)
|
||||||
|
xp = get_array_module(scores)
|
||||||
|
best_costs = costs.min(axis=1, keepdims=True)
|
||||||
|
is_gold = costs <= costs.min(axis=1, keepdims=True)
|
||||||
|
gscores = scores[is_gold]
|
||||||
|
max_ = scores.max(axis=1)
|
||||||
|
gmax = gscores.max(axis=1, keepdims=True)
|
||||||
|
exp_scores = xp.exp(scores - max_)
|
||||||
|
exp_gscores = xp.exp(gscores - gmax)
|
||||||
|
Z = exp_scores.sum(axis=1, keepdims=True)
|
||||||
|
gZ = exp_gscores.sum(axis=1, keepdims=True)
|
||||||
|
d_scores = exp_scores / Z
|
||||||
|
d_scores[is_gold] -= exp_gscores / gZ
|
||||||
|
return d_scores
|
||||||
|
|
||||||
|
def _get_costs_from_histories(self, examples, histories):
|
||||||
cdef StateClass state
|
cdef StateClass state
|
||||||
cdef int clas
|
cdef int clas
|
||||||
cdef int nF = self.model.state2vec.nF
|
cdef int nF = self.model.state2vec.nF
|
||||||
cdef int nO = self.moves.n_moves
|
cdef int nO = self.moves.n_moves
|
||||||
cdef int nS = sum([len(history) for history in histories])
|
cdef int nS = sum([len(history) for history in histories])
|
||||||
# ids and costs have one row per state in the whole batch.
|
|
||||||
cdef np.ndarray ids = numpy.zeros((nS, nF), dtype="i")
|
|
||||||
cdef np.ndarray costs = numpy.zeros((nS, nO), dtype="f")
|
cdef np.ndarray costs = numpy.zeros((nS, nO), dtype="f")
|
||||||
cdef Pool mem = Pool()
|
cdef Pool mem = Pool()
|
||||||
is_valid = <int*>mem.alloc(nO, sizeof(int))
|
is_valid = <int*>mem.alloc(nO, sizeof(int))
|
||||||
c_ids = <int*>ids.data
|
|
||||||
c_costs = <float*>costs.data
|
c_costs = <float*>costs.data
|
||||||
states = self.moves.init_states([eg.x for eg in examples])
|
states = self.moves.init_states([eg.x for eg in examples])
|
||||||
cdef int i = 0
|
cdef int i = 0
|
||||||
|
@ -394,92 +295,15 @@ cdef class Parser(TrainablePipe):
|
||||||
i += 1
|
i += 1
|
||||||
# If the model is on GPU, copy the costs to device.
|
# If the model is on GPU, copy the costs to device.
|
||||||
costs = self.model.ops.asarray(costs)
|
costs = self.model.ops.asarray(costs)
|
||||||
return ids, costs
|
return costs
|
||||||
|
|
||||||
def get_loss(self, scores, costs):
|
|
||||||
xp = get_array_module(scores)
|
|
||||||
best_costs = costs.min(axis=1, keepdims=True)
|
|
||||||
is_gold = costs <= costs.min(axis=1, keepdims=True)
|
|
||||||
gscores = scores[is_gold]
|
|
||||||
max_ = scores.max(axis=1)
|
|
||||||
gmax = gscores.max(axis=1, keepdims=True)
|
|
||||||
exp_scores = xp.exp(scores - max_)
|
|
||||||
exp_gscores = xp.exp(gscores - gmax)
|
|
||||||
Z = exp_scores.sum(axis=1, keepdims=True)
|
|
||||||
gZ = exp_gscores.sum(axis=1, keepdims=True)
|
|
||||||
d_scores = exp_scores / Z
|
|
||||||
d_scores[is_gold] -= exp_gscores / gZ
|
|
||||||
return d_scores
|
|
||||||
|
|
||||||
def rehearse(self, examples, sgd=None, losses=None, **cfg):
|
def rehearse(self, examples, sgd=None, losses=None, **cfg):
|
||||||
"""Perform a "rehearsal" update, to prevent catastrophic forgetting."""
|
"""Perform a "rehearsal" update, to prevent catastrophic forgetting."""
|
||||||
if losses is None:
|
raise NotImplementedError
|
||||||
losses = {}
|
|
||||||
for multitask in self._multitasks:
|
|
||||||
if hasattr(multitask, 'rehearse'):
|
|
||||||
multitask.rehearse(examples, losses=losses, sgd=sgd)
|
|
||||||
if self._rehearsal_model is None:
|
|
||||||
return None
|
|
||||||
losses.setdefault(self.name, 0.)
|
|
||||||
validate_examples(examples, "Parser.rehearse")
|
|
||||||
docs = [eg.predicted for eg in examples]
|
|
||||||
states = self.moves.init_batch(docs)
|
|
||||||
# This is pretty dirty, but the NER can resize itself in init_batch,
|
|
||||||
# if labels are missing. We therefore have to check whether we need to
|
|
||||||
# expand our model output.
|
|
||||||
self._resize()
|
|
||||||
# Prepare the stepwise model, and get the callback for finishing the batch
|
|
||||||
set_dropout_rate(self._rehearsal_model, 0.0)
|
|
||||||
set_dropout_rate(self.model, 0.0)
|
|
||||||
tutor, _ = self._rehearsal_model.begin_update(docs)
|
|
||||||
model, backprop_tok2vec = self.model.begin_update(docs)
|
|
||||||
n_scores = 0.
|
|
||||||
loss = 0.
|
|
||||||
while states:
|
|
||||||
targets, _ = tutor.begin_update(states)
|
|
||||||
guesses, backprop = model.begin_update(states)
|
|
||||||
d_scores = (guesses - targets) / targets.shape[0]
|
|
||||||
# If all weights for an output are 0 in the original model, don't
|
|
||||||
# supervise that output. This allows us to add classes.
|
|
||||||
loss += (d_scores**2).sum()
|
|
||||||
backprop(d_scores)
|
|
||||||
# Follow the predicted action
|
|
||||||
self.transition_states(states, guesses)
|
|
||||||
states = [state for state in states if not state.is_final()]
|
|
||||||
n_scores += d_scores.size
|
|
||||||
# Do the backprop
|
|
||||||
backprop_tok2vec(docs)
|
|
||||||
if sgd is not None:
|
|
||||||
self.finish_update(sgd)
|
|
||||||
losses[self.name] += loss / n_scores
|
|
||||||
del backprop
|
|
||||||
del backprop_tok2vec
|
|
||||||
model.clear_memory()
|
|
||||||
tutor.clear_memory()
|
|
||||||
del model
|
|
||||||
del tutor
|
|
||||||
return losses
|
|
||||||
|
|
||||||
def update_beam(self, examples, *, beam_width,
|
def update_beam(self, examples, *, beam_width,
|
||||||
drop=0., sgd=None, losses=None, beam_density=0.0):
|
drop=0., sgd=None, losses=None, beam_density=0.0):
|
||||||
states, golds, _ = self.moves.init_gold_batch(examples)
|
raise NotImplementedError
|
||||||
if not states:
|
|
||||||
return losses
|
|
||||||
# Prepare the stepwise model, and get the callback for finishing the batch
|
|
||||||
model, backprop_tok2vec = self.model.begin_update(
|
|
||||||
[eg.predicted for eg in examples])
|
|
||||||
loss = _beam_utils.update_beam(
|
|
||||||
self.moves,
|
|
||||||
states,
|
|
||||||
golds,
|
|
||||||
model,
|
|
||||||
beam_width,
|
|
||||||
beam_density=beam_density,
|
|
||||||
)
|
|
||||||
losses[self.name] += loss
|
|
||||||
backprop_tok2vec(golds)
|
|
||||||
if sgd is not None:
|
|
||||||
self.finish_update(sgd)
|
|
||||||
|
|
||||||
def set_output(self, nO):
|
def set_output(self, nO):
|
||||||
self.model.attrs["resize_output"](self.model, nO)
|
self.model.attrs["resize_output"](self.model, nO)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user