Bring back support for update_with_oracle_cut_size (#12086)

* Bring back support for `update_with_oracle_cut_size`

This option was available in the pre-refactor parser, but was never
implemented in the refactored parser. This option cuts transition
sequences that are longer than `update_with_oracle_cut` size into
separate sequences that have at most `update_with_oracle_cut`
transitions. The oracle (gold standard) transition sequence is used to
determine the cuts and the initial states for the additional sequences.

Applying this cut makes the batches more homogeneous in the transition
sequence lengths, making forward passes (and as a consequence training)
much faster.

Training time 1000 steps on de_core_news_lg:

- Before this change: 149s
- After this change: 68s
- Pre-refactor parser: 81s

* Fix a rename that was missed in #10878.

So that rehearsal tests pass.

* Apply suggestions from @shadeMe

* Use chained conditional

* Test with update_with_oracle_cut_size={0, 1, 5, 100}

And fix a git that occurs with a cut size of 1.
This commit is contained in:
Daniël de Kok 2023-01-13 17:05:17 +01:00 committed by GitHub
parent aca4b09ac2
commit 774930c1ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 168 additions and 38 deletions

View File

@ -148,32 +148,77 @@ def init(
# model = _lsuv_init(model)
return model
InWithoutActions = Tuple[List[Doc], TransitionSystem]
InWithActions = Tuple[List[Doc], TransitionSystem, List[Ints1d]]
InT = TypeVar("InT", InWithoutActions, InWithActions)
def forward(model, docs_moves: InT, is_train: bool):
if len(docs_moves) == 2:
docs, moves = docs_moves
actions = None
else:
docs, moves, actions = docs_moves
class TransitionModelInputs:
"""
Input to transition model.
"""
# dataclass annotation is not yet supported in Cython 0.29.x,
# so, we'll do something close to it.
actions: Optional[List[Ints1d]]
docs: List[Doc]
max_moves: int
moves: TransitionSystem
states: Optional[List[State]]
__slots__ = [
"actions",
"docs",
"max_moves",
"moves",
"states",
]
def __init__(
self,
docs: List[Doc],
moves: TransitionSystem,
actions: Optional[List[Ints1d]]=None,
max_moves: int=0,
states: Optional[List[State]]=None):
"""
actions (Optional[List[Ints1d]]): actions to apply for each Doc.
docs (List[Doc]): Docs to predict transition sequences for.
max_moves: (int): the maximum number of moves to apply, values less
than 1 will apply moves to states until they are final states.
moves (TransitionSystem): the transition system to use when predicting
the transition sequences.
states (Optional[List[States]]): the initial states to predict the
transition sequences for. When absent, the initial states are
initialized from the provided Docs.
"""
self.actions = actions
self.docs = docs
self.moves = moves
self.max_moves = max_moves
self.states = states
def forward(model, inputs: TransitionModelInputs, is_train: bool):
docs = inputs.docs
moves = inputs.moves
actions = inputs.actions
beam_width = model.attrs["beam_width"]
hidden_pad = model.get_param("hidden_pad")
tok2vec = model.get_ref("tok2vec")
states = moves.init_batch(docs)
states = moves.init_batch(docs) if inputs.states is None else inputs.states
tokvecs, backprop_tok2vec = tok2vec(docs, is_train)
tokvecs = model.ops.xp.vstack((tokvecs, hidden_pad))
feats, backprop_feats = _forward_precomputable_affine(model, tokvecs, is_train)
seen_mask = _get_seen_mask(model)
# Fixme: support actions in forward_cpu
if beam_width == 1 and not is_train and isinstance(model.ops, NumpyOps):
if not is_train and beam_width == 1 and isinstance(model.ops, NumpyOps):
# Note: max_moves is only used during training, so we don't need to
# pass it to the greedy inference path.
return _forward_greedy_cpu(model, moves, states, feats, seen_mask, actions=actions)
else:
return _forward_fallback(model, moves, states, tokvecs, backprop_tok2vec, feats, backprop_feats, seen_mask, is_train, actions=actions)
return _forward_fallback(model, moves, states, tokvecs, backprop_tok2vec,
feats, backprop_feats, seen_mask, is_train, actions=actions,
max_moves=inputs.max_moves)
def _forward_greedy_cpu(model: Model, TransitionSystem moves, states: List[StateClass], np.ndarray feats,
@ -229,8 +274,17 @@ cdef list _parse_batch(CBlas cblas, TransitionSystem moves, StateC** states,
return scores
def _forward_fallback(model: Model, moves: TransitionSystem, states: List[StateClass], tokvecs, backprop_tok2vec, feats, backprop_feats, seen_mask, is_train: bool,
actions: Optional[List[Ints1d]]=None):
def _forward_fallback(
model: Model,
moves: TransitionSystem,
states: List[StateClass],
tokvecs, backprop_tok2vec,
feats,
backprop_feats,
seen_mask,
is_train: bool,
actions: Optional[List[Ints1d]]=None,
max_moves: int=0):
nF = model.get_dim("nF")
output = model.get_ref("output")
hidden_b = model.get_param("hidden_b")
@ -253,6 +307,7 @@ def _forward_fallback(model: Model, moves: TransitionSystem, states: List[StateC
moves, states, None, width=beam_width, density=beam_density
)
arange = ops.xp.arange(nF)
n_moves = 0
while not batch.is_done:
ids = numpy.zeros((len(batch.get_unfinished_states()), nF), dtype="i")
for i, state in enumerate(batch.get_unfinished_states()):
@ -281,6 +336,9 @@ def _forward_fallback(model: Model, moves: TransitionSystem, states: List[StateC
all_ids.append(ids)
all_statevecs.append(statevecs)
all_which.append(which)
if n_moves >= max_moves >= 1:
break
n_moves += 1
def backprop_parser(d_states_d_scores):
ids = ops.xp.vstack(all_ids)

View File

@ -18,6 +18,7 @@ import numpy.random
import numpy
import warnings
from ..ml.tb_framework import TransitionModelInputs
from ._parser_internals.stateclass cimport StateC, StateClass
from ._parser_internals.search cimport Beam
from ..tokens.doc cimport Doc
@ -25,7 +26,7 @@ from .trainable_pipe cimport TrainablePipe
from ._parser_internals cimport _beam_utils
from ._parser_internals import _beam_utils
from ..vocab cimport Vocab
from ._parser_internals.transition_system cimport TransitionSystem
from ._parser_internals.transition_system cimport Transition, TransitionSystem
from ..typedefs cimport weight_t
from ..training import validate_examples, validate_get_examples
@ -253,20 +254,23 @@ class Parser(TrainablePipe):
result = self.moves.init_batch(docs)
return result
with _change_attrs(self.model, beam_width=self.cfg["beam_width"], beam_density=self.cfg["beam_density"]):
states_or_beams, _ = self.model.predict((docs, self.moves))
inputs = TransitionModelInputs(docs=docs, moves=self.moves)
states_or_beams, _ = self.model.predict(inputs)
return states_or_beams
def greedy_parse(self, docs, drop=0.):
self._resize()
self._ensure_labels_are_added(docs)
with _change_attrs(self.model, beam_width=1):
states, _ = self.model.predict((docs, self.moves))
inputs = TransitionModelInputs(docs=docs, moves=self.moves)
states, _ = self.model.predict(inputs)
return states
def beam_parse(self, docs, int beam_width, float drop=0., beam_density=0.):
self._ensure_labels_are_added(docs)
with _change_attrs(self.model, beam_width=self.cfg["beam_width"], beam_density=self.cfg["beam_density"]):
beams, _ = self.model.predict((docs, self.moves))
inputs = TransitionModelInputs(docs=docs, moves=self.moves)
beams, _ = self.model.predict(inputs)
return beams
def set_annotations(self, docs, states_or_beams):
@ -297,11 +301,27 @@ class Parser(TrainablePipe):
return losses
set_dropout_rate(self.model, drop)
docs = [eg.x for eg in examples if len(eg.x)]
(states, scores), backprop_scores = self.model.begin_update((docs, self.moves))
max_moves = self.cfg["update_with_oracle_cut_size"]
if max_moves >= 1:
# Chop sequences into lengths of this many words, to make the
# batch uniform length.
max_moves = int(random.uniform(max(max_moves // 2, 1), max_moves * 2))
init_states, gold_states, _ = self._init_gold_batch(
examples,
max_length=max_moves
)
else:
init_states, gold_states, _ = self.moves.init_gold_batch(examples)
inputs = TransitionModelInputs(docs=docs, moves=self.moves,
max_moves=max_moves, states=[state.copy() for state in init_states])
(pred_states, scores), backprop_scores = self.model.begin_update(inputs)
if sum(s.shape[0] for s in scores) == 0:
return losses
d_scores = self.get_loss((states, scores), examples)
backprop_scores((states, d_scores))
d_scores = self.get_loss((gold_states, init_states, pred_states, scores),
examples, max_moves)
backprop_scores((pred_states, d_scores))
if sgd not in (None, False):
self.finish_update(sgd)
losses[self.name] += float((d_scores**2).sum())
@ -311,12 +331,15 @@ class Parser(TrainablePipe):
del backprop_scores
return losses
def get_loss(self, states_scores, examples):
states, scores = states_scores
def get_loss(self, states_scores, examples, max_moves):
gold_states, init_states, pred_states, scores = states_scores
scores = self.model.ops.xp.vstack(scores)
costs = self._get_costs_from_histories(
examples,
[list(state.history) for state in states]
gold_states,
init_states,
[list(state.history) for state in pred_states],
max_moves
)
xp = get_array_module(scores)
best_costs = costs.min(axis=1, keepdims=True)
@ -334,7 +357,7 @@ class Parser(TrainablePipe):
d_scores -= (costs <= best_costs) * (exp_gscores / gZ)
return d_scores
def _get_costs_from_histories(self, examples, histories):
def _get_costs_from_histories(self, examples, gold_states, init_states, histories, max_moves):
cdef TransitionSystem moves = self.moves
cdef StateClass state
cdef int clas
@ -344,16 +367,12 @@ class Parser(TrainablePipe):
cdef Pool mem = Pool()
cdef np.ndarray costs_i
is_valid = <int*>mem.alloc(nO, sizeof(int))
states = moves.init_batch([eg.x for eg in examples])
batch = []
for eg, s, h in zip(examples, states, histories):
if not s.is_final():
gold = moves.init_gold(s, eg)
batch.append((eg, s, h, gold))
batch = list(zip(init_states, histories, gold_states))
n_moves = 0
output = []
while batch:
costs = numpy.zeros((len(batch), nO), dtype="f")
for i, (eg, state, history, gold) in enumerate(batch):
for i, (state, history, gold) in enumerate(batch):
costs_i = costs[i]
clas = history.pop(0)
moves.set_costs(is_valid, <weight_t*>costs_i.data, state.c, gold)
@ -361,7 +380,11 @@ class Parser(TrainablePipe):
action.do(state.c, action.label)
state.c.history.push_back(clas)
output.append(costs)
batch = [(eg, s, h, g) for eg, s, h, g in batch if len(h) != 0]
batch = [(s, h, g) for s, h, g in batch if len(h) != 0]
if n_moves >= max_moves >= 1:
break
n_moves += 1
return self.model.ops.xp.vstack(output)
def rehearse(self, examples, sgd=None, losses=None, **cfg):
@ -383,9 +406,11 @@ class Parser(TrainablePipe):
# Prepare the stepwise model, and get the callback for finishing the batch
set_dropout_rate(self._rehearsal_model, 0.0)
set_dropout_rate(self.model, 0.0)
(student_states, student_scores), backprop_scores = self.model.begin_update((docs, self.moves))
student_inputs = TransitionModelInputs(docs=docs, moves=self.moves)
(student_states, student_scores), backprop_scores = self.model.begin_update(student_inputs)
actions = states2actions(student_states)
_, teacher_scores = self._rehearsal_model.predict((docs, self.moves, actions))
teacher_inputs = TransitionModelInputs(docs=docs, moves=self.moves, actions=actions)
_, teacher_scores = self._rehearsal_model.predict(teacher_inputs)
teacher_scores = self.model.ops.xp.vstack(teacher_scores)
student_scores = self.model.ops.xp.vstack(student_scores)
@ -501,6 +526,49 @@ class Parser(TrainablePipe):
raise ValueError(Errors.E149) from None
return self
def _init_gold_batch(self, examples, max_length):
"""Make a square batch, of length equal to the shortest transition
sequence or a cap. A long doc will get multiple states. Let's say we
have a doc of length 2*N, where N is the shortest doc. We'll make
two states, one representing long_doc[:N], and another representing
long_doc[N:]."""
cdef:
StateClass start_state
StateClass state
Transition action
TransitionSystem moves = self.moves
all_states = moves.init_batch([eg.predicted for eg in examples])
states = []
golds = []
to_cut = []
for state, eg in zip(all_states, examples):
if moves.has_gold(eg) and not state.is_final():
gold = moves.init_gold(state, eg)
if len(eg.x) < max_length:
states.append(state)
golds.append(gold)
else:
oracle_actions = moves.get_oracle_sequence_from_state(
state.copy(), gold)
to_cut.append((eg, state, gold, oracle_actions))
if not to_cut:
return states, golds, 0
cdef int clas
for eg, state, gold, oracle_actions in to_cut:
for i in range(0, len(oracle_actions), max_length):
start_state = state.copy()
for clas in oracle_actions[i:i+max_length]:
action = moves.c[clas]
action.do(state.c, action.label)
if state.is_final():
break
if moves.has_gold(eg, start_state.B(0), state.B(0)):
states.append(start_state)
golds.append(gold)
if state.is_final():
break
return states, golds, max_length
@contextlib.contextmanager
def _change_attrs(model, **kwargs):

View File

@ -1,3 +1,4 @@
import itertools
import pytest
import numpy
from numpy.testing import assert_equal
@ -401,12 +402,15 @@ def test_incomplete_data(pipe_name):
assert doc[2].head.i == 1
@pytest.mark.parametrize("pipe_name", PARSERS)
def test_overfitting_IO(pipe_name):
@pytest.mark.parametrize(
"pipe_name,max_moves", itertools.product(PARSERS, [0, 1, 5, 100])
)
def test_overfitting_IO(pipe_name, max_moves):
fix_random_seed(0)
# Simple test to try and quickly overfit the dependency parser (normal or beam)
nlp = English()
parser = nlp.add_pipe(pipe_name)
parser.cfg["update_with_oracle_cut_size"] = max_moves
train_examples = []
for text, annotations in TRAIN_DATA:
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))