New progress on parser model refactor

This commit is contained in:
Matthew Honnibal 2021-10-25 03:13:31 +02:00
parent 267ffb5605
commit de8c88babb
3 changed files with 410 additions and 410 deletions

View File

@ -208,50 +208,41 @@ cdef int arg_max_if_valid(const weight_t* scores, const int* is_valid, int n) no
class ParserStepModel(Model):
def __init__(self, docs, layers, *, has_upper, unseen_classes=None, train=True,
dropout=0.1):
Model.__init__(self, name="parser_step_model", forward=step_forward)
self.attrs["has_upper"] = has_upper
self.attrs["dropout_rate"] = dropout
self.tokvecs, self.bp_tokvecs = layers[0](docs, is_train=train)
if layers[1].get_dim("nP") >= 2:
activation = "maxout"
elif has_upper:
activation = None
else:
activation = "relu"
self.state2vec = precompute_hiddens(
len(docs),
self.tokvecs,
layers[1],
activation=activation,
train=train
)
if has_upper:
self.vec2scores = layers[-1]
else:
self.vec2scores = None
self._class_mask = numpy.zeros((self.nO,), dtype='f')
self._class_mask.fill(1)
if unseen_classes is not None:
for class_ in unseen_classes:
self._class_mask[class_] = 0.
def ParserStepModel(
tokvecs: Floats2d,
bp_tokvecs: Callable,
upper: Model[Floats2d, Floats2d],
dropout: float=0.1
unseen_classes: Optional[List[int]]=None
) -> Model[Ints2d, Floats2d]:
# TODO: Keep working on replacing all of this with just 'chain'
state2vec = precompute_hiddens(
tokvecs,
bp_tokvecs
)
class_mask = numpy.zeros((self.nO,), dtype='f')
class_mask.fill(1)
if unseen_classes is not None:
for class_ in unseen_classes:
class_mask[class_] = 0.
@property
def nO(self):
if self.attrs["has_upper"]:
return self.vec2scores.get_dim("nO")
else:
return self.state2vec.get_dim("nO")
return _ParserStepModel(
"ParserStep",
step_forward,
init=None,
dims={"nO": upper.get_dim("nO")},
layers=[state2vec, upper],
attrs={
"tokvecs": tokvecs,
"bp_tokvecs": bp_tokvecs,
"dropout_rate": dropout,
"class_mask": class_mask
}
)
def clear_memory(self):
del self.tokvecs
del self.bp_tokvecs
del self.state2vec
del self.backprops
del self._class_mask
class _ParserStepModel(Model):
# TODO: Remove need for all this stuff, so we can normalize this
def class_is_unseen(self, class_):
return self._class_mask[class_]
@ -274,21 +265,22 @@ class ParserStepModel(Model):
return ids
def step_forward(model: ParserStepModel, token_ids, is_train):
vector, get_d_tokvecs = model.state2vec(token_ids, is_train)
def step_forward(model: _ParserStepModel, token_ids, is_train):
# TODO: Eventually we hopefully can get rid of all of this?
# If we make the 'class_mask' thing its own layer, we can just
# have chain() here, right?
state2vec, upper = model.layers
vector, get_d_tokvecs = state2vec(token_ids, is_train)
mask = None
if model.attrs["has_upper"]:
vec2scores = ensure_same_device(model.ops, model.vec2scores)
dropout_rate = model.attrs["dropout_rate"]
if is_train and dropout_rate > 0:
mask = model.ops.get_dropout_mask(vector.shape, dropout_rate)
vector *= mask
scores, get_d_vector = vec2scores(vector, is_train)
else:
scores = vector
get_d_vector = lambda d_scores: d_scores
vec2scores = ensure_same_device(model.ops, vec2scores)
dropout_rate = model.attrs["dropout_rate"]
if is_train and dropout_rate > 0:
mask = model.ops.get_dropout_mask(vector.shape, dropout_rate)
vector *= mask
scores, get_d_vector = vec2scores(vector, is_train)
# If the class is unseen, make sure its score is minimum
scores[:, model._class_mask == 0] = model.ops.xp.nanmin(scores)
class_mask = model.attrs["class_mask"]
scores[:, class_mask == 0] = model.ops.xp.nanmin(scores)
def backprop_parser_step(d_scores):
# Zero vectors for unseen classes
@ -301,127 +293,45 @@ def step_forward(model: ParserStepModel, token_ids, is_train):
return scores, backprop_parser_step
def ensure_same_device(ops, model):
"""Ensure a model is on the same device as a given ops"""
if not isinstance(model.ops, ops.__class__):
model._to_ops(ops)
return model
def precompute_hiddens(lower_model, feat_weights: Floats3d, bp_hiddens: Callable) -> Model:
return Model(
"precompute_hiddens",
init=None,
forward=_precompute_forward,
dims={
"nO": feat_weights.shape[2],
"nP": lower_model.get_dim("nP") if lower_model.has_dim("nP") else 1,
"nF": cached.shape[1]
},
ops=lower_model.ops
)
cdef class precompute_hiddens:
"""Allow a model to be "primed" by pre-computing input features in bulk.
def _precomputed_forward(
model: Model[Ints2d, Floats2d],
token_ids: Ints2d,
is_train: bool
) -> Tuple[Floats2d, Callable]:
nO = model.get_dim("nO")
nP = model.get_dim("nP")
bp_hiddens = model.attrs["bp_hiddens"]
feat_weights = model.attrs["feat_weights"]
bias = model.attrs["bias"]
hidden = model.ops.alloc2f(
token_ids.shape[0],
nO * nP
)
# TODO: This is probably wrong, right?
model.ops.scatter_add(
hidden,
feat_weights,
token_ids
)
statevec, mask = model.ops.maxout(hidden.reshape((-1, nO, nP)))
This is used for the parser, where we want to take a batch of documents,
and compute vectors for each (token, position) pair. These vectors can then
be reused, especially for beam-search.
Let's say we're using 12 features for each state, e.g. word at start of
buffer, three words on stack, their children, etc. In the normal arc-eager
system, a document of length N is processed in 2*N states. This means we'll
create 2*N*12 feature vectors --- but if we pre-compute, we only need
N*12 vector computations. The saving for beam-search is much better:
if we have a beam of k, we'll normally make 2*N*12*K computations --
so we can save the factor k. This also gives a nice CPU/GPU division:
we can do all our hard maths up front, packed into large multiplications,
and do the hard-to-program parsing on the CPU.
"""
cdef readonly int nF, nO, nP
cdef public object ops
cdef readonly object bias
cdef readonly object activation
cdef readonly object _features
cdef readonly object _cached
cdef readonly object _bp_hiddens
def __init__(
self,
batch_size,
tokvecs,
lower_model,
activation="maxout",
train=False
):
cached, bp_features = lower_model(tokvecs, train)
self.bias = lower_model.get_param("b")
self.nF = cached.shape[1]
if lower_model.has_dim("nP"):
self.nP = lower_model.get_dim("nP")
else:
self.nP = 1
self.nO = cached.shape[2]
self.ops = lower_model.ops
assert activation in (None, "relu", "maxout")
self.activation = activation
self._cached = cached
self._bp_hiddens = bp_features
cdef const float* get_feat_weights(self) except NULL:
cdef np.ndarray cached
if isinstance(self._cached, numpy.ndarray):
cached = self._cached
else:
cached = self._cached.get()
return <float*>cached.data
def has_dim(self, name):
if name == "nF":
return self.nF if self.nF is not None else True
elif name == "nP":
return self.nP if self.nP is not None else True
elif name == "nO":
return self.nO if self.nO is not None else True
else:
return False
def get_dim(self, name):
if name == "nF":
return self.nF
elif name == "nP":
return self.nP
elif name == "nO":
return self.nO
else:
raise ValueError(f"Dimension {name} invalid -- only nO, nF, nP")
def set_dim(self, name, value):
if name == "nF":
self.nF = value
elif name == "nP":
self.nP = value
elif name == "nO":
self.nO = value
else:
raise ValueError(f"Dimension {name} invalid -- only nO, nF, nP")
def __call__(self, X, bint is_train):
if is_train:
return self.begin_update(X)
else:
return self.predict(X), lambda X: X
def predict(self, X):
return self.begin_update(X)[0]
def begin_update(self, token_ids):
nO = self.nO
nP = self.nP
hidden = self.model.ops.alloc2f(
token_ids.shape[0],
nO * nP
)
bp_hiddens = self._bp_hiddens
feat_weights = self.cached
self.ops.scatter_add(
hidden,
feat_weights,
token_ids
def backward(d_statevec):
return bp_hiddens(
model.ops.backprop_maxout(d_statevec, mask, nP)
)
hidden += self.bias
statevec, mask = self.ops.maxout(hidden.reshape((-1, nO, nP)))
def backward(d_statevec):
return bp_hiddens(
self.ops.backprop_maxout(d_statevec, mask, nP)
)
return statevec, backward
return statevec, backward

View File

@ -1,48 +1,314 @@
from thinc.api import Model, noop
from .parser_model import ParserStepModel
from typing import List, Tuple, Any, Optional
from thinc.api import Ops, Model, normal_init
from thinc.types import Floats1d, Floats2d, Floats3d, Ints2d, Floats4d
from ..tokens.doc import Doc
TransitionSystem = Any # TODO
State = Any # TODO
def TransitionModel(
tok2vec, lower, upper, resize_output, dropout=0.2, unseen_classes=set()
):
"""Set up a stepwise transition-based model"""
if upper is None:
has_upper = False
upper = noop()
else:
has_upper = True
# don't define nO for this object, because we can't dynamically change it
*,
tok2vec: Model[List[Doc], List[Floats2d]],
state_tokens: int,
hidden_width: int,
maxout_pieces: int,
nO: Optional[int] = None,
unseen_classes=set(),
) -> Model[Tuple[List[Doc], TransitionSystem], List[Tuple[State, List[Floats2d]]]]:
"""Set up a transition-based parsing model, using a maxout hidden
layer and a linear output layer.
"""
return Model(
name="parser_model",
forward=forward,
dims={"nI": tok2vec.get_dim("nI") if tok2vec.has_dim("nI") else None},
layers=[tok2vec, lower, upper],
refs={"tok2vec": tok2vec, "lower": lower, "upper": upper},
init=init,
layers=[tok2vec],
refs={"tok2vec": tok2vec},
params={
"lower_W": None, # Floats2d W for the hidden layer
"lower_b": None, # Floats1d bias for the hidden layer
"lower_pad": None, # Floats1d bias for the hidden layer
"upper_W": None, # Floats2d W for the output layer
"upper_b": None, # Floats1d bias for the output layer
},
dims={
"nO": None, # Output size
"nP": maxout_pieces,
"nH": hidden_width,
"nI": tok2vec.maybe_get_dim("nO"),
"nF": state_tokens,
},
attrs={
"has_upper": has_upper,
"unseen_classes": set(unseen_classes),
"resize_output": resize_output,
"make_step_model": make_step_model,
},
)
def forward(model, X, is_train):
step_model = ParserStepModel(
X,
model.layers,
unseen_classes=model.attrs["unseen_classes"],
train=is_train,
has_upper=model.attrs["has_upper"],
def make_step_model(model: Model) -> Model[List[State], Floats2d]:
...
def resize_output(model: Model) -> Model:
...
def init(
model,
X: Optional[Tuple[List[Doc], TransitionSystem]] = None,
Y: Optional[Tuple[List[State], List[Floats2d]]] = None,
):
if X is not None:
docs, states = X
model.get_ref("tok2vec").initialize(X=docs)
inferred_nO = _infer_nO(Y)
if inferred_nO is not None:
current_nO = model.maybe_get_dim("nO")
if current_nO is None:
model.set_dim("nO", inferred_nO)
elif current_nO != inferred_nO:
model.attrs["resize_output"](model, inferred_nO)
nO = model.get_dim("nO")
nP = model.get_dim("nP")
nH = model.get_dim("nH")
nI = model.get_dim("nI")
nF = model.get_dim("nF")
ops = model.ops
Wl = ops.alloc4f(nF, nH, nP, nI)
bl = ops.alloc2f(nH, nP)
padl = ops.alloc4f(1, nF, nH, nP)
Wu = ops.alloc2f(nO, nH)
bu = ops.alloc1f(nO)
Wl = normal_init(ops, Wl.shape, mean=float(ops.xp.sqrt(1.0 / nF * nI)))
padl = normal_init(ops, padl.shape, mean=1.0)
# TODO: Experiment with whether better to initialize Wu
model.set_param("lower_W", Wl)
model.set_param("lower_b", bl)
model.set_param("lower_pad", padl)
model.set_param("upper_W", Wu)
model.set_param("upper_b", bu)
_lsuv_init(model)
def forward(model, docs_moves, is_train):
tok2vec = model.get_ref("tok2vec")
state2scores = model.get_ref("state2scores")
# Get a reference to the parameters. We need to work with
# stable references through the forward/backward pass, to make
# sure we don't have a stale reference if there's concurrent shenanigans.
params = {name: model.get_param(name) for name in model.param_names}
ops = model.ops
docs, moves = docs_moves
states = moves.init_batch(docs)
tokvecs, backprop_tok2vec = tok2vec(docs, is_train)
feats, backprop_feats = _forward_precomputable_affine(model, tokvecs, is_train)
memory = []
all_scores = []
while states:
states, scores, memory = _step_parser(
ops, params, moves, states, feats, memory, is_train
)
all_scores.append(scores)
def backprop_parser(d_states_d_scores):
_, d_scores = d_states_d_scores
d_feats, ids = _backprop_parser_steps(ops, params, memory, d_scores)
d_tokvecs = backprop_feats((d_feats, ids))
return backprop_tok2vec(d_tokvecs), None
return (states, all_scores), backprop_parser
def _step_parser(ops, params, moves, states, feats, memory, is_train):
ids = moves.get_state_ids(states)
statevecs, which, scores = _score_ids(ops, params, ids, feats, is_train)
next_states = moves.transition_states(states, scores)
if is_train:
memory.append((ids, statevecs, which))
return next_states, scores, memory
def _score_ids(ops, params, ids, feats, is_train):
lower_pad = params["lower_pad"]
lower_b = params["lower_b"]
upper_W = params["upper_W"]
upper_b = params["upper_b"]
# During each step of the parser, we do:
# * Index into the features, to get the pre-activated vector
# for each (token, feature) and sum the feature vectors
preacts = _sum_state_features(feats, lower_pad, ids)
# * Add the bias
preacts += lower_b
# * Apply the activation (maxout)
statevecs, which = ops.maxout(preacts)
# * Multiply the state-vector by the scores weights
scores = ops.gemm(statevecs, upper_W, trans2=True)
# * Add the bias
scores += upper_b
# * Apply the is-class-unseen masking
# TODO
return statevecs, which, scores
def _sum_state_features(ops: Ops, feats: Floats3d, ids: Ints2d) -> Floats2d:
# Here's what we're trying to implement here:
#
# for i in range(ids.shape[0]):
# for j in range(ids.shape[1]):
# output[i] += feats[ids[i, j], j]
#
# Reshape the feats into 2d, to make indexing easier. Instead of getting an
# array of indices where the cell at (4, 2) needs to refer to the row at
# feats[4, 2], we'll translate the index so that it directly addresses
# feats[18]. This lets us make the indices array 1d, leading to fewer
# numpy shennanigans.
feats2d = ops.reshape2f(feats, feats.shape[0] * feats.shape[1], feats.shape[2])
# Now translate the ids. If we're looking for the row that used to be at
# (4, 1) and we have 4 features, we'll find it at (4*4)+1=17.
oob_ids = ids < 0 # Retain the -1 values
ids = ids * feats.shape[1] + ops.xp.arange(feats.shape[1])
ids[oob_ids] = -1
unsummed2d = feats2d[ops.reshape1i(ids, ids.size)]
unsummed3d = ops.reshape3f(
unsummed2d, feats.shape[0], feats.shape[1], feats.shape[2]
)
return step_model, step_model.finish_steps
summed = unsummed3d.sum(axis=1) # type: ignore
return summed
def init(model, X=None, Y=None):
model.get_ref("tok2vec").initialize(X=X)
lower = model.get_ref("lower")
lower.initialize()
if model.attrs["has_upper"]:
statevecs = model.ops.alloc2f(2, lower.get_dim("nO"))
model.get_ref("upper").initialize(X=statevecs)
def _process_memory(ops, memory):
"""Concatenate the memory buffers from each state into contiguous
buffers for the whole batch.
"""
return [ops.xp.concatenate(*item) for item in zip(*memory)]
def _backprop_parser_steps(model, upper_W, memory, d_scores):
# During each step of the parser, we do:
# * Index into the features, to get the pre-activated vector
# for each (token, feature)
# * Sum the feature vectors
# * Add the bias
# * Apply the activation (maxout)
# * Multiply the state-vector by the scores weights
# * Add the bias
# * Apply the is-class-unseen masking
#
# So we have to backprop through all those steps.
ids, statevecs, whiches = _process_memory(model.ops, memory)
# TODO: Unseen class masking
# Calculate the gradients for the parameters of the upper layer.
model.inc_grad("upper_b", d_scores.sum(axis=0))
model.inc_grad("upper_W", model.ops.gemm(d_scores, statevecs, trans1=True))
# Now calculate d_statevecs, by backproping through the upper linear layer.
d_statevecs = model.ops.gemm(d_scores, upper_W)
# Backprop through the maxout activation
d_preacts = model.ops.backprop_maxount(d_statevecs, whiches, model.get_dim("nP"))
# We don't need to backprop the summation, because we pass back the IDs instead
return d_preacts, ids
def _forward_precomputable_affine(model, X: Floats2d, is_train: bool):
W: Floats4d = model.get_param("lower_W")
b: Floats2d = model.get_param("lower_b")
pad: Floats4d = model.get_param("lower_pad")
nF = model.get_dim("nF")
nO = model.get_dim("nO")
nP = model.get_dim("nP")
nI = model.get_dim("nI")
Yf_ = model.ops.gemm(X, model.ops.reshape2f(W, nF * nO * nP, nI), trans2=True)
Yf = model.ops.reshape4f(Yf_, Yf_.shape[0], nF, nO, nP)
Yf = model.ops.xp.vstack((Yf, pad))
def backward(dY_ids: Tuple[Floats3d, Ints2d]):
# This backprop is particularly tricky, because we get back a different
# thing from what we put out. We put out an array of shape:
# (nB, nF, nO, nP), and get back:
# (nB, nO, nP) and ids (nB, nF)
# The ids tell us the values of nF, so we would have:
#
# dYf = zeros((nB, nF, nO, nP))
# for b in range(nB):
# for f in range(nF):
# dYf[b, ids[b, f]] += dY[b]
#
# However, we avoid building that array for efficiency -- and just pass
# in the indices.
dY, ids = dY_ids
assert dY.ndim == 3
assert dY.shape[1] == nO, dY.shape
assert dY.shape[2] == nP, dY.shape
# nB = dY.shape[0]
model.inc_grad(
"lower_pad", _backprop_precomputable_affine_padding(model, dY, ids)
)
Xf = model.ops.reshape2f(X[ids], ids.shape[0], nF * nI)
model.inc_grad("lower_b", dY.sum(axis=0)) # type: ignore
dY = model.ops.reshape2f(dY, dY.shape[0], nO * nP)
Wopfi = W.transpose((1, 2, 0, 3))
Wopfi = Wopfi.reshape((nO * nP, nF * nI))
dXf = model.ops.gemm(dY.reshape((dY.shape[0], nO * nP)), Wopfi)
dWopfi = model.ops.gemm(dY, Xf, trans1=True)
dWopfi = dWopfi.reshape((nO, nP, nF, nI))
# (o, p, f, i) --> (f, o, p, i)
dWopfi = dWopfi.transpose((2, 0, 1, 3))
model.inc_grad("W", dWopfi)
return model.ops.reshape3f(dXf, dXf.shape[0], nF, nI)
return Yf, backward
def _backprop_precomputable_affine_padding(model, dY, ids):
nB = dY.shape[0]
nF = model.get_dim("nF")
nP = model.get_dim("nP")
nO = model.get_dim("nO")
# Backprop the "padding", used as a filler for missing values.
# Values that are missing are set to -1, and each state vector could
# have multiple missing values. The padding has different values for
# different missing features. The gradient of the padding vector is:
#
# for b in range(nB):
# for f in range(nF):
# if ids[b, f] < 0:
# d_pad[f] += dY[b]
#
# Which can be rewritten as:
#
# (ids < 0).T @ dY
mask = model.ops.asarray(ids < 0, dtype="f")
d_pad = model.ops.gemm(mask, dY.reshape(nB, nO * nP), trans1=True)
return d_pad.reshape((1, nF, nO, nP))
def _infer_nO(Y: Optional[Tuple[List[State], List[Floats2d]]]) -> Optional[int]:
if Y is None:
return None
_, scores = Y
if len(scores) == 0:
return None
assert scores[0].shape[0] >= 1
assert len(scores[0].shape) == 2
return scores[0].shape[1]
def _lsuv_init(model):
"""This is like the 'layer sequential unit variance', but instead
of taking the actual inputs, we randomly generate whitened data.
Why's this all so complicated? We have a huge number of inputs,
and the maxout unit makes guessing the dynamics tricky. Instead
we set the maxout weights to values that empirically result in
whitened outputs given whitened inputs.
"""
# TODO
return None

View File

@ -208,70 +208,11 @@ cdef class Parser(TrainablePipe):
# if labels are missing. We therefore have to check whether we need to
# expand our model output.
self._resize()
model = self.model.predict(docs)
batch = self.moves.init_batch(docs)
states = self._predict_states(model, batch)
model.clear_memory()
del model
states, scores = self.model.predict((docs, self.moves))
return states
def _predict_states(self, model, batch):
cdef vector[StateC*] states
cdef StateClass state
weights = get_c_weights(model)
for state in batch:
if not state.is_final():
states.push_back(state.c)
sizes = get_c_sizes(model, states.size())
with nogil:
self._parseC(&states[0],
weights, sizes)
return batch
def beam_parse(self, docs, int beam_width, float drop=0., beam_density=0.):
cdef Beam beam
cdef Doc doc
batch = _beam_utils.BeamBatch(
self.moves,
self.moves.init_batch(docs),
None,
beam_width,
density=beam_density
)
# This is pretty dirty, but the NER can resize itself in init_batch,
# if labels are missing. We therefore have to check whether we need to
# expand our model output.
self._resize()
model = self.model.predict(docs)
while not batch.is_done:
states = batch.get_unfinished_states()
if not states:
break
scores = model.predict(states)
batch.advance(scores)
model.clear_memory()
del model
return list(batch)
cdef void _parseC(self, StateC** states,
WeightsC weights, SizesC sizes) nogil:
cdef int i, j
cdef vector[StateC*] unfinished
cdef ActivationsC activations = alloc_activations(sizes)
while sizes.states >= 1:
predict_states(&activations,
states, &weights, sizes)
# Validate actions, argmax, take action.
self.c_transition_batch(states,
activations.scores, sizes.classes, sizes.states)
for i in range(sizes.states):
if not states[i].is_final():
unfinished.push_back(states[i])
for i in range(unfinished.size()):
states[i] = unfinished[i]
sizes.states = unfinished.size()
unfinished.clear()
free_activations(&activations)
raise NotImplementedError
def set_annotations(self, docs, states_or_beams):
cdef StateClass state
@ -283,36 +224,6 @@ cdef class Parser(TrainablePipe):
for hook in self.postprocesses:
hook(doc)
def transition_states(self, states, float[:, ::1] scores):
cdef StateClass state
cdef float* c_scores = &scores[0, 0]
cdef vector[StateC*] c_states
for state in states:
c_states.push_back(state.c)
self.c_transition_batch(&c_states[0], c_scores, scores.shape[1], scores.shape[0])
return [state for state in states if not state.c.is_final()]
cdef void c_transition_batch(self, StateC** states, const float* scores,
int nr_class, int batch_size) nogil:
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
with gil:
assert self.moves.n_moves > 0, Errors.E924.format(name=self.name)
is_valid = <int*>calloc(self.moves.n_moves, sizeof(int))
cdef int i, guess
cdef Transition action
for i in range(batch_size):
self.moves.set_valid(is_valid, states[i])
guess = arg_max_if_valid(&scores[i*nr_class], is_valid, nr_class)
if guess == -1:
# This shouldn't happen, but it's hard to raise an error here,
# and we don't want to infinite loop. So, force to end state.
states[i].force_final()
else:
action = self.moves.c[guess]
action.do(states[i], action.label)
states[i].history.push_back(guess)
free(is_valid)
def update(self, examples, *, drop=0., sgd=None, losses=None):
cdef StateClass state
if losses is None:
@ -327,58 +238,48 @@ cdef class Parser(TrainablePipe):
if n_examples == 0:
return losses
set_dropout_rate(self.model, drop)
# The probability we use beam update, instead of falling back to
# a greedy update
beam_update_prob = self.cfg["beam_update_prob"]
if self.cfg['beam_width'] >= 2 and numpy.random.random() < beam_update_prob:
return self.update_beam(
examples,
beam_width=self.cfg["beam_width"],
sgd=sgd,
losses=losses,
beam_density=self.cfg["beam_density"]
)
docs = [eg.x for eg in examples]
model, backprop_tok2vec = self.model.begin_update(docs)
states = self.moves.init_batch(docs)
self._predict_states(states)
# I've separated the prediction from getting the batch because
# I like the idea of trying to store the histories or maybe compute
# them in another process or something. Just walking the states
# and transitioning isn't expensive anyway.
ids, costs = self._get_ids_and_costs_from_histories(
examples,
[list(state.history) for state in states]
)
scores, backprop_states = model.begin_update(ids)
d_scores = self.get_loss(scores, costs)
d_tokvecs = backprop_states(d_scores)
backprop_tok2vec(d_tokvecs)
(states, scores), backprop_scores = self.model.begin_update((docs, self.moves))
d_scores = self.get_loss((states, scores), examples)
backprop_scores(d_scores)
if sgd not in (None, False):
self.finish_update(sgd)
self.set_annotations(docs, states)
losses[self.name] += (d_scores**2).sum()
# Ugh, this is annoying. If we're working on GPU, we want to free the
# memory ASAP. It seems that Python doesn't necessarily get around to
# removing these in time if we don't explicitly delete? It's confusing.
del backprop_states
del backprop_tok2vec
model.clear_memory()
del model
del backprop_scores
return losses
def _get_ids_and_costs_from_histories(self, examples, histories):
def get_loss(self, states_scores, examples):
states, scores = states_scores
costs = self._get_costs_from_histories(
examples,
[list(state.history) for state in states]
)
xp = get_array_module(scores)
best_costs = costs.min(axis=1, keepdims=True)
is_gold = costs <= costs.min(axis=1, keepdims=True)
gscores = scores[is_gold]
max_ = scores.max(axis=1)
gmax = gscores.max(axis=1, keepdims=True)
exp_scores = xp.exp(scores - max_)
exp_gscores = xp.exp(gscores - gmax)
Z = exp_scores.sum(axis=1, keepdims=True)
gZ = exp_gscores.sum(axis=1, keepdims=True)
d_scores = exp_scores / Z
d_scores[is_gold] -= exp_gscores / gZ
return d_scores
def _get_costs_from_histories(self, examples, histories):
cdef StateClass state
cdef int clas
cdef int nF = self.model.state2vec.nF
cdef int nO = self.moves.n_moves
cdef int nS = sum([len(history) for history in histories])
# ids and costs have one row per state in the whole batch.
cdef np.ndarray ids = numpy.zeros((nS, nF), dtype="i")
cdef np.ndarray costs = numpy.zeros((nS, nO), dtype="f")
cdef Pool mem = Pool()
is_valid = <int*>mem.alloc(nO, sizeof(int))
c_ids = <int*>ids.data
c_costs = <float*>costs.data
states = self.moves.init_states([eg.x for eg in examples])
cdef int i = 0
@ -394,92 +295,15 @@ cdef class Parser(TrainablePipe):
i += 1
# If the model is on GPU, copy the costs to device.
costs = self.model.ops.asarray(costs)
return ids, costs
def get_loss(self, scores, costs):
xp = get_array_module(scores)
best_costs = costs.min(axis=1, keepdims=True)
is_gold = costs <= costs.min(axis=1, keepdims=True)
gscores = scores[is_gold]
max_ = scores.max(axis=1)
gmax = gscores.max(axis=1, keepdims=True)
exp_scores = xp.exp(scores - max_)
exp_gscores = xp.exp(gscores - gmax)
Z = exp_scores.sum(axis=1, keepdims=True)
gZ = exp_gscores.sum(axis=1, keepdims=True)
d_scores = exp_scores / Z
d_scores[is_gold] -= exp_gscores / gZ
return d_scores
return costs
def rehearse(self, examples, sgd=None, losses=None, **cfg):
"""Perform a "rehearsal" update, to prevent catastrophic forgetting."""
if losses is None:
losses = {}
for multitask in self._multitasks:
if hasattr(multitask, 'rehearse'):
multitask.rehearse(examples, losses=losses, sgd=sgd)
if self._rehearsal_model is None:
return None
losses.setdefault(self.name, 0.)
validate_examples(examples, "Parser.rehearse")
docs = [eg.predicted for eg in examples]
states = self.moves.init_batch(docs)
# This is pretty dirty, but the NER can resize itself in init_batch,
# if labels are missing. We therefore have to check whether we need to
# expand our model output.
self._resize()
# Prepare the stepwise model, and get the callback for finishing the batch
set_dropout_rate(self._rehearsal_model, 0.0)
set_dropout_rate(self.model, 0.0)
tutor, _ = self._rehearsal_model.begin_update(docs)
model, backprop_tok2vec = self.model.begin_update(docs)
n_scores = 0.
loss = 0.
while states:
targets, _ = tutor.begin_update(states)
guesses, backprop = model.begin_update(states)
d_scores = (guesses - targets) / targets.shape[0]
# If all weights for an output are 0 in the original model, don't
# supervise that output. This allows us to add classes.
loss += (d_scores**2).sum()
backprop(d_scores)
# Follow the predicted action
self.transition_states(states, guesses)
states = [state for state in states if not state.is_final()]
n_scores += d_scores.size
# Do the backprop
backprop_tok2vec(docs)
if sgd is not None:
self.finish_update(sgd)
losses[self.name] += loss / n_scores
del backprop
del backprop_tok2vec
model.clear_memory()
tutor.clear_memory()
del model
del tutor
return losses
raise NotImplementedError
def update_beam(self, examples, *, beam_width,
drop=0., sgd=None, losses=None, beam_density=0.0):
states, golds, _ = self.moves.init_gold_batch(examples)
if not states:
return losses
# Prepare the stepwise model, and get the callback for finishing the batch
model, backprop_tok2vec = self.model.begin_update(
[eg.predicted for eg in examples])
loss = _beam_utils.update_beam(
self.moves,
states,
golds,
model,
beam_width,
beam_density=beam_density,
)
losses[self.name] += loss
backprop_tok2vec(golds)
if sgd is not None:
self.finish_update(sgd)
raise NotImplementedError
def set_output(self, nO):
self.model.attrs["resize_output"](self.model, nO)