diff --git a/spacy/ml/tb_framework.pyx b/spacy/ml/tb_framework.pyx index b5ef2a489..1ee9716dc 100644 --- a/spacy/ml/tb_framework.pyx +++ b/spacy/ml/tb_framework.pyx @@ -148,32 +148,74 @@ def init( # model = _lsuv_init(model) return model -InWithoutActions = Tuple[List[Doc], TransitionSystem] -InWithActions = Tuple[List[Doc], TransitionSystem, List[Ints1d]] -InT = TypeVar("InT", InWithoutActions, InWithActions) -def forward(model, docs_moves: InT, is_train: bool): - if len(docs_moves) == 2: - docs, moves = docs_moves - actions = None - else: - docs, moves, actions = docs_moves +class TransitionModelInputs: + """ + Input to transition model. + """ + + # dataclass annotation is not yet supported in Cython 0.29.x, + # so, we'll do something close to it. + + actions: Optional[List[Ints1d]] + docs: List[Doc] + max_moves: int + moves: TransitionSystem + states: Optional[List[State]] + + __slots__ = [ + "actions", + "docs", + "max_moves", + "moves", + "states", + ] + + def __init__(self, docs: List[Doc], moves: TransitionSystem, + actions: Optional[List[Ints1d]]=None, max_moves: int=0, + states: Optional[List[State]]=None): + """ + actions (Optional[List[Ints1d]]): actions to apply for each Doc. + docs (List[Doc]): Docs to predict transition sequences for. + max_moves: (Optional[int]): the maximum number of moves to apply, + values less than 1 will apply moves to states until they are + final states. + moves (TransitionSystem): the transition system to use when predicting + the transition sequences. + states (Optional[List[States]]): the initial states to predict the + transition sequences for. When absent, the initial states are + initialized from the provided Docs. + """ + self.actions = actions + self.docs = docs + self.moves = moves + self.max_moves = max_moves + self.states = states + + +def forward(model, inputs: TransitionModelInputs, is_train: bool): + docs = inputs.docs + moves = inputs.moves + actions = inputs.actions beam_width = model.attrs["beam_width"] hidden_pad = model.get_param("hidden_pad") tok2vec = model.get_ref("tok2vec") - states = moves.init_batch(docs) + states = moves.init_batch(docs) if inputs.states is None else inputs.states tokvecs, backprop_tok2vec = tok2vec(docs, is_train) tokvecs = model.ops.xp.vstack((tokvecs, hidden_pad)) feats, backprop_feats = _forward_precomputable_affine(model, tokvecs, is_train) seen_mask = _get_seen_mask(model) - # Fixme: support actions in forward_cpu if beam_width == 1 and not is_train and isinstance(model.ops, NumpyOps): + # Note: max_moves is only used during training, so we don't need to + # pass it to the greedy inference path. return _forward_greedy_cpu(model, moves, states, feats, seen_mask, actions=actions) else: - return _forward_fallback(model, moves, states, tokvecs, backprop_tok2vec, feats, backprop_feats, seen_mask, is_train, actions=actions) + return _forward_fallback(model, moves, states, tokvecs, backprop_tok2vec, + feats, backprop_feats, seen_mask, is_train, actions=actions, + max_moves=inputs.max_moves) def _forward_greedy_cpu(model: Model, TransitionSystem moves, states: List[StateClass], np.ndarray feats, @@ -230,7 +272,7 @@ cdef list _parse_batch(CBlas cblas, TransitionSystem moves, StateC** states, def _forward_fallback(model: Model, moves: TransitionSystem, states: List[StateClass], tokvecs, backprop_tok2vec, feats, backprop_feats, seen_mask, is_train: bool, - actions: Optional[List[Ints1d]]=None): + actions: Optional[List[Ints1d]]=None, max_moves=0): nF = model.get_dim("nF") output = model.get_ref("output") hidden_b = model.get_param("hidden_b") @@ -253,6 +295,7 @@ def _forward_fallback(model: Model, moves: TransitionSystem, states: List[StateC moves, states, None, width=beam_width, density=beam_density ) arange = ops.xp.arange(nF) + n_moves = 0 while not batch.is_done: ids = numpy.zeros((len(batch.get_unfinished_states()), nF), dtype="i") for i, state in enumerate(batch.get_unfinished_states()): @@ -281,6 +324,9 @@ def _forward_fallback(model: Model, moves: TransitionSystem, states: List[StateC all_ids.append(ids) all_statevecs.append(statevecs) all_which.append(which) + if max_moves >= 1 and n_moves >= max_moves: + break + n_moves += 1 def backprop_parser(d_states_d_scores): ids = ops.xp.vstack(all_ids) diff --git a/spacy/pipeline/transition_parser.pyx b/spacy/pipeline/transition_parser.pyx index 75ca5902d..85fcc1ef8 100644 --- a/spacy/pipeline/transition_parser.pyx +++ b/spacy/pipeline/transition_parser.pyx @@ -19,6 +19,7 @@ import numpy.random import numpy import warnings +from ..ml.tb_framework import TransitionModelInputs from ._parser_internals.stateclass cimport StateC, StateClass from ._parser_internals.search cimport Beam from ..tokens.doc cimport Doc @@ -26,7 +27,7 @@ from .trainable_pipe cimport TrainablePipe from ._parser_internals cimport _beam_utils from ._parser_internals import _beam_utils from ..vocab cimport Vocab -from ._parser_internals.transition_system cimport TransitionSystem +from ._parser_internals.transition_system cimport Transition, TransitionSystem from ..typedefs cimport weight_t from ..training import validate_examples, validate_get_examples @@ -254,20 +255,23 @@ class Parser(TrainablePipe): result = self.moves.init_batch(docs) return result with _change_attrs(self.model, beam_width=self.cfg["beam_width"], beam_density=self.cfg["beam_density"]): - states_or_beams, _ = self.model.predict((docs, self.moves)) + inputs = TransitionModelInputs(docs=docs, moves=self.moves) + states_or_beams, _ = self.model.predict(inputs) return states_or_beams def greedy_parse(self, docs, drop=0.): self._resize() self._ensure_labels_are_added(docs) with _change_attrs(self.model, beam_width=1): - states, _ = self.model.predict((docs, self.moves)) + inputs = TransitionModelInputs(docs=docs, moves=self.moves) + states, _ = self.model.predict(inputs) return states def beam_parse(self, docs, int beam_width, float drop=0., beam_density=0.): self._ensure_labels_are_added(docs) with _change_attrs(self.model, beam_width=self.cfg["beam_width"], beam_density=self.cfg["beam_density"]): - beams, _ = self.model.predict((docs, self.moves)) + inputs = TransitionModelInputs(docs=docs, moves=self.moves) + beams, _ = self.model.predict(inputs) return beams def set_annotations(self, docs, states_or_beams): @@ -298,11 +302,27 @@ class Parser(TrainablePipe): return losses set_dropout_rate(self.model, drop) docs = [eg.x for eg in examples if len(eg.x)] - (states, scores), backprop_scores = self.model.begin_update((docs, self.moves)) + + max_moves = self.cfg["update_with_oracle_cut_size"] + if max_moves >= 1: + # Chop sequences into lengths of this many words, to make the + # batch uniform length. + max_moves = int(random.uniform(max_moves // 2, max_moves * 2)) + init_states, gold_states, _ = self._init_gold_batch( + examples, + max_length=max_moves + ) + else: + init_states, gold_states, _ = self.moves.init_gold_batch(examples) + + inputs = TransitionModelInputs(docs=docs, moves=self.moves, + max_moves=max_moves, states=[state.copy() for state in init_states]) + (pred_states, scores), backprop_scores = self.model.begin_update(inputs) if sum(s.shape[0] for s in scores) == 0: return losses - d_scores = self.get_loss((states, scores), examples) - backprop_scores((states, d_scores)) + d_scores = self.get_loss((gold_states, init_states, pred_states, scores), + examples, max_moves) + backprop_scores((pred_states, d_scores)) if sgd not in (None, False): self.finish_update(sgd) losses[self.name] += float((d_scores**2).sum()) @@ -312,12 +332,15 @@ class Parser(TrainablePipe): del backprop_scores return losses - def get_loss(self, states_scores, examples): - states, scores = states_scores + def get_loss(self, states_scores, examples, max_moves): + gold_states, init_states, pred_states, scores = states_scores scores = self.model.ops.xp.vstack(scores) costs = self._get_costs_from_histories( examples, - [list(state.history) for state in states] + gold_states, + init_states, + [list(state.history) for state in pred_states], + max_moves ) xp = get_array_module(scores) best_costs = costs.min(axis=1, keepdims=True) @@ -335,7 +358,7 @@ class Parser(TrainablePipe): d_scores -= (costs <= best_costs) * (exp_gscores / gZ) return d_scores - def _get_costs_from_histories(self, examples, histories): + def _get_costs_from_histories(self, examples, gold_states, init_states, histories, max_moves): cdef TransitionSystem moves = self.moves cdef StateClass state cdef int clas @@ -345,16 +368,12 @@ class Parser(TrainablePipe): cdef Pool mem = Pool() cdef np.ndarray costs_i is_valid = mem.alloc(nO, sizeof(int)) - states = moves.init_batch([eg.x for eg in examples]) - batch = [] - for eg, s, h in zip(examples, states, histories): - if not s.is_final(): - gold = moves.init_gold(s, eg) - batch.append((eg, s, h, gold)) + batch = list(zip(init_states, histories, gold_states)) + n_moves = 0 output = [] while batch: costs = numpy.zeros((len(batch), nO), dtype="f") - for i, (eg, state, history, gold) in enumerate(batch): + for i, (state, history, gold) in enumerate(batch): costs_i = costs[i] clas = history.pop(0) moves.set_costs(is_valid, costs_i.data, state.c, gold) @@ -362,7 +381,11 @@ class Parser(TrainablePipe): action.do(state.c, action.label) state.c.history.push_back(clas) output.append(costs) - batch = [(eg, s, h, g) for eg, s, h, g in batch if len(h) != 0] + batch = [(s, h, g) for s, h, g in batch if len(h) != 0] + if max_moves >= 1 and n_moves >= max_moves: + break + n_moves += 1 + return self.model.ops.xp.vstack(output) def rehearse(self, examples, sgd=None, losses=None, **cfg): @@ -384,9 +407,11 @@ class Parser(TrainablePipe): # Prepare the stepwise model, and get the callback for finishing the batch set_dropout_rate(self._rehearsal_model, 0.0) set_dropout_rate(self.model, 0.0) - (student_states, student_scores), backprop_scores = self.model.begin_update((docs, self.moves)) + student_inputs = TransitionModelInputs(docs=docs, moves=self.moves) + (student_states, student_scores), backprop_scores = self.model.begin_update(student_inputs) actions = states2actions(student_states) - _, teacher_scores = self._rehearsal_model.predict((docs, self.moves, actions)) + teacher_inputs = TransitionModelInputs(docs=docs, moves=self.moves, actions=actions) + _, teacher_scores = self._rehearsal_model.predict(teacher_inputs) teacher_scores = self.model.ops.xp.vstack(teacher_scores) student_scores = self.model.ops.xp.vstack(student_scores) @@ -502,6 +527,49 @@ class Parser(TrainablePipe): raise ValueError(Errors.E149) from None return self + def _init_gold_batch(self, examples, max_length): + """Make a square batch, of length equal to the shortest transition + sequence or a cap. A long + doc will get multiple states. Let's say we have a doc of length 2*N, + where N is the shortest doc. We'll make two states, one representing + long_doc[:N], and another representing long_doc[N:].""" + cdef: + StateClass start_state + StateClass state + Transition action + TransitionSystem moves = self.moves + all_states = moves.init_batch([eg.predicted for eg in examples]) + states = [] + golds = [] + to_cut = [] + for state, eg in zip(all_states, examples): + if moves.has_gold(eg) and not state.is_final(): + gold = moves.init_gold(state, eg) + if len(eg.x) < max_length: + states.append(state) + golds.append(gold) + else: + oracle_actions = moves.get_oracle_sequence_from_state( + state.copy(), gold) + to_cut.append((eg, state, gold, oracle_actions)) + if not to_cut: + return states, golds, 0 + cdef int clas + for eg, state, gold, oracle_actions in to_cut: + for i in range(0, len(oracle_actions), max_length): + start_state = state.copy() + for clas in oracle_actions[i:i+max_length]: + action = moves.c[clas] + action.do(state.c, action.label) + if state.is_final(): + break + if moves.has_gold(eg, start_state.B(0), state.B(0)): + states.append(start_state) + golds.append(gold) + if state.is_final(): + break + return states, golds, max_length + @contextlib.contextmanager def _change_attrs(model, **kwargs):