Update parser distillation for the refactor

In the old parser, we'd iterate over the transitions in the distill
function and compute the loss/gradients on the go. In the refactored
parser, we first let the student model parse the inputs. Then we'll let
the teacher compute the transition probabilities of the states in the
student's transition sequence. We can then compute the gradients of the
student given the teacher.
This commit is contained in:
Daniël de Kok 2023-01-16 20:01:15 +01:00
parent e4ee6a4a2c
commit e73a591358

View File

@ -247,15 +247,6 @@ class Parser(TrainablePipe):
student_docs = [eg.predicted for eg in examples] student_docs = [eg.predicted for eg in examples]
teacher_step_model = teacher_pipe.model.predict([eg.reference for eg in examples])
student_step_model, backprop_tok2vec = self.model.begin_update(student_docs)
# Add softmax activation, so that we can compute student losses
# with cross-entropy loss.
with use_ops("numpy"):
teacher_model = chain(teacher_step_model, softmax_activation())
student_model = chain(student_step_model, softmax_activation())
max_moves = self.cfg["update_with_oracle_cut_size"] max_moves = self.cfg["update_with_oracle_cut_size"]
if max_moves >= 1: if max_moves >= 1:
# Chop sequences into lengths of this many words, to make the # Chop sequences into lengths of this many words, to make the
@ -263,51 +254,39 @@ class Parser(TrainablePipe):
# sequence, we use the teacher's predictions as the gold # sequence, we use the teacher's predictions as the gold
# standard. # standard.
max_moves = int(random.uniform(max_moves // 2, max_moves * 2)) max_moves = int(random.uniform(max_moves // 2, max_moves * 2))
states = self._init_batch(teacher_step_model, student_docs, max_moves) states = self._init_batch(teacher_pipe, student_docs, max_moves)
else: else:
states = self.moves.init_batch(student_docs) states = self.moves.init_batch(student_docs)
loss = 0.0 # We distill as follows: 1. we first let the student predict transition
n_moves = 0 # sequences (and the corresponding transition probabilities); (2) we
while states: # let the teacher follow the student's predicted transition sequences
# We do distillation as follows: (1) for every state, we compute the # to obtain the teacher's transition probabilities; (3) we compute the
# transition softmax distributions: (2) we backpropagate the error of # gradients of the student's transition distributions relative to the
# the student (compared to the teacher) into the student model; (3) # teacher's distributions.
# for all states, we move to the next state using the student's
# predictions.
teacher_scores = teacher_model.predict(states)
student_scores, backprop = student_model.begin_update(states)
state_loss, d_scores = self.get_teacher_student_loss(teacher_scores, student_scores)
backprop(d_scores)
loss += state_loss
self.transition_states(states, student_scores)
states = [state for state in states if not state.is_final()]
# Stop when we reach the maximum number of moves, otherwise we start student_inputs = TransitionModelInputs(docs=student_docs, moves=self.moves,
# to process the remainder of cut sequences again. max_moves=max_moves)
if max_moves >= 1 and n_moves >= max_moves: (student_states, student_scores), backprop_scores = self.model.begin_update(student_inputs)
break actions = states2actions(student_states)
n_moves += 1 teacher_inputs = TransitionModelInputs(docs=[eg.reference for eg in examples],
moves=self.moves, actions=actions)
(_, teacher_scores) = teacher_pipe.model.predict(teacher_inputs)
backprop_tok2vec(student_docs) loss, d_scores = self.get_teacher_student_loss(teacher_scores, student_scores)
backprop_scores((student_states, d_scores))
if sgd is not None: if sgd is not None:
self.finish_update(sgd) self.finish_update(sgd)
losses[self.name] += loss losses[self.name] += loss
del backprop
del backprop_tok2vec
teacher_step_model.clear_memory()
student_step_model.clear_memory()
del teacher_model
del student_model
return losses return losses
def get_teacher_student_loss( def get_teacher_student_loss(
self, teacher_scores: List[Floats2d], student_scores: List[Floats2d] self, teacher_scores: List[Floats2d], student_scores: List[Floats2d],
normalize: bool=False,
) -> Tuple[float, List[Floats2d]]: ) -> Tuple[float, List[Floats2d]]:
"""Calculate the loss and its gradient for a batch of student """Calculate the loss and its gradient for a batch of student
scores, relative to teacher scores. scores, relative to teacher scores.
@ -319,10 +298,28 @@ class Parser(TrainablePipe):
DOCS: https://spacy.io/api/dependencyparser#get_teacher_student_loss DOCS: https://spacy.io/api/dependencyparser#get_teacher_student_loss
""" """
loss_func = LegacySequenceCategoricalCrossentropy(normalize=False)
d_scores, loss = loss_func(student_scores, teacher_scores) # We can't easily hook up a softmax layer in the parsing model, since
if self.model.ops.xp.isnan(loss): # the get_loss does additional masking. So, we could apply softmax
raise ValueError(Errors.E910.format(name=self.name)) # manually here and use Thinc's cross-entropy loss. But it's a bit
# suboptimal, since we can have a lot of states that would result in
# many kernel launches. Futhermore the parsing model's backprop expects
# a XP array, so we'd have to concat the softmaxes anyway. So, like
# the get_loss implementation, we'll compute the loss and gradients
# ourselves.
teacher_scores = self.model.ops.softmax(self.model.ops.xp.vstack(teacher_scores),
axis=-1, inplace=True)
student_scores = self.model.ops.softmax(self.model.ops.xp.vstack(student_scores),
axis=-1, inplace=True)
assert teacher_scores.shape == student_scores.shape
d_scores = student_scores - teacher_scores
if normalize:
d_scores /= d_scores.shape[0]
loss = (d_scores**2).sum() / d_scores.size
return float(loss), d_scores return float(loss), d_scores
def init_multitask_objectives(self, get_examples, pipeline, **cfg): def init_multitask_objectives(self, get_examples, pipeline, **cfg):
@ -529,6 +526,8 @@ class Parser(TrainablePipe):
teacher_inputs = TransitionModelInputs(docs=docs, moves=self.moves, actions=actions) teacher_inputs = TransitionModelInputs(docs=docs, moves=self.moves, actions=actions)
_, teacher_scores = self._rehearsal_model.predict(teacher_inputs) _, teacher_scores = self._rehearsal_model.predict(teacher_inputs)
loss, d_scores = self.get_teacher_student_loss(teacher_scores, student_scores, normalize=True)
teacher_scores = self.model.ops.xp.vstack(teacher_scores) teacher_scores = self.model.ops.xp.vstack(teacher_scores)
student_scores = self.model.ops.xp.vstack(student_scores) student_scores = self.model.ops.xp.vstack(student_scores)
assert teacher_scores.shape == student_scores.shape assert teacher_scores.shape == student_scores.shape