diff --git a/spacy/pipeline/transition_parser.pyx b/spacy/pipeline/transition_parser.pyx index 9e50dd7b2..dc512a2be 100644 --- a/spacy/pipeline/transition_parser.pyx +++ b/spacy/pipeline/transition_parser.pyx @@ -36,6 +36,11 @@ from ..errors import Errors, Warnings from .. import util +# TODO: Remove when we switch to Cython 3. +cdef extern from "" namespace "std" nogil: + bint equal[InputIt1, InputIt2](InputIt1 first1, InputIt1 last1, InputIt2 first2) except + + + NUMPY_OPS = NumpyOps() @@ -253,7 +258,7 @@ class Parser(TrainablePipe): # batch uniform length. Since we do not have a gold standard # sequence, we use the teacher's predictions as the gold # standard. - max_moves = int(random.uniform(max_moves // 2, max_moves * 2)) + max_moves = int(random.uniform(max(max_moves // 2, 1), max_moves * 2)) states = self._init_batch(teacher_pipe, student_docs, max_moves) else: states = self.moves.init_batch(student_docs) @@ -265,12 +270,12 @@ class Parser(TrainablePipe): # gradients of the student's transition distributions relative to the # teacher's distributions. - student_inputs = TransitionModelInputs(docs=student_docs, moves=self.moves, - max_moves=max_moves) + student_inputs = TransitionModelInputs(docs=student_docs, + states=[state.copy() for state in states], moves=self.moves, max_moves=max_moves) (student_states, student_scores), backprop_scores = self.model.begin_update(student_inputs) - actions = states2actions(student_states) + actions = _states_diff_to_actions(states, student_states) teacher_inputs = TransitionModelInputs(docs=[eg.reference for eg in examples], - moves=self.moves, actions=actions) + states=states, moves=teacher_pipe.moves, actions=actions) (_, teacher_scores) = teacher_pipe.model.predict(teacher_inputs) loss, d_scores = self.get_teacher_student_loss(teacher_scores, student_scores) @@ -642,7 +647,7 @@ class Parser(TrainablePipe): raise ValueError(Errors.E149) from None return self - def _init_batch(self, teacher_step_model, docs, max_length): + def _init_batch(self, teacher_pipe, docs, max_length): """Make a square batch of length equal to the shortest transition sequence or a cap. A long doc will get multiple states. Let's say we have a doc of length 2*N, @@ -651,10 +656,12 @@ class Parser(TrainablePipe): _init_gold_batch, this version uses a teacher model to generate the cut sequences.""" cdef: - StateClass start_state StateClass state - Transition action - all_states = self.moves.init_batch(docs) + TransitionSystem moves = teacher_pipe.moves + + # Start with the same heuristic as in supervised training: exclude + # docs that are within the maximum length. + all_states = moves.init_batch(docs) states = [] to_cut = [] for state, doc in zip(all_states, docs): @@ -663,18 +670,28 @@ class Parser(TrainablePipe): states.append(state) else: to_cut.append(state) + + if not to_cut: + return states + + # Parse the states that are too long with the teacher's parsing model. + teacher_inputs = TransitionModelInputs(docs=docs, moves=moves, + states=[state.copy() for state in to_cut]) + (teacher_states, _ ) = teacher_pipe.model.predict(teacher_inputs) + + # Step through the teacher's actions and store every state after + # each multiple of max_length. + teacher_actions = states2actions(teacher_states) while to_cut: states.extend(state.copy() for state in to_cut) - # Move states forward max_length actions. - length = 0 - while to_cut and length < max_length: - teacher_scores = teacher_step_model.predict(to_cut) - self.transition_states(to_cut, teacher_scores) - # States that are completed do not need further cutting. - to_cut = [state for state in to_cut if not state.is_final()] - length += 1 - return states + for step_actions in teacher_actions[:max_length]: + to_cut = moves.apply_actions(to_cut, step_actions) + teacher_actions = teacher_actions[max_length:] + if len(teacher_actions) < max_length: + break + + return states def _init_gold_batch(self, examples, max_length): """Make a square batch, of length equal to the shortest transition @@ -757,3 +774,46 @@ def states2actions(states: List[StateClass]) -> List[Ints1d]: actions.append(numpy.array(step_actions, dtype="i")) return actions + +def _states_diff_to_actions( + before_states: List[StateClass], + after_states: List[StateClass] +) -> List[Ints1d]: + """ + Return for two sets of states the actions to go from the first set of + states to the second set of states. The histories of the first set of + states must be prefix of the second set of states. + """ + cdef StateClass before_state, after_state + cdef StateC* c_state_before + cdef StateC* c_state_after + + assert len(before_states) == len(after_states) + + # Check invariant: before states histories must be prefixes of after states. + for before_state, after_state in zip(before_states, after_states): + c_state_before = before_state.c + c_state_after = after_state.c + + assert equal(c_state_after.history.begin(), + c_state_after.history.begin() + c_state_before.history.size(), + c_state_after.history.begin()) + + actions = [] + while True: + step = len(actions) + + step_actions = [] + for before_state, after_state in zip(before_states, after_states): + c_state_before = before_state.c + c_state_after = after_state.c + if step < c_state_after.history.size() - c_state_before.history.size(): + step_actions.append(c_state_after.history[c_state_before.history.size() + step]) + + # We are done if we have exhausted all histories. + if len(step_actions) == 0: + break + + actions.append(numpy.array(step_actions, dtype="i")) + + return actions diff --git a/spacy/tests/parser/test_parse.py b/spacy/tests/parser/test_parse.py index 57b6e188b..b708210e7 100644 --- a/spacy/tests/parser/test_parse.py +++ b/spacy/tests/parser/test_parse.py @@ -463,7 +463,8 @@ def test_is_distillable(): assert parser.is_distillable -def test_distill(): +@pytest.mark.parametrize("max_moves", [0, 1, 5, 100]) +def test_distill(max_moves): teacher = English() teacher_parser = teacher.add_pipe("parser") train_examples = [] @@ -481,6 +482,7 @@ def test_distill(): student = English() student_parser = student.add_pipe("parser") + student_parser.cfg["update_with_oracle_cut_size"] = max_moves student_parser.initialize( get_examples=lambda: train_examples, labels=teacher_parser.label_data )