From 774930c1ff6bc7ab116cb20182a8e6102fe3100f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 13 Jan 2023 17:05:17 +0100 Subject: [PATCH] Bring back support for `update_with_oracle_cut_size` (#12086) * Bring back support for `update_with_oracle_cut_size` This option was available in the pre-refactor parser, but was never implemented in the refactored parser. This option cuts transition sequences that are longer than `update_with_oracle_cut` size into separate sequences that have at most `update_with_oracle_cut` transitions. The oracle (gold standard) transition sequence is used to determine the cuts and the initial states for the additional sequences. Applying this cut makes the batches more homogeneous in the transition sequence lengths, making forward passes (and as a consequence training) much faster. Training time 1000 steps on de_core_news_lg: - Before this change: 149s - After this change: 68s - Pre-refactor parser: 81s * Fix a rename that was missed in #10878. So that rehearsal tests pass. * Apply suggestions from @shadeMe * Use chained conditional * Test with update_with_oracle_cut_size={0, 1, 5, 100} And fix a git that occurs with a cut size of 1. --- spacy/ml/tb_framework.pyx | 88 +++++++++++++++++---- spacy/pipeline/transition_parser.pyx | 110 ++++++++++++++++++++++----- spacy/tests/parser/test_parse.py | 8 +- 3 files changed, 168 insertions(+), 38 deletions(-) diff --git a/spacy/ml/tb_framework.pyx b/spacy/ml/tb_framework.pyx index 372849bae..f0316d8f9 100644 --- a/spacy/ml/tb_framework.pyx +++ b/spacy/ml/tb_framework.pyx @@ -148,32 +148,77 @@ def init( # model = _lsuv_init(model) return model -InWithoutActions = Tuple[List[Doc], TransitionSystem] -InWithActions = Tuple[List[Doc], TransitionSystem, List[Ints1d]] -InT = TypeVar("InT", InWithoutActions, InWithActions) -def forward(model, docs_moves: InT, is_train: bool): - if len(docs_moves) == 2: - docs, moves = docs_moves - actions = None - else: - docs, moves, actions = docs_moves +class TransitionModelInputs: + """ + Input to transition model. + """ + + # dataclass annotation is not yet supported in Cython 0.29.x, + # so, we'll do something close to it. + + actions: Optional[List[Ints1d]] + docs: List[Doc] + max_moves: int + moves: TransitionSystem + states: Optional[List[State]] + + __slots__ = [ + "actions", + "docs", + "max_moves", + "moves", + "states", + ] + + def __init__( + self, + docs: List[Doc], + moves: TransitionSystem, + actions: Optional[List[Ints1d]]=None, + max_moves: int=0, + states: Optional[List[State]]=None): + """ + actions (Optional[List[Ints1d]]): actions to apply for each Doc. + docs (List[Doc]): Docs to predict transition sequences for. + max_moves: (int): the maximum number of moves to apply, values less + than 1 will apply moves to states until they are final states. + moves (TransitionSystem): the transition system to use when predicting + the transition sequences. + states (Optional[List[States]]): the initial states to predict the + transition sequences for. When absent, the initial states are + initialized from the provided Docs. + """ + self.actions = actions + self.docs = docs + self.moves = moves + self.max_moves = max_moves + self.states = states + + +def forward(model, inputs: TransitionModelInputs, is_train: bool): + docs = inputs.docs + moves = inputs.moves + actions = inputs.actions beam_width = model.attrs["beam_width"] hidden_pad = model.get_param("hidden_pad") tok2vec = model.get_ref("tok2vec") - states = moves.init_batch(docs) + states = moves.init_batch(docs) if inputs.states is None else inputs.states tokvecs, backprop_tok2vec = tok2vec(docs, is_train) tokvecs = model.ops.xp.vstack((tokvecs, hidden_pad)) feats, backprop_feats = _forward_precomputable_affine(model, tokvecs, is_train) seen_mask = _get_seen_mask(model) - # Fixme: support actions in forward_cpu - if beam_width == 1 and not is_train and isinstance(model.ops, NumpyOps): + if not is_train and beam_width == 1 and isinstance(model.ops, NumpyOps): + # Note: max_moves is only used during training, so we don't need to + # pass it to the greedy inference path. return _forward_greedy_cpu(model, moves, states, feats, seen_mask, actions=actions) else: - return _forward_fallback(model, moves, states, tokvecs, backprop_tok2vec, feats, backprop_feats, seen_mask, is_train, actions=actions) + return _forward_fallback(model, moves, states, tokvecs, backprop_tok2vec, + feats, backprop_feats, seen_mask, is_train, actions=actions, + max_moves=inputs.max_moves) def _forward_greedy_cpu(model: Model, TransitionSystem moves, states: List[StateClass], np.ndarray feats, @@ -229,8 +274,17 @@ cdef list _parse_batch(CBlas cblas, TransitionSystem moves, StateC** states, return scores -def _forward_fallback(model: Model, moves: TransitionSystem, states: List[StateClass], tokvecs, backprop_tok2vec, feats, backprop_feats, seen_mask, is_train: bool, - actions: Optional[List[Ints1d]]=None): +def _forward_fallback( + model: Model, + moves: TransitionSystem, + states: List[StateClass], + tokvecs, backprop_tok2vec, + feats, + backprop_feats, + seen_mask, + is_train: bool, + actions: Optional[List[Ints1d]]=None, + max_moves: int=0): nF = model.get_dim("nF") output = model.get_ref("output") hidden_b = model.get_param("hidden_b") @@ -253,6 +307,7 @@ def _forward_fallback(model: Model, moves: TransitionSystem, states: List[StateC moves, states, None, width=beam_width, density=beam_density ) arange = ops.xp.arange(nF) + n_moves = 0 while not batch.is_done: ids = numpy.zeros((len(batch.get_unfinished_states()), nF), dtype="i") for i, state in enumerate(batch.get_unfinished_states()): @@ -281,6 +336,9 @@ def _forward_fallback(model: Model, moves: TransitionSystem, states: List[StateC all_ids.append(ids) all_statevecs.append(statevecs) all_which.append(which) + if n_moves >= max_moves >= 1: + break + n_moves += 1 def backprop_parser(d_states_d_scores): ids = ops.xp.vstack(all_ids) diff --git a/spacy/pipeline/transition_parser.pyx b/spacy/pipeline/transition_parser.pyx index 2f7acf2fc..e6119ee79 100644 --- a/spacy/pipeline/transition_parser.pyx +++ b/spacy/pipeline/transition_parser.pyx @@ -18,6 +18,7 @@ import numpy.random import numpy import warnings +from ..ml.tb_framework import TransitionModelInputs from ._parser_internals.stateclass cimport StateC, StateClass from ._parser_internals.search cimport Beam from ..tokens.doc cimport Doc @@ -25,7 +26,7 @@ from .trainable_pipe cimport TrainablePipe from ._parser_internals cimport _beam_utils from ._parser_internals import _beam_utils from ..vocab cimport Vocab -from ._parser_internals.transition_system cimport TransitionSystem +from ._parser_internals.transition_system cimport Transition, TransitionSystem from ..typedefs cimport weight_t from ..training import validate_examples, validate_get_examples @@ -253,20 +254,23 @@ class Parser(TrainablePipe): result = self.moves.init_batch(docs) return result with _change_attrs(self.model, beam_width=self.cfg["beam_width"], beam_density=self.cfg["beam_density"]): - states_or_beams, _ = self.model.predict((docs, self.moves)) + inputs = TransitionModelInputs(docs=docs, moves=self.moves) + states_or_beams, _ = self.model.predict(inputs) return states_or_beams def greedy_parse(self, docs, drop=0.): self._resize() self._ensure_labels_are_added(docs) with _change_attrs(self.model, beam_width=1): - states, _ = self.model.predict((docs, self.moves)) + inputs = TransitionModelInputs(docs=docs, moves=self.moves) + states, _ = self.model.predict(inputs) return states def beam_parse(self, docs, int beam_width, float drop=0., beam_density=0.): self._ensure_labels_are_added(docs) with _change_attrs(self.model, beam_width=self.cfg["beam_width"], beam_density=self.cfg["beam_density"]): - beams, _ = self.model.predict((docs, self.moves)) + inputs = TransitionModelInputs(docs=docs, moves=self.moves) + beams, _ = self.model.predict(inputs) return beams def set_annotations(self, docs, states_or_beams): @@ -297,11 +301,27 @@ class Parser(TrainablePipe): return losses set_dropout_rate(self.model, drop) docs = [eg.x for eg in examples if len(eg.x)] - (states, scores), backprop_scores = self.model.begin_update((docs, self.moves)) + + max_moves = self.cfg["update_with_oracle_cut_size"] + if max_moves >= 1: + # Chop sequences into lengths of this many words, to make the + # batch uniform length. + max_moves = int(random.uniform(max(max_moves // 2, 1), max_moves * 2)) + init_states, gold_states, _ = self._init_gold_batch( + examples, + max_length=max_moves + ) + else: + init_states, gold_states, _ = self.moves.init_gold_batch(examples) + + inputs = TransitionModelInputs(docs=docs, moves=self.moves, + max_moves=max_moves, states=[state.copy() for state in init_states]) + (pred_states, scores), backprop_scores = self.model.begin_update(inputs) if sum(s.shape[0] for s in scores) == 0: return losses - d_scores = self.get_loss((states, scores), examples) - backprop_scores((states, d_scores)) + d_scores = self.get_loss((gold_states, init_states, pred_states, scores), + examples, max_moves) + backprop_scores((pred_states, d_scores)) if sgd not in (None, False): self.finish_update(sgd) losses[self.name] += float((d_scores**2).sum()) @@ -311,12 +331,15 @@ class Parser(TrainablePipe): del backprop_scores return losses - def get_loss(self, states_scores, examples): - states, scores = states_scores + def get_loss(self, states_scores, examples, max_moves): + gold_states, init_states, pred_states, scores = states_scores scores = self.model.ops.xp.vstack(scores) costs = self._get_costs_from_histories( examples, - [list(state.history) for state in states] + gold_states, + init_states, + [list(state.history) for state in pred_states], + max_moves ) xp = get_array_module(scores) best_costs = costs.min(axis=1, keepdims=True) @@ -334,7 +357,7 @@ class Parser(TrainablePipe): d_scores -= (costs <= best_costs) * (exp_gscores / gZ) return d_scores - def _get_costs_from_histories(self, examples, histories): + def _get_costs_from_histories(self, examples, gold_states, init_states, histories, max_moves): cdef TransitionSystem moves = self.moves cdef StateClass state cdef int clas @@ -344,16 +367,12 @@ class Parser(TrainablePipe): cdef Pool mem = Pool() cdef np.ndarray costs_i is_valid = mem.alloc(nO, sizeof(int)) - states = moves.init_batch([eg.x for eg in examples]) - batch = [] - for eg, s, h in zip(examples, states, histories): - if not s.is_final(): - gold = moves.init_gold(s, eg) - batch.append((eg, s, h, gold)) + batch = list(zip(init_states, histories, gold_states)) + n_moves = 0 output = [] while batch: costs = numpy.zeros((len(batch), nO), dtype="f") - for i, (eg, state, history, gold) in enumerate(batch): + for i, (state, history, gold) in enumerate(batch): costs_i = costs[i] clas = history.pop(0) moves.set_costs(is_valid, costs_i.data, state.c, gold) @@ -361,7 +380,11 @@ class Parser(TrainablePipe): action.do(state.c, action.label) state.c.history.push_back(clas) output.append(costs) - batch = [(eg, s, h, g) for eg, s, h, g in batch if len(h) != 0] + batch = [(s, h, g) for s, h, g in batch if len(h) != 0] + if n_moves >= max_moves >= 1: + break + n_moves += 1 + return self.model.ops.xp.vstack(output) def rehearse(self, examples, sgd=None, losses=None, **cfg): @@ -383,9 +406,11 @@ class Parser(TrainablePipe): # Prepare the stepwise model, and get the callback for finishing the batch set_dropout_rate(self._rehearsal_model, 0.0) set_dropout_rate(self.model, 0.0) - (student_states, student_scores), backprop_scores = self.model.begin_update((docs, self.moves)) + student_inputs = TransitionModelInputs(docs=docs, moves=self.moves) + (student_states, student_scores), backprop_scores = self.model.begin_update(student_inputs) actions = states2actions(student_states) - _, teacher_scores = self._rehearsal_model.predict((docs, self.moves, actions)) + teacher_inputs = TransitionModelInputs(docs=docs, moves=self.moves, actions=actions) + _, teacher_scores = self._rehearsal_model.predict(teacher_inputs) teacher_scores = self.model.ops.xp.vstack(teacher_scores) student_scores = self.model.ops.xp.vstack(student_scores) @@ -501,6 +526,49 @@ class Parser(TrainablePipe): raise ValueError(Errors.E149) from None return self + def _init_gold_batch(self, examples, max_length): + """Make a square batch, of length equal to the shortest transition + sequence or a cap. A long doc will get multiple states. Let's say we + have a doc of length 2*N, where N is the shortest doc. We'll make + two states, one representing long_doc[:N], and another representing + long_doc[N:].""" + cdef: + StateClass start_state + StateClass state + Transition action + TransitionSystem moves = self.moves + all_states = moves.init_batch([eg.predicted for eg in examples]) + states = [] + golds = [] + to_cut = [] + for state, eg in zip(all_states, examples): + if moves.has_gold(eg) and not state.is_final(): + gold = moves.init_gold(state, eg) + if len(eg.x) < max_length: + states.append(state) + golds.append(gold) + else: + oracle_actions = moves.get_oracle_sequence_from_state( + state.copy(), gold) + to_cut.append((eg, state, gold, oracle_actions)) + if not to_cut: + return states, golds, 0 + cdef int clas + for eg, state, gold, oracle_actions in to_cut: + for i in range(0, len(oracle_actions), max_length): + start_state = state.copy() + for clas in oracle_actions[i:i+max_length]: + action = moves.c[clas] + action.do(state.c, action.label) + if state.is_final(): + break + if moves.has_gold(eg, start_state.B(0), state.B(0)): + states.append(start_state) + golds.append(gold) + if state.is_final(): + break + return states, golds, max_length + @contextlib.contextmanager def _change_attrs(model, **kwargs): diff --git a/spacy/tests/parser/test_parse.py b/spacy/tests/parser/test_parse.py index af33dcf5f..df463b700 100644 --- a/spacy/tests/parser/test_parse.py +++ b/spacy/tests/parser/test_parse.py @@ -1,3 +1,4 @@ +import itertools import pytest import numpy from numpy.testing import assert_equal @@ -401,12 +402,15 @@ def test_incomplete_data(pipe_name): assert doc[2].head.i == 1 -@pytest.mark.parametrize("pipe_name", PARSERS) -def test_overfitting_IO(pipe_name): +@pytest.mark.parametrize( + "pipe_name,max_moves", itertools.product(PARSERS, [0, 1, 5, 100]) +) +def test_overfitting_IO(pipe_name, max_moves): fix_random_seed(0) # Simple test to try and quickly overfit the dependency parser (normal or beam) nlp = English() parser = nlp.add_pipe(pipe_name) + parser.cfg["update_with_oracle_cut_size"] = max_moves train_examples = [] for text, annotations in TRAIN_DATA: train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))