mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-04 04:10:20 +03:00
Update parser distillation for the refactor
In the old parser, we'd iterate over the transitions in the distill function and compute the loss/gradients on the go. In the refactored parser, we first let the student model parse the inputs. Then we'll let the teacher compute the transition probabilities of the states in the student's transition sequence. We can then compute the gradients of the student given the teacher.
This commit is contained in:
parent
e4ee6a4a2c
commit
e73a591358
|
@ -247,15 +247,6 @@ class Parser(TrainablePipe):
|
|||
|
||||
student_docs = [eg.predicted for eg in examples]
|
||||
|
||||
teacher_step_model = teacher_pipe.model.predict([eg.reference for eg in examples])
|
||||
student_step_model, backprop_tok2vec = self.model.begin_update(student_docs)
|
||||
|
||||
# Add softmax activation, so that we can compute student losses
|
||||
# with cross-entropy loss.
|
||||
with use_ops("numpy"):
|
||||
teacher_model = chain(teacher_step_model, softmax_activation())
|
||||
student_model = chain(student_step_model, softmax_activation())
|
||||
|
||||
max_moves = self.cfg["update_with_oracle_cut_size"]
|
||||
if max_moves >= 1:
|
||||
# Chop sequences into lengths of this many words, to make the
|
||||
|
@ -263,51 +254,39 @@ class Parser(TrainablePipe):
|
|||
# sequence, we use the teacher's predictions as the gold
|
||||
# standard.
|
||||
max_moves = int(random.uniform(max_moves // 2, max_moves * 2))
|
||||
states = self._init_batch(teacher_step_model, student_docs, max_moves)
|
||||
states = self._init_batch(teacher_pipe, student_docs, max_moves)
|
||||
else:
|
||||
states = self.moves.init_batch(student_docs)
|
||||
|
||||
loss = 0.0
|
||||
n_moves = 0
|
||||
while states:
|
||||
# We do distillation as follows: (1) for every state, we compute the
|
||||
# transition softmax distributions: (2) we backpropagate the error of
|
||||
# the student (compared to the teacher) into the student model; (3)
|
||||
# for all states, we move to the next state using the student's
|
||||
# predictions.
|
||||
teacher_scores = teacher_model.predict(states)
|
||||
student_scores, backprop = student_model.begin_update(states)
|
||||
state_loss, d_scores = self.get_teacher_student_loss(teacher_scores, student_scores)
|
||||
backprop(d_scores)
|
||||
loss += state_loss
|
||||
self.transition_states(states, student_scores)
|
||||
states = [state for state in states if not state.is_final()]
|
||||
# We distill as follows: 1. we first let the student predict transition
|
||||
# sequences (and the corresponding transition probabilities); (2) we
|
||||
# let the teacher follow the student's predicted transition sequences
|
||||
# to obtain the teacher's transition probabilities; (3) we compute the
|
||||
# gradients of the student's transition distributions relative to the
|
||||
# teacher's distributions.
|
||||
|
||||
# Stop when we reach the maximum number of moves, otherwise we start
|
||||
# to process the remainder of cut sequences again.
|
||||
if max_moves >= 1 and n_moves >= max_moves:
|
||||
break
|
||||
n_moves += 1
|
||||
student_inputs = TransitionModelInputs(docs=student_docs, moves=self.moves,
|
||||
max_moves=max_moves)
|
||||
(student_states, student_scores), backprop_scores = self.model.begin_update(student_inputs)
|
||||
actions = states2actions(student_states)
|
||||
teacher_inputs = TransitionModelInputs(docs=[eg.reference for eg in examples],
|
||||
moves=self.moves, actions=actions)
|
||||
(_, teacher_scores) = teacher_pipe.model.predict(teacher_inputs)
|
||||
|
||||
backprop_tok2vec(student_docs)
|
||||
loss, d_scores = self.get_teacher_student_loss(teacher_scores, student_scores)
|
||||
backprop_scores((student_states, d_scores))
|
||||
|
||||
if sgd is not None:
|
||||
self.finish_update(sgd)
|
||||
|
||||
losses[self.name] += loss
|
||||
|
||||
del backprop
|
||||
del backprop_tok2vec
|
||||
teacher_step_model.clear_memory()
|
||||
student_step_model.clear_memory()
|
||||
del teacher_model
|
||||
del student_model
|
||||
|
||||
return losses
|
||||
|
||||
|
||||
def get_teacher_student_loss(
|
||||
self, teacher_scores: List[Floats2d], student_scores: List[Floats2d]
|
||||
self, teacher_scores: List[Floats2d], student_scores: List[Floats2d],
|
||||
normalize: bool=False,
|
||||
) -> Tuple[float, List[Floats2d]]:
|
||||
"""Calculate the loss and its gradient for a batch of student
|
||||
scores, relative to teacher scores.
|
||||
|
@ -319,10 +298,28 @@ class Parser(TrainablePipe):
|
|||
|
||||
DOCS: https://spacy.io/api/dependencyparser#get_teacher_student_loss
|
||||
"""
|
||||
loss_func = LegacySequenceCategoricalCrossentropy(normalize=False)
|
||||
d_scores, loss = loss_func(student_scores, teacher_scores)
|
||||
if self.model.ops.xp.isnan(loss):
|
||||
raise ValueError(Errors.E910.format(name=self.name))
|
||||
|
||||
# We can't easily hook up a softmax layer in the parsing model, since
|
||||
# the get_loss does additional masking. So, we could apply softmax
|
||||
# manually here and use Thinc's cross-entropy loss. But it's a bit
|
||||
# suboptimal, since we can have a lot of states that would result in
|
||||
# many kernel launches. Futhermore the parsing model's backprop expects
|
||||
# a XP array, so we'd have to concat the softmaxes anyway. So, like
|
||||
# the get_loss implementation, we'll compute the loss and gradients
|
||||
# ourselves.
|
||||
|
||||
teacher_scores = self.model.ops.softmax(self.model.ops.xp.vstack(teacher_scores),
|
||||
axis=-1, inplace=True)
|
||||
student_scores = self.model.ops.softmax(self.model.ops.xp.vstack(student_scores),
|
||||
axis=-1, inplace=True)
|
||||
|
||||
assert teacher_scores.shape == student_scores.shape
|
||||
|
||||
d_scores = student_scores - teacher_scores
|
||||
if normalize:
|
||||
d_scores /= d_scores.shape[0]
|
||||
loss = (d_scores**2).sum() / d_scores.size
|
||||
|
||||
return float(loss), d_scores
|
||||
|
||||
def init_multitask_objectives(self, get_examples, pipeline, **cfg):
|
||||
|
@ -529,6 +526,8 @@ class Parser(TrainablePipe):
|
|||
teacher_inputs = TransitionModelInputs(docs=docs, moves=self.moves, actions=actions)
|
||||
_, teacher_scores = self._rehearsal_model.predict(teacher_inputs)
|
||||
|
||||
loss, d_scores = self.get_teacher_student_loss(teacher_scores, student_scores, normalize=True)
|
||||
|
||||
teacher_scores = self.model.ops.xp.vstack(teacher_scores)
|
||||
student_scores = self.model.ops.xp.vstack(student_scores)
|
||||
assert teacher_scores.shape == student_scores.shape
|
||||
|
|
Loading…
Reference in New Issue
Block a user