mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-04 20:30:24 +03:00
Bring back support for update_with_oracle_cut_size
This option was available in the pre-refactor parser, but was never implemented in the refactored parser. This option cuts transition sequences that are longer than `update_with_oracle_cut` size into separate sequences that have at most `update_with_oracle_cut` transitions. The oracle (gold standard) transition sequence is used to determine the cuts and the initial states for the additional sequences. Applying this cut makes the batches more homogeneous in the transition sequence lengths, making forward passes (and as a consequence training) much faster. Training time 1000 steps on de_core_news_lg: - Before this change: 149s - After this change: 68s - Pre-refactor parser: 81s
This commit is contained in:
parent
5ffee4863f
commit
29b3d19a87
|
@ -148,32 +148,74 @@ def init(
|
||||||
# model = _lsuv_init(model)
|
# model = _lsuv_init(model)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
InWithoutActions = Tuple[List[Doc], TransitionSystem]
|
|
||||||
InWithActions = Tuple[List[Doc], TransitionSystem, List[Ints1d]]
|
|
||||||
InT = TypeVar("InT", InWithoutActions, InWithActions)
|
|
||||||
|
|
||||||
def forward(model, docs_moves: InT, is_train: bool):
|
class TransitionModelInputs:
|
||||||
if len(docs_moves) == 2:
|
"""
|
||||||
docs, moves = docs_moves
|
Input to transition model.
|
||||||
actions = None
|
"""
|
||||||
else:
|
|
||||||
docs, moves, actions = docs_moves
|
# dataclass annotation is not yet supported in Cython 0.29.x,
|
||||||
|
# so, we'll do something close to it.
|
||||||
|
|
||||||
|
actions: Optional[List[Ints1d]]
|
||||||
|
docs: List[Doc]
|
||||||
|
max_moves: int
|
||||||
|
moves: TransitionSystem
|
||||||
|
states: Optional[List[State]]
|
||||||
|
|
||||||
|
__slots__ = [
|
||||||
|
"actions",
|
||||||
|
"docs",
|
||||||
|
"max_moves",
|
||||||
|
"moves",
|
||||||
|
"states",
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, docs: List[Doc], moves: TransitionSystem,
|
||||||
|
actions: Optional[List[Ints1d]]=None, max_moves: int=0,
|
||||||
|
states: Optional[List[State]]=None):
|
||||||
|
"""
|
||||||
|
actions (Optional[List[Ints1d]]): actions to apply for each Doc.
|
||||||
|
docs (List[Doc]): Docs to predict transition sequences for.
|
||||||
|
max_moves: (Optional[int]): the maximum number of moves to apply,
|
||||||
|
values less than 1 will apply moves to states until they are
|
||||||
|
final states.
|
||||||
|
moves (TransitionSystem): the transition system to use when predicting
|
||||||
|
the transition sequences.
|
||||||
|
states (Optional[List[States]]): the initial states to predict the
|
||||||
|
transition sequences for. When absent, the initial states are
|
||||||
|
initialized from the provided Docs.
|
||||||
|
"""
|
||||||
|
self.actions = actions
|
||||||
|
self.docs = docs
|
||||||
|
self.moves = moves
|
||||||
|
self.max_moves = max_moves
|
||||||
|
self.states = states
|
||||||
|
|
||||||
|
|
||||||
|
def forward(model, inputs: TransitionModelInputs, is_train: bool):
|
||||||
|
docs = inputs.docs
|
||||||
|
moves = inputs.moves
|
||||||
|
actions = inputs.actions
|
||||||
|
|
||||||
beam_width = model.attrs["beam_width"]
|
beam_width = model.attrs["beam_width"]
|
||||||
hidden_pad = model.get_param("hidden_pad")
|
hidden_pad = model.get_param("hidden_pad")
|
||||||
tok2vec = model.get_ref("tok2vec")
|
tok2vec = model.get_ref("tok2vec")
|
||||||
|
|
||||||
states = moves.init_batch(docs)
|
states = moves.init_batch(docs) if inputs.states is None else inputs.states
|
||||||
tokvecs, backprop_tok2vec = tok2vec(docs, is_train)
|
tokvecs, backprop_tok2vec = tok2vec(docs, is_train)
|
||||||
tokvecs = model.ops.xp.vstack((tokvecs, hidden_pad))
|
tokvecs = model.ops.xp.vstack((tokvecs, hidden_pad))
|
||||||
feats, backprop_feats = _forward_precomputable_affine(model, tokvecs, is_train)
|
feats, backprop_feats = _forward_precomputable_affine(model, tokvecs, is_train)
|
||||||
seen_mask = _get_seen_mask(model)
|
seen_mask = _get_seen_mask(model)
|
||||||
|
|
||||||
# Fixme: support actions in forward_cpu
|
|
||||||
if beam_width == 1 and not is_train and isinstance(model.ops, NumpyOps):
|
if beam_width == 1 and not is_train and isinstance(model.ops, NumpyOps):
|
||||||
|
# Note: max_moves is only used during training, so we don't need to
|
||||||
|
# pass it to the greedy inference path.
|
||||||
return _forward_greedy_cpu(model, moves, states, feats, seen_mask, actions=actions)
|
return _forward_greedy_cpu(model, moves, states, feats, seen_mask, actions=actions)
|
||||||
else:
|
else:
|
||||||
return _forward_fallback(model, moves, states, tokvecs, backprop_tok2vec, feats, backprop_feats, seen_mask, is_train, actions=actions)
|
return _forward_fallback(model, moves, states, tokvecs, backprop_tok2vec,
|
||||||
|
feats, backprop_feats, seen_mask, is_train, actions=actions,
|
||||||
|
max_moves=inputs.max_moves)
|
||||||
|
|
||||||
|
|
||||||
def _forward_greedy_cpu(model: Model, TransitionSystem moves, states: List[StateClass], np.ndarray feats,
|
def _forward_greedy_cpu(model: Model, TransitionSystem moves, states: List[StateClass], np.ndarray feats,
|
||||||
|
@ -230,7 +272,7 @@ cdef list _parse_batch(CBlas cblas, TransitionSystem moves, StateC** states,
|
||||||
|
|
||||||
|
|
||||||
def _forward_fallback(model: Model, moves: TransitionSystem, states: List[StateClass], tokvecs, backprop_tok2vec, feats, backprop_feats, seen_mask, is_train: bool,
|
def _forward_fallback(model: Model, moves: TransitionSystem, states: List[StateClass], tokvecs, backprop_tok2vec, feats, backprop_feats, seen_mask, is_train: bool,
|
||||||
actions: Optional[List[Ints1d]]=None):
|
actions: Optional[List[Ints1d]]=None, max_moves=0):
|
||||||
nF = model.get_dim("nF")
|
nF = model.get_dim("nF")
|
||||||
output = model.get_ref("output")
|
output = model.get_ref("output")
|
||||||
hidden_b = model.get_param("hidden_b")
|
hidden_b = model.get_param("hidden_b")
|
||||||
|
@ -253,6 +295,7 @@ def _forward_fallback(model: Model, moves: TransitionSystem, states: List[StateC
|
||||||
moves, states, None, width=beam_width, density=beam_density
|
moves, states, None, width=beam_width, density=beam_density
|
||||||
)
|
)
|
||||||
arange = ops.xp.arange(nF)
|
arange = ops.xp.arange(nF)
|
||||||
|
n_moves = 0
|
||||||
while not batch.is_done:
|
while not batch.is_done:
|
||||||
ids = numpy.zeros((len(batch.get_unfinished_states()), nF), dtype="i")
|
ids = numpy.zeros((len(batch.get_unfinished_states()), nF), dtype="i")
|
||||||
for i, state in enumerate(batch.get_unfinished_states()):
|
for i, state in enumerate(batch.get_unfinished_states()):
|
||||||
|
@ -281,6 +324,9 @@ def _forward_fallback(model: Model, moves: TransitionSystem, states: List[StateC
|
||||||
all_ids.append(ids)
|
all_ids.append(ids)
|
||||||
all_statevecs.append(statevecs)
|
all_statevecs.append(statevecs)
|
||||||
all_which.append(which)
|
all_which.append(which)
|
||||||
|
if max_moves >= 1 and n_moves >= max_moves:
|
||||||
|
break
|
||||||
|
n_moves += 1
|
||||||
|
|
||||||
def backprop_parser(d_states_d_scores):
|
def backprop_parser(d_states_d_scores):
|
||||||
ids = ops.xp.vstack(all_ids)
|
ids = ops.xp.vstack(all_ids)
|
||||||
|
|
|
@ -19,6 +19,7 @@ import numpy.random
|
||||||
import numpy
|
import numpy
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
from ..ml.tb_framework import TransitionModelInputs
|
||||||
from ._parser_internals.stateclass cimport StateC, StateClass
|
from ._parser_internals.stateclass cimport StateC, StateClass
|
||||||
from ._parser_internals.search cimport Beam
|
from ._parser_internals.search cimport Beam
|
||||||
from ..tokens.doc cimport Doc
|
from ..tokens.doc cimport Doc
|
||||||
|
@ -26,7 +27,7 @@ from .trainable_pipe cimport TrainablePipe
|
||||||
from ._parser_internals cimport _beam_utils
|
from ._parser_internals cimport _beam_utils
|
||||||
from ._parser_internals import _beam_utils
|
from ._parser_internals import _beam_utils
|
||||||
from ..vocab cimport Vocab
|
from ..vocab cimport Vocab
|
||||||
from ._parser_internals.transition_system cimport TransitionSystem
|
from ._parser_internals.transition_system cimport Transition, TransitionSystem
|
||||||
from ..typedefs cimport weight_t
|
from ..typedefs cimport weight_t
|
||||||
|
|
||||||
from ..training import validate_examples, validate_get_examples
|
from ..training import validate_examples, validate_get_examples
|
||||||
|
@ -254,20 +255,23 @@ class Parser(TrainablePipe):
|
||||||
result = self.moves.init_batch(docs)
|
result = self.moves.init_batch(docs)
|
||||||
return result
|
return result
|
||||||
with _change_attrs(self.model, beam_width=self.cfg["beam_width"], beam_density=self.cfg["beam_density"]):
|
with _change_attrs(self.model, beam_width=self.cfg["beam_width"], beam_density=self.cfg["beam_density"]):
|
||||||
states_or_beams, _ = self.model.predict((docs, self.moves))
|
inputs = TransitionModelInputs(docs=docs, moves=self.moves)
|
||||||
|
states_or_beams, _ = self.model.predict(inputs)
|
||||||
return states_or_beams
|
return states_or_beams
|
||||||
|
|
||||||
def greedy_parse(self, docs, drop=0.):
|
def greedy_parse(self, docs, drop=0.):
|
||||||
self._resize()
|
self._resize()
|
||||||
self._ensure_labels_are_added(docs)
|
self._ensure_labels_are_added(docs)
|
||||||
with _change_attrs(self.model, beam_width=1):
|
with _change_attrs(self.model, beam_width=1):
|
||||||
states, _ = self.model.predict((docs, self.moves))
|
inputs = TransitionModelInputs(docs=docs, moves=self.moves)
|
||||||
|
states, _ = self.model.predict(inputs)
|
||||||
return states
|
return states
|
||||||
|
|
||||||
def beam_parse(self, docs, int beam_width, float drop=0., beam_density=0.):
|
def beam_parse(self, docs, int beam_width, float drop=0., beam_density=0.):
|
||||||
self._ensure_labels_are_added(docs)
|
self._ensure_labels_are_added(docs)
|
||||||
with _change_attrs(self.model, beam_width=self.cfg["beam_width"], beam_density=self.cfg["beam_density"]):
|
with _change_attrs(self.model, beam_width=self.cfg["beam_width"], beam_density=self.cfg["beam_density"]):
|
||||||
beams, _ = self.model.predict((docs, self.moves))
|
inputs = TransitionModelInputs(docs=docs, moves=self.moves)
|
||||||
|
beams, _ = self.model.predict(inputs)
|
||||||
return beams
|
return beams
|
||||||
|
|
||||||
def set_annotations(self, docs, states_or_beams):
|
def set_annotations(self, docs, states_or_beams):
|
||||||
|
@ -298,11 +302,27 @@ class Parser(TrainablePipe):
|
||||||
return losses
|
return losses
|
||||||
set_dropout_rate(self.model, drop)
|
set_dropout_rate(self.model, drop)
|
||||||
docs = [eg.x for eg in examples if len(eg.x)]
|
docs = [eg.x for eg in examples if len(eg.x)]
|
||||||
(states, scores), backprop_scores = self.model.begin_update((docs, self.moves))
|
|
||||||
|
max_moves = self.cfg["update_with_oracle_cut_size"]
|
||||||
|
if max_moves >= 1:
|
||||||
|
# Chop sequences into lengths of this many words, to make the
|
||||||
|
# batch uniform length.
|
||||||
|
max_moves = int(random.uniform(max_moves // 2, max_moves * 2))
|
||||||
|
init_states, gold_states, _ = self._init_gold_batch(
|
||||||
|
examples,
|
||||||
|
max_length=max_moves
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
init_states, gold_states, _ = self.moves.init_gold_batch(examples)
|
||||||
|
|
||||||
|
inputs = TransitionModelInputs(docs=docs, moves=self.moves,
|
||||||
|
max_moves=max_moves, states=[state.copy() for state in init_states])
|
||||||
|
(pred_states, scores), backprop_scores = self.model.begin_update(inputs)
|
||||||
if sum(s.shape[0] for s in scores) == 0:
|
if sum(s.shape[0] for s in scores) == 0:
|
||||||
return losses
|
return losses
|
||||||
d_scores = self.get_loss((states, scores), examples)
|
d_scores = self.get_loss((gold_states, init_states, pred_states, scores),
|
||||||
backprop_scores((states, d_scores))
|
examples, max_moves)
|
||||||
|
backprop_scores((pred_states, d_scores))
|
||||||
if sgd not in (None, False):
|
if sgd not in (None, False):
|
||||||
self.finish_update(sgd)
|
self.finish_update(sgd)
|
||||||
losses[self.name] += float((d_scores**2).sum())
|
losses[self.name] += float((d_scores**2).sum())
|
||||||
|
@ -312,12 +332,15 @@ class Parser(TrainablePipe):
|
||||||
del backprop_scores
|
del backprop_scores
|
||||||
return losses
|
return losses
|
||||||
|
|
||||||
def get_loss(self, states_scores, examples):
|
def get_loss(self, states_scores, examples, max_moves):
|
||||||
states, scores = states_scores
|
gold_states, init_states, pred_states, scores = states_scores
|
||||||
scores = self.model.ops.xp.vstack(scores)
|
scores = self.model.ops.xp.vstack(scores)
|
||||||
costs = self._get_costs_from_histories(
|
costs = self._get_costs_from_histories(
|
||||||
examples,
|
examples,
|
||||||
[list(state.history) for state in states]
|
gold_states,
|
||||||
|
init_states,
|
||||||
|
[list(state.history) for state in pred_states],
|
||||||
|
max_moves
|
||||||
)
|
)
|
||||||
xp = get_array_module(scores)
|
xp = get_array_module(scores)
|
||||||
best_costs = costs.min(axis=1, keepdims=True)
|
best_costs = costs.min(axis=1, keepdims=True)
|
||||||
|
@ -335,7 +358,7 @@ class Parser(TrainablePipe):
|
||||||
d_scores -= (costs <= best_costs) * (exp_gscores / gZ)
|
d_scores -= (costs <= best_costs) * (exp_gscores / gZ)
|
||||||
return d_scores
|
return d_scores
|
||||||
|
|
||||||
def _get_costs_from_histories(self, examples, histories):
|
def _get_costs_from_histories(self, examples, gold_states, init_states, histories, max_moves):
|
||||||
cdef TransitionSystem moves = self.moves
|
cdef TransitionSystem moves = self.moves
|
||||||
cdef StateClass state
|
cdef StateClass state
|
||||||
cdef int clas
|
cdef int clas
|
||||||
|
@ -345,16 +368,12 @@ class Parser(TrainablePipe):
|
||||||
cdef Pool mem = Pool()
|
cdef Pool mem = Pool()
|
||||||
cdef np.ndarray costs_i
|
cdef np.ndarray costs_i
|
||||||
is_valid = <int*>mem.alloc(nO, sizeof(int))
|
is_valid = <int*>mem.alloc(nO, sizeof(int))
|
||||||
states = moves.init_batch([eg.x for eg in examples])
|
batch = list(zip(init_states, histories, gold_states))
|
||||||
batch = []
|
n_moves = 0
|
||||||
for eg, s, h in zip(examples, states, histories):
|
|
||||||
if not s.is_final():
|
|
||||||
gold = moves.init_gold(s, eg)
|
|
||||||
batch.append((eg, s, h, gold))
|
|
||||||
output = []
|
output = []
|
||||||
while batch:
|
while batch:
|
||||||
costs = numpy.zeros((len(batch), nO), dtype="f")
|
costs = numpy.zeros((len(batch), nO), dtype="f")
|
||||||
for i, (eg, state, history, gold) in enumerate(batch):
|
for i, (state, history, gold) in enumerate(batch):
|
||||||
costs_i = costs[i]
|
costs_i = costs[i]
|
||||||
clas = history.pop(0)
|
clas = history.pop(0)
|
||||||
moves.set_costs(is_valid, <weight_t*>costs_i.data, state.c, gold)
|
moves.set_costs(is_valid, <weight_t*>costs_i.data, state.c, gold)
|
||||||
|
@ -362,7 +381,11 @@ class Parser(TrainablePipe):
|
||||||
action.do(state.c, action.label)
|
action.do(state.c, action.label)
|
||||||
state.c.history.push_back(clas)
|
state.c.history.push_back(clas)
|
||||||
output.append(costs)
|
output.append(costs)
|
||||||
batch = [(eg, s, h, g) for eg, s, h, g in batch if len(h) != 0]
|
batch = [(s, h, g) for s, h, g in batch if len(h) != 0]
|
||||||
|
if max_moves >= 1 and n_moves >= max_moves:
|
||||||
|
break
|
||||||
|
n_moves += 1
|
||||||
|
|
||||||
return self.model.ops.xp.vstack(output)
|
return self.model.ops.xp.vstack(output)
|
||||||
|
|
||||||
def rehearse(self, examples, sgd=None, losses=None, **cfg):
|
def rehearse(self, examples, sgd=None, losses=None, **cfg):
|
||||||
|
@ -384,9 +407,11 @@ class Parser(TrainablePipe):
|
||||||
# Prepare the stepwise model, and get the callback for finishing the batch
|
# Prepare the stepwise model, and get the callback for finishing the batch
|
||||||
set_dropout_rate(self._rehearsal_model, 0.0)
|
set_dropout_rate(self._rehearsal_model, 0.0)
|
||||||
set_dropout_rate(self.model, 0.0)
|
set_dropout_rate(self.model, 0.0)
|
||||||
(student_states, student_scores), backprop_scores = self.model.begin_update((docs, self.moves))
|
student_inputs = TransitionModelInputs(docs=docs, moves=self.moves)
|
||||||
|
(student_states, student_scores), backprop_scores = self.model.begin_update(student_inputs)
|
||||||
actions = states2actions(student_states)
|
actions = states2actions(student_states)
|
||||||
_, teacher_scores = self._rehearsal_model.predict((docs, self.moves, actions))
|
teacher_inputs = TransitionModelInputs(docs=docs, moves=self.moves, actions=actions)
|
||||||
|
_, teacher_scores = self._rehearsal_model.predict(teacher_inputs)
|
||||||
|
|
||||||
teacher_scores = self.model.ops.xp.vstack(teacher_scores)
|
teacher_scores = self.model.ops.xp.vstack(teacher_scores)
|
||||||
student_scores = self.model.ops.xp.vstack(student_scores)
|
student_scores = self.model.ops.xp.vstack(student_scores)
|
||||||
|
@ -502,6 +527,49 @@ class Parser(TrainablePipe):
|
||||||
raise ValueError(Errors.E149) from None
|
raise ValueError(Errors.E149) from None
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def _init_gold_batch(self, examples, max_length):
|
||||||
|
"""Make a square batch, of length equal to the shortest transition
|
||||||
|
sequence or a cap. A long
|
||||||
|
doc will get multiple states. Let's say we have a doc of length 2*N,
|
||||||
|
where N is the shortest doc. We'll make two states, one representing
|
||||||
|
long_doc[:N], and another representing long_doc[N:]."""
|
||||||
|
cdef:
|
||||||
|
StateClass start_state
|
||||||
|
StateClass state
|
||||||
|
Transition action
|
||||||
|
TransitionSystem moves = self.moves
|
||||||
|
all_states = moves.init_batch([eg.predicted for eg in examples])
|
||||||
|
states = []
|
||||||
|
golds = []
|
||||||
|
to_cut = []
|
||||||
|
for state, eg in zip(all_states, examples):
|
||||||
|
if moves.has_gold(eg) and not state.is_final():
|
||||||
|
gold = moves.init_gold(state, eg)
|
||||||
|
if len(eg.x) < max_length:
|
||||||
|
states.append(state)
|
||||||
|
golds.append(gold)
|
||||||
|
else:
|
||||||
|
oracle_actions = moves.get_oracle_sequence_from_state(
|
||||||
|
state.copy(), gold)
|
||||||
|
to_cut.append((eg, state, gold, oracle_actions))
|
||||||
|
if not to_cut:
|
||||||
|
return states, golds, 0
|
||||||
|
cdef int clas
|
||||||
|
for eg, state, gold, oracle_actions in to_cut:
|
||||||
|
for i in range(0, len(oracle_actions), max_length):
|
||||||
|
start_state = state.copy()
|
||||||
|
for clas in oracle_actions[i:i+max_length]:
|
||||||
|
action = moves.c[clas]
|
||||||
|
action.do(state.c, action.label)
|
||||||
|
if state.is_final():
|
||||||
|
break
|
||||||
|
if moves.has_gold(eg, start_state.B(0), state.B(0)):
|
||||||
|
states.append(start_state)
|
||||||
|
golds.append(gold)
|
||||||
|
if state.is_final():
|
||||||
|
break
|
||||||
|
return states, golds, max_length
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def _change_attrs(model, **kwargs):
|
def _change_attrs(model, **kwargs):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user