diff --git a/spacy/pipeline/transition_parser.pyx b/spacy/pipeline/transition_parser.pyx index 09b5f6181..a2b6c167f 100644 --- a/spacy/pipeline/transition_parser.pyx +++ b/spacy/pipeline/transition_parser.pyx @@ -247,15 +247,6 @@ class Parser(TrainablePipe): student_docs = [eg.predicted for eg in examples] - teacher_step_model = teacher_pipe.model.predict([eg.reference for eg in examples]) - student_step_model, backprop_tok2vec = self.model.begin_update(student_docs) - - # Add softmax activation, so that we can compute student losses - # with cross-entropy loss. - with use_ops("numpy"): - teacher_model = chain(teacher_step_model, softmax_activation()) - student_model = chain(student_step_model, softmax_activation()) - max_moves = self.cfg["update_with_oracle_cut_size"] if max_moves >= 1: # Chop sequences into lengths of this many words, to make the @@ -263,51 +254,39 @@ class Parser(TrainablePipe): # sequence, we use the teacher's predictions as the gold # standard. max_moves = int(random.uniform(max_moves // 2, max_moves * 2)) - states = self._init_batch(teacher_step_model, student_docs, max_moves) + states = self._init_batch(teacher_pipe, student_docs, max_moves) else: states = self.moves.init_batch(student_docs) - loss = 0.0 - n_moves = 0 - while states: - # We do distillation as follows: (1) for every state, we compute the - # transition softmax distributions: (2) we backpropagate the error of - # the student (compared to the teacher) into the student model; (3) - # for all states, we move to the next state using the student's - # predictions. - teacher_scores = teacher_model.predict(states) - student_scores, backprop = student_model.begin_update(states) - state_loss, d_scores = self.get_teacher_student_loss(teacher_scores, student_scores) - backprop(d_scores) - loss += state_loss - self.transition_states(states, student_scores) - states = [state for state in states if not state.is_final()] + # We distill as follows: 1. we first let the student predict transition + # sequences (and the corresponding transition probabilities); (2) we + # let the teacher follow the student's predicted transition sequences + # to obtain the teacher's transition probabilities; (3) we compute the + # gradients of the student's transition distributions relative to the + # teacher's distributions. - # Stop when we reach the maximum number of moves, otherwise we start - # to process the remainder of cut sequences again. - if max_moves >= 1 and n_moves >= max_moves: - break - n_moves += 1 + student_inputs = TransitionModelInputs(docs=student_docs, moves=self.moves, + max_moves=max_moves) + (student_states, student_scores), backprop_scores = self.model.begin_update(student_inputs) + actions = states2actions(student_states) + teacher_inputs = TransitionModelInputs(docs=[eg.reference for eg in examples], + moves=self.moves, actions=actions) + (_, teacher_scores) = teacher_pipe.model.predict(teacher_inputs) - backprop_tok2vec(student_docs) + loss, d_scores = self.get_teacher_student_loss(teacher_scores, student_scores) + backprop_scores((student_states, d_scores)) if sgd is not None: self.finish_update(sgd) losses[self.name] += loss - del backprop - del backprop_tok2vec - teacher_step_model.clear_memory() - student_step_model.clear_memory() - del teacher_model - del student_model - return losses def get_teacher_student_loss( - self, teacher_scores: List[Floats2d], student_scores: List[Floats2d] + self, teacher_scores: List[Floats2d], student_scores: List[Floats2d], + normalize: bool=False, ) -> Tuple[float, List[Floats2d]]: """Calculate the loss and its gradient for a batch of student scores, relative to teacher scores. @@ -319,10 +298,28 @@ class Parser(TrainablePipe): DOCS: https://spacy.io/api/dependencyparser#get_teacher_student_loss """ - loss_func = LegacySequenceCategoricalCrossentropy(normalize=False) - d_scores, loss = loss_func(student_scores, teacher_scores) - if self.model.ops.xp.isnan(loss): - raise ValueError(Errors.E910.format(name=self.name)) + + # We can't easily hook up a softmax layer in the parsing model, since + # the get_loss does additional masking. So, we could apply softmax + # manually here and use Thinc's cross-entropy loss. But it's a bit + # suboptimal, since we can have a lot of states that would result in + # many kernel launches. Futhermore the parsing model's backprop expects + # a XP array, so we'd have to concat the softmaxes anyway. So, like + # the get_loss implementation, we'll compute the loss and gradients + # ourselves. + + teacher_scores = self.model.ops.softmax(self.model.ops.xp.vstack(teacher_scores), + axis=-1, inplace=True) + student_scores = self.model.ops.softmax(self.model.ops.xp.vstack(student_scores), + axis=-1, inplace=True) + + assert teacher_scores.shape == student_scores.shape + + d_scores = student_scores - teacher_scores + if normalize: + d_scores /= d_scores.shape[0] + loss = (d_scores**2).sum() / d_scores.size + return float(loss), d_scores def init_multitask_objectives(self, get_examples, pipeline, **cfg): @@ -529,6 +526,8 @@ class Parser(TrainablePipe): teacher_inputs = TransitionModelInputs(docs=docs, moves=self.moves, actions=actions) _, teacher_scores = self._rehearsal_model.predict(teacher_inputs) + loss, d_scores = self.get_teacher_student_loss(teacher_scores, student_scores, normalize=True) + teacher_scores = self.model.ops.xp.vstack(teacher_scores) student_scores = self.model.ops.xp.vstack(student_scores) assert teacher_scores.shape == student_scores.shape