mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-04 20:30:24 +03:00
Reimplement parser rehearsal function (#10878)
* Reimplement parser rehearsal function Before the parser refactor, rehearsal was driven by a loop in the `rehearse` method itself. For each parsing step, the loops would: 1. Get the predictions of the teacher. 2. Get the predictions and backprop function of the student. 3. Compute the loss and backprop into the student. 4. Move the teacher and student forward with the predictions of the student. In the refactored parser, we cannot perform search stepwise rehearsal anymore, since the model now predicts all parsing steps at once. Therefore, rehearsal is performed in the following steps: 1. Get the predictions of all parsing steps from the student, along with its backprop function. 2. Get the predictions from the teacher, but use the predictions of the student to advance the parser while doing so. 3. Compute the loss and backprop into the student. To support the second step a new method, `advance_with_actions` is added to `GreedyBatch`, which performs the provided parsing steps. * tb_framework: wrap upper_W and upper_b in Linear Thinc's Optimizer cannot handle resizing of existing parameters. Until it does, we work around this by wrapping the weights/biases of the upper layer of the parser model in Linear. When the upper layer is resized, we copy over the existing parameters into a new Linear instance. This does not trigger an error in Optimizer, because it sees the resized layer as a new set of parameters. * Add test for TransitionSystem.apply_actions * Better FIXME marker Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com> * Fixes from Madeesh * Apply suggestions from Sofie Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> * Remove useless assignment Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com> Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
This commit is contained in:
parent
aad38972cb
commit
63e90dd6a1
|
@ -1,5 +1,5 @@
|
|||
# cython: infer_types=True, cdivision=True, boundscheck=False
|
||||
from typing import List, Tuple, Any, Optional, cast
|
||||
from typing import List, Tuple, Any, Optional, TypeVar, cast
|
||||
from libc.string cimport memset, memcpy
|
||||
from libc.stdlib cimport calloc, free, realloc
|
||||
from libcpp.vector cimport vector
|
||||
|
@ -10,12 +10,14 @@ from thinc.api import uniform_init, glorot_uniform_init, zero_init
|
|||
from thinc.api import NumpyOps
|
||||
from thinc.backends.linalg cimport Vec, VecVec
|
||||
from thinc.backends.cblas cimport CBlas
|
||||
from thinc.types import Floats1d, Floats2d, Floats3d, Ints2d, Floats4d
|
||||
from thinc.types import Floats1d, Floats2d, Floats3d, Floats4d
|
||||
from thinc.types import Ints1d, Ints2d
|
||||
|
||||
from ..errors import Errors
|
||||
from ..pipeline._parser_internals import _beam_utils
|
||||
from ..pipeline._parser_internals.batch import GreedyBatch
|
||||
from ..pipeline._parser_internals.transition_system cimport c_transition_batch, TransitionSystem
|
||||
from ..pipeline._parser_internals.transition_system cimport c_transition_batch, c_apply_actions
|
||||
from ..pipeline._parser_internals.transition_system cimport TransitionSystem
|
||||
from ..pipeline._parser_internals.stateclass cimport StateC, StateClass
|
||||
from ..tokens.doc import Doc
|
||||
from ..util import registry
|
||||
|
@ -43,18 +45,28 @@ def TransitionModel(
|
|||
tok2vec_projected = chain(tok2vec, list2array(), Linear(hidden_width, t2v_width)) # type: ignore
|
||||
tok2vec_projected.set_dim("nO", hidden_width)
|
||||
|
||||
# FIXME: we use `upper` as a container for the upper layer's
|
||||
# weights and biases. Thinc optimizers cannot handle resizing
|
||||
# of parameters. So, when the parser model is resized, we
|
||||
# construct a new `upper` layer, which has a different key in
|
||||
# the optimizer. Once the optimizer supports parameter resizing,
|
||||
# we can replace the `upper` layer by `upper_W` and `upper_b`
|
||||
# parameters in this model.
|
||||
upper = Linear(nO=None, nI=hidden_width, init_W=zero_init)
|
||||
|
||||
return Model(
|
||||
name="parser_model",
|
||||
forward=forward,
|
||||
init=init,
|
||||
layers=[tok2vec_projected],
|
||||
refs={"tok2vec": tok2vec_projected},
|
||||
layers=[tok2vec_projected, upper],
|
||||
refs={
|
||||
"tok2vec": tok2vec_projected,
|
||||
"upper": upper,
|
||||
},
|
||||
params={
|
||||
"lower_W": None, # Floats2d W for the hidden layer
|
||||
"lower_b": None, # Floats1d bias for the hidden layer
|
||||
"lower_pad": None, # Floats1d padding for the hidden layer
|
||||
"upper_W": None, # Floats2d W for the output layer
|
||||
"upper_b": None, # Floats1d bias for the output layer
|
||||
},
|
||||
dims={
|
||||
"nO": None, # Output size
|
||||
|
@ -74,29 +86,30 @@ def TransitionModel(
|
|||
|
||||
def resize_output(model: Model, new_nO: int) -> Model:
|
||||
old_nO = model.maybe_get_dim("nO")
|
||||
upper = model.get_ref("upper")
|
||||
if old_nO is None:
|
||||
model.set_dim("nO", new_nO)
|
||||
upper.set_dim("nO", new_nO)
|
||||
upper.initialize()
|
||||
return model
|
||||
elif new_nO <= old_nO:
|
||||
return model
|
||||
elif model.has_param("upper_W"):
|
||||
elif upper.has_param("W"):
|
||||
nH = model.get_dim("nH")
|
||||
new_W = model.ops.alloc2f(new_nO, nH)
|
||||
new_b = model.ops.alloc1f(new_nO)
|
||||
old_W = model.get_param("upper_W")
|
||||
old_b = model.get_param("upper_b")
|
||||
new_upper = Linear(nO=new_nO, nI=nH, init_W=zero_init)
|
||||
new_upper.initialize()
|
||||
new_W = new_upper.get_param("W")
|
||||
new_b = new_upper.get_param("b")
|
||||
old_W = upper.get_param("W")
|
||||
old_b = upper.get_param("b")
|
||||
new_W[:old_nO] = old_W # type: ignore
|
||||
new_b[:old_nO] = old_b # type: ignore
|
||||
for i in range(old_nO, new_nO):
|
||||
model.attrs["unseen_classes"].add(i)
|
||||
model.set_param("upper_W", new_W)
|
||||
model.set_param("upper_b", new_b)
|
||||
model.layers[-1] = new_upper
|
||||
model.set_ref("upper", new_upper)
|
||||
# TODO: Avoid this private intrusion
|
||||
model._dims["nO"] = new_nO
|
||||
if model.has_grad("upper_W"):
|
||||
model.set_grad("upper_W", model.get_param("upper_W") * 0)
|
||||
if model.has_grad("upper_b"):
|
||||
model.set_grad("upper_b", model.get_param("upper_b") * 0)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -113,9 +126,7 @@ def init(
|
|||
inferred_nO = _infer_nO(Y)
|
||||
if inferred_nO is not None:
|
||||
current_nO = model.maybe_get_dim("nO")
|
||||
if current_nO is None:
|
||||
model.set_dim("nO", inferred_nO)
|
||||
elif current_nO != inferred_nO:
|
||||
if current_nO is None or current_nO != inferred_nO:
|
||||
model.attrs["resize_output"](model, inferred_nO)
|
||||
nO = model.get_dim("nO")
|
||||
nP = model.get_dim("nP")
|
||||
|
@ -127,9 +138,6 @@ def init(
|
|||
Wl = ops.alloc2f(nH * nP, nF * nI)
|
||||
bl = ops.alloc1f(nH * nP)
|
||||
padl = ops.alloc1f(nI)
|
||||
Wu = ops.alloc2f(nO, nH)
|
||||
bu = ops.alloc1f(nO)
|
||||
Wu = zero_init(ops, Wu.shape)
|
||||
# Wl = zero_init(ops, Wl.shape)
|
||||
Wl = glorot_uniform_init(ops, Wl.shape)
|
||||
padl = uniform_init(ops, padl.shape) # type: ignore
|
||||
|
@ -137,30 +145,39 @@ def init(
|
|||
model.set_param("lower_W", Wl)
|
||||
model.set_param("lower_b", bl)
|
||||
model.set_param("lower_pad", padl)
|
||||
model.set_param("upper_W", Wu)
|
||||
model.set_param("upper_b", bu)
|
||||
# model = _lsuv_init(model)
|
||||
return model
|
||||
|
||||
def forward(model, docs_moves: Tuple[List[Doc], TransitionSystem], is_train: bool):
|
||||
InWithoutActions = Tuple[List[Doc], TransitionSystem]
|
||||
InWithActions = Tuple[List[Doc], TransitionSystem, List[Ints1d]]
|
||||
InT = TypeVar("InT", InWithoutActions, InWithActions)
|
||||
|
||||
def forward(model, docs_moves: InT, is_train: bool):
|
||||
if len(docs_moves) == 2:
|
||||
docs, moves = docs_moves
|
||||
actions = None
|
||||
else:
|
||||
docs, moves, actions = docs_moves
|
||||
|
||||
beam_width = model.attrs["beam_width"]
|
||||
lower_pad = model.get_param("lower_pad")
|
||||
tok2vec = model.get_ref("tok2vec")
|
||||
|
||||
docs, moves = docs_moves
|
||||
states = moves.init_batch(docs)
|
||||
tokvecs, backprop_tok2vec = tok2vec(docs, is_train)
|
||||
tokvecs = model.ops.xp.vstack((tokvecs, lower_pad))
|
||||
feats, backprop_feats = _forward_precomputable_affine(model, tokvecs, is_train)
|
||||
seen_mask = _get_seen_mask(model)
|
||||
|
||||
# Fixme: support actions in forward_cpu
|
||||
if beam_width == 1 and not is_train and isinstance(model.ops, NumpyOps):
|
||||
return _forward_greedy_cpu(model, moves, states, feats, seen_mask)
|
||||
return _forward_greedy_cpu(model, moves, states, feats, seen_mask, actions=actions)
|
||||
else:
|
||||
return _forward_fallback(model, moves, states, tokvecs, backprop_tok2vec, feats, backprop_feats, seen_mask, is_train)
|
||||
return _forward_fallback(model, moves, states, tokvecs, backprop_tok2vec, feats, backprop_feats, seen_mask, is_train, actions=actions)
|
||||
|
||||
|
||||
def _forward_greedy_cpu(model: Model, TransitionSystem moves, states: List[StateClass], np.ndarray feats,
|
||||
np.ndarray[np.npy_bool, ndim=1] seen_mask):
|
||||
np.ndarray[np.npy_bool, ndim=1] seen_mask, actions: Optional[List[Ints1d]]=None):
|
||||
cdef vector[StateC*] c_states
|
||||
cdef StateClass state
|
||||
for state in states:
|
||||
|
@ -171,7 +188,7 @@ def _forward_greedy_cpu(model: Model, TransitionSystem moves, states: List[State
|
|||
cdef int n_tokens = feats.shape[0] - 1
|
||||
sizes = get_c_sizes(model, c_states.size(), n_tokens)
|
||||
cdef CBlas cblas = model.ops.cblas()
|
||||
scores = _parseC(cblas, moves, &c_states[0], weights, sizes)
|
||||
scores = _parseC(cblas, moves, &c_states[0], weights, sizes, actions=actions)
|
||||
|
||||
def backprop(dY):
|
||||
raise ValueError(Errors.E1038)
|
||||
|
@ -179,20 +196,25 @@ def _forward_greedy_cpu(model: Model, TransitionSystem moves, states: List[State
|
|||
return (states, scores), backprop
|
||||
|
||||
cdef list _parseC(CBlas cblas, TransitionSystem moves, StateC** states,
|
||||
WeightsC weights, SizesC sizes):
|
||||
WeightsC weights, SizesC sizes, actions: Optional[List[Ints1d]]=None):
|
||||
cdef int i, j
|
||||
cdef vector[StateC *] unfinished
|
||||
cdef ActivationsC activations = alloc_activations(sizes)
|
||||
cdef np.ndarray step_scores
|
||||
cdef np.ndarray step_actions
|
||||
|
||||
scores = []
|
||||
while sizes.states >= 1:
|
||||
step_scores = numpy.empty((sizes.states, sizes.classes), dtype="f")
|
||||
step_actions = actions[0] if actions is not None else None
|
||||
with nogil:
|
||||
predict_states(cblas, &activations, <float*>step_scores.data, states, &weights, sizes)
|
||||
if actions is None:
|
||||
# Validate actions, argmax, take action.
|
||||
c_transition_batch(moves, states, <const float*>step_scores.data, sizes.classes,
|
||||
sizes.states)
|
||||
else:
|
||||
c_apply_actions(moves, states, <const int*>step_actions.data, sizes.states)
|
||||
for i in range(sizes.states):
|
||||
if not states[i].is_final():
|
||||
unfinished.push_back(states[i])
|
||||
|
@ -201,15 +223,17 @@ cdef list _parseC(CBlas cblas, TransitionSystem moves, StateC** states,
|
|||
sizes.states = unfinished.size()
|
||||
scores.append(step_scores)
|
||||
unfinished.clear()
|
||||
actions = actions[1:] if actions is not None else None
|
||||
free_activations(&activations)
|
||||
|
||||
return scores
|
||||
|
||||
def _forward_fallback(model: Model, moves: TransitionSystem, states: List[StateClass], tokvecs, backprop_tok2vec, feats, backprop_feats, seen_mask, is_train: bool):
|
||||
|
||||
def _forward_fallback(model: Model, moves: TransitionSystem, states: List[StateClass], tokvecs, backprop_tok2vec, feats, backprop_feats, seen_mask, is_train: bool,
|
||||
actions: Optional[List[Ints1d]]=None):
|
||||
nF = model.get_dim("nF")
|
||||
upper = model.get_ref("upper")
|
||||
lower_b = model.get_param("lower_b")
|
||||
upper_W = model.get_param("upper_W")
|
||||
upper_b = model.get_param("upper_b")
|
||||
nH = model.get_dim("nH")
|
||||
nP = model.get_dim("nP")
|
||||
|
||||
|
@ -240,14 +264,17 @@ def _forward_fallback(model: Model, moves: TransitionSystem, states: List[StateC
|
|||
preacts = ops.reshape3f(preacts2f, preacts2f.shape[0], nH, nP)
|
||||
assert preacts.shape[0] == len(batch.get_unfinished_states()), preacts.shape
|
||||
statevecs, which = ops.maxout(preacts)
|
||||
# Multiply the state-vector by the scores weights and add the bias,
|
||||
# to get the logits.
|
||||
scores = ops.gemm(statevecs, upper_W, trans2=True)
|
||||
scores += upper_b
|
||||
# We don't use upper's backprop, since we want to backprop for
|
||||
# all states at once, rather than a single state.
|
||||
scores = upper.predict(statevecs)
|
||||
scores[:, seen_mask] = ops.xp.nanmin(scores)
|
||||
# Transition the states, filtering out any that are finished.
|
||||
cpu_scores = ops.to_numpy(scores)
|
||||
if actions is None:
|
||||
batch.advance(cpu_scores)
|
||||
else:
|
||||
batch.advance_with_actions(actions[0])
|
||||
actions = actions[1:]
|
||||
all_scores.append(scores)
|
||||
if is_train:
|
||||
# Remember intermediate results for the backprop.
|
||||
|
@ -270,10 +297,11 @@ def _forward_fallback(model: Model, moves: TransitionSystem, states: List[StateC
|
|||
d_scores *= seen_mask == False
|
||||
# Calculate the gradients for the parameters of the upper layer.
|
||||
# The weight gemm is (nS, nO) @ (nS, nH).T
|
||||
model.inc_grad("upper_b", d_scores.sum(axis=0))
|
||||
model.inc_grad("upper_W", ops.gemm(d_scores, statevecs, trans1=True))
|
||||
upper.inc_grad("b", d_scores.sum(axis=0))
|
||||
upper.inc_grad("W", ops.gemm(d_scores, statevecs, trans1=True))
|
||||
# Now calculate d_statevecs, by backproping through the upper linear layer.
|
||||
# This gemm is (nS, nO) @ (nO, nH)
|
||||
upper_W = upper.get_param("W")
|
||||
d_statevecs = ops.gemm(d_scores, upper_W)
|
||||
# Backprop through the maxout activation
|
||||
d_preacts = ops.backprop_maxout(d_statevecs, which, nP)
|
||||
|
@ -295,11 +323,10 @@ def _forward_reference(
|
|||
"""Slow reference implementation, without the precomputation"""
|
||||
nF = model.get_dim("nF")
|
||||
tok2vec = model.get_ref("tok2vec")
|
||||
upper = model.get_ref("upper")
|
||||
lower_pad = model.get_param("lower_pad")
|
||||
lower_W = model.get_param("lower_W")
|
||||
lower_b = model.get_param("lower_b")
|
||||
upper_W = model.get_param("upper_W")
|
||||
upper_b = model.get_param("upper_b")
|
||||
nH = model.get_dim("nH")
|
||||
nP = model.get_dim("nP")
|
||||
nO = model.get_dim("nO")
|
||||
|
@ -330,10 +357,9 @@ def _forward_reference(
|
|||
preacts2f += lower_b
|
||||
preacts = model.ops.reshape3f(preacts2f, preacts2f.shape[0], nH, nP)
|
||||
statevecs, which = ops.maxout(preacts)
|
||||
# Multiply the state-vector by the scores weights and add the bias,
|
||||
# to get the logits.
|
||||
scores = model.ops.gemm(statevecs, upper_W, trans2=True)
|
||||
scores += upper_b
|
||||
# We don't use upper's backprop, since we want to backprop for
|
||||
# all states at once, rather than a single state.
|
||||
scores = upper.predict(statevecs)
|
||||
scores[:, seen_mask] = model.ops.xp.nanmin(scores)
|
||||
# Transition the states, filtering out any that are finished.
|
||||
next_states = moves.transition_states(next_states, scores)
|
||||
|
@ -366,10 +392,11 @@ def _forward_reference(
|
|||
assert d_scores.shape == (nS, nO), d_scores.shape
|
||||
# Calculate the gradients for the parameters of the upper layer.
|
||||
# The weight gemm is (nS, nO) @ (nS, nH).T
|
||||
model.inc_grad("upper_b", d_scores.sum(axis=0))
|
||||
model.inc_grad("upper_W", model.ops.gemm(d_scores, statevecs, trans1=True))
|
||||
upper.inc_grad("b", d_scores.sum(axis=0))
|
||||
upper.inc_grad("W", model.ops.gemm(d_scores, statevecs, trans1=True))
|
||||
# Now calculate d_statevecs, by backproping through the upper linear layer.
|
||||
# This gemm is (nS, nO) @ (nO, nH)
|
||||
upper_W = upper.get_param("W")
|
||||
d_statevecs = model.ops.gemm(d_scores, upper_W)
|
||||
# Backprop through the maxout activation
|
||||
d_preacts = model.ops.backprop_maxout(d_statevecs, which, nP)
|
||||
|
@ -514,9 +541,10 @@ def _lsuv_init(model: Model):
|
|||
|
||||
|
||||
cdef WeightsC get_c_weights(model, const float* feats, np.ndarray[np.npy_bool, ndim=1] seen_mask) except *:
|
||||
upper = model.get_ref("upper")
|
||||
cdef np.ndarray lower_b = model.get_param("lower_b")
|
||||
cdef np.ndarray upper_W = model.get_param("upper_W")
|
||||
cdef np.ndarray upper_b = model.get_param("upper_b")
|
||||
cdef np.ndarray upper_W = upper.get_param("W")
|
||||
cdef np.ndarray upper_b = upper.get_param("b")
|
||||
|
||||
cdef WeightsC output
|
||||
output.feat_weights = feats
|
||||
|
|
|
@ -32,6 +32,9 @@ class GreedyBatch(Batch):
|
|||
def advance(self, scores):
|
||||
self._next_states = self._moves.transition_states(self._next_states, scores)
|
||||
|
||||
def advance_with_actions(self, actions):
|
||||
self._next_states = self._moves.apply_transitions(self._next_states, actions)
|
||||
|
||||
def get_states(self):
|
||||
return self._states
|
||||
|
||||
|
|
|
@ -54,5 +54,9 @@ cdef class TransitionSystem:
|
|||
cdef int set_costs(self, int* is_valid, weight_t* costs,
|
||||
const StateC* state, gold) except -1
|
||||
|
||||
|
||||
cdef void c_apply_actions(TransitionSystem moves, StateC** states, const int* actions,
|
||||
int batch_size) nogil
|
||||
|
||||
cdef void c_transition_batch(TransitionSystem moves, StateC** states, const float* scores,
|
||||
int nr_class, int batch_size) nogil
|
||||
|
|
|
@ -155,6 +155,17 @@ cdef class TransitionSystem:
|
|||
action.do(state.c, action.label)
|
||||
state.c.history.push_back(action.clas)
|
||||
|
||||
def apply_actions(self, states, const int[::1] actions):
|
||||
assert len(states) == actions.shape[0]
|
||||
cdef StateClass state
|
||||
cdef vector[StateC*] c_states
|
||||
c_states.resize(len(states))
|
||||
cdef int i
|
||||
for (i, state) in enumerate(states):
|
||||
c_states[i] = state.c
|
||||
c_apply_actions(self, &c_states[0], &actions[0], actions.shape[0])
|
||||
return [state for state in states if not state.c.is_final()]
|
||||
|
||||
def transition_states(self, states, float[:, ::1] scores):
|
||||
assert len(states) == scores.shape[0]
|
||||
cdef StateClass state
|
||||
|
@ -279,6 +290,18 @@ cdef class TransitionSystem:
|
|||
return self
|
||||
|
||||
|
||||
cdef void c_apply_actions(TransitionSystem moves, StateC** states, const int* actions,
|
||||
int batch_size) nogil:
|
||||
cdef int i
|
||||
cdef Transition action
|
||||
cdef StateC* state
|
||||
for i in range(batch_size):
|
||||
state = states[i]
|
||||
action = moves.c[actions[i]]
|
||||
action.do(state, action.label)
|
||||
state.history.push_back(action.clas)
|
||||
|
||||
|
||||
cdef void c_transition_batch(TransitionSystem moves, StateC** states, const float* scores,
|
||||
int nr_class, int batch_size) nogil:
|
||||
is_valid = <int*>calloc(moves.n_moves, sizeof(int))
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# cython: infer_types=True, cdivision=True, boundscheck=False, binding=True
|
||||
from __future__ import print_function
|
||||
from typing import List
|
||||
from cymem.cymem cimport Pool
|
||||
cimport numpy as np
|
||||
from itertools import islice
|
||||
|
@ -12,11 +13,12 @@ import contextlib
|
|||
import srsly
|
||||
from thinc.api import set_dropout_rate, CupyOps, get_array_module
|
||||
from thinc.extra.search cimport Beam
|
||||
from thinc.types import Ints1d
|
||||
import numpy.random
|
||||
import numpy
|
||||
import warnings
|
||||
|
||||
from ._parser_internals.stateclass cimport StateClass
|
||||
from ._parser_internals.stateclass cimport StateC, StateClass
|
||||
from ..tokens.doc cimport Doc
|
||||
from .trainable_pipe import TrainablePipe
|
||||
from ._parser_internals cimport _beam_utils
|
||||
|
@ -359,7 +361,42 @@ class Parser(TrainablePipe):
|
|||
|
||||
def rehearse(self, examples, sgd=None, losses=None, **cfg):
|
||||
"""Perform a "rehearsal" update, to prevent catastrophic forgetting."""
|
||||
raise NotImplementedError
|
||||
if losses is None:
|
||||
losses = {}
|
||||
for multitask in self._multitasks:
|
||||
if hasattr(multitask, 'rehearse'):
|
||||
multitask.rehearse(examples, losses=losses, sgd=sgd)
|
||||
if self._rehearsal_model is None:
|
||||
return None
|
||||
losses.setdefault(self.name, 0.0)
|
||||
validate_examples(examples, "Parser.rehearse")
|
||||
docs = [eg.predicted for eg in examples]
|
||||
# This is pretty dirty, but the NER can resize itself in init_batch,
|
||||
# if labels are missing. We therefore have to check whether we need to
|
||||
# expand our model output.
|
||||
self._resize()
|
||||
# Prepare the stepwise model, and get the callback for finishing the batch
|
||||
set_dropout_rate(self._rehearsal_model, 0.0)
|
||||
set_dropout_rate(self.model, 0.0)
|
||||
(student_states, student_scores), backprop_scores = self.model.begin_update((docs, self.moves))
|
||||
actions = states2actions(student_states)
|
||||
_, teacher_scores = self._rehearsal_model.predict((docs, self.moves, actions))
|
||||
|
||||
teacher_scores = self.model.ops.xp.vstack(teacher_scores)
|
||||
student_scores = self.model.ops.xp.vstack(student_scores)
|
||||
assert teacher_scores.shape == student_scores.shape
|
||||
|
||||
d_scores = (student_scores - teacher_scores) / teacher_scores.shape[0]
|
||||
# If all weights for an output are 0 in the original model, don't
|
||||
# supervise that output. This allows us to add classes.
|
||||
loss = (d_scores**2).sum() / d_scores.size
|
||||
backprop_scores((student_states, d_scores))
|
||||
|
||||
if sgd is not None:
|
||||
self.finish_update(sgd)
|
||||
losses[self.name] += loss
|
||||
|
||||
return losses
|
||||
|
||||
def update_beam(self, examples, *, beam_width,
|
||||
drop=0., sgd=None, losses=None, beam_density=0.0):
|
||||
|
@ -474,3 +511,26 @@ def _change_attrs(model, **kwargs):
|
|||
model.attrs.pop(key)
|
||||
else:
|
||||
model.attrs[key] = value
|
||||
|
||||
|
||||
def states2actions(states: List[StateClass]) -> List[Ints1d]:
|
||||
cdef int step
|
||||
cdef StateClass state
|
||||
cdef StateC* c_state
|
||||
actions = []
|
||||
while True:
|
||||
step = len(actions)
|
||||
|
||||
step_actions = []
|
||||
for state in states:
|
||||
c_state = state.c
|
||||
if step < c_state.history.size():
|
||||
step_actions.append(c_state.history[step])
|
||||
|
||||
# We are done if we have exhausted all histories.
|
||||
if len(step_actions) == 0:
|
||||
break
|
||||
|
||||
actions.append(numpy.array(step_actions, dtype="i"))
|
||||
|
||||
return actions
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import pytest
|
||||
import numpy
|
||||
from numpy.testing import assert_equal
|
||||
from thinc.api import Adam
|
||||
|
||||
|
@ -175,6 +176,57 @@ def test_parser_parse_one_word_sentence(en_vocab, en_parser, words):
|
|||
assert doc[0].dep != 0
|
||||
|
||||
|
||||
def test_parser_apply_actions(en_vocab, en_parser):
|
||||
words = ["I", "ate", "pizza"]
|
||||
words2 = ["Eat", "more", "pizza", "!"]
|
||||
doc1 = Doc(en_vocab, words=words)
|
||||
doc2 = Doc(en_vocab, words=words2)
|
||||
docs = [doc1, doc2]
|
||||
|
||||
moves = en_parser.moves
|
||||
moves.add_action(0, "")
|
||||
moves.add_action(1, "")
|
||||
moves.add_action(2, "nsubj")
|
||||
moves.add_action(3, "obj")
|
||||
moves.add_action(2, "amod")
|
||||
|
||||
actions = [
|
||||
numpy.array([0, 0], dtype="i"),
|
||||
numpy.array([2, 0], dtype="i"),
|
||||
numpy.array([0, 4], dtype="i"),
|
||||
numpy.array([3, 3], dtype="i"),
|
||||
numpy.array([1, 1], dtype="i"),
|
||||
numpy.array([1, 1], dtype="i"),
|
||||
numpy.array([0], dtype="i"),
|
||||
numpy.array([1], dtype="i"),
|
||||
]
|
||||
|
||||
states = moves.init_batch(docs)
|
||||
active_states = states
|
||||
|
||||
for step_actions in actions:
|
||||
active_states = moves.apply_actions(active_states, step_actions)
|
||||
|
||||
assert len(active_states) == 0
|
||||
|
||||
for (state, doc) in zip(states, docs):
|
||||
moves.set_annotations(state, doc)
|
||||
|
||||
assert docs[0][0].head.i == 1
|
||||
assert docs[0][0].dep_ == "nsubj"
|
||||
assert docs[0][1].head.i == 1
|
||||
assert docs[0][1].dep_ == "ROOT"
|
||||
assert docs[0][2].head.i == 1
|
||||
assert docs[0][2].dep_ == "obj"
|
||||
|
||||
assert docs[1][0].head.i == 0
|
||||
assert docs[1][0].dep_ == "ROOT"
|
||||
assert docs[1][1].head.i == 2
|
||||
assert docs[1][1].dep_ == "amod"
|
||||
assert docs[1][2].head.i == 0
|
||||
assert docs[1][2].dep_ == "obj"
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="The step_through API was removed (but should be brought back)"
|
||||
)
|
||||
|
|
|
@ -264,9 +264,11 @@ def test_serialize_custom_nlp():
|
|||
model = nlp2.get_pipe("parser").model
|
||||
assert model.get_ref("tok2vec") is not None
|
||||
assert model.has_param("lower_W")
|
||||
assert model.has_param("upper_W")
|
||||
assert model.has_param("lower_b")
|
||||
assert model.has_param("upper_b")
|
||||
upper = model.get_ref("upper")
|
||||
assert upper is not None
|
||||
assert upper.has_param("W")
|
||||
assert upper.has_param("b")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("parser_config_string", [parser_config_string_upper])
|
||||
|
@ -284,9 +286,11 @@ def test_serialize_parser(parser_config_string):
|
|||
model = nlp2.get_pipe("parser").model
|
||||
assert model.get_ref("tok2vec") is not None
|
||||
assert model.has_param("lower_W")
|
||||
assert model.has_param("upper_W")
|
||||
assert model.has_param("lower_b")
|
||||
assert model.has_param("upper_b")
|
||||
upper = model.get_ref("upper")
|
||||
assert upper is not None
|
||||
assert upper.has_param("b")
|
||||
assert upper.has_param("W")
|
||||
|
||||
|
||||
def test_config_nlp_roundtrip():
|
||||
|
|
|
@ -203,8 +203,7 @@ def _optimize(nlp, component: str, data: List, rehearse: bool):
|
|||
return nlp
|
||||
|
||||
|
||||
# Fixme: reenable ner and parser when rehearsal is implemented.
|
||||
@pytest.mark.parametrize("component", ["tagger", "textcat_multilabel"])
|
||||
@pytest.mark.parametrize("component", ["ner", "tagger", "parser", "textcat_multilabel"])
|
||||
def test_rehearse(component):
|
||||
nlp = spacy.blank("en")
|
||||
nlp.add_pipe(component)
|
||||
|
|
Loading…
Reference in New Issue
Block a user