mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 18:06:29 +03:00
Reimplement distillation with oracle cut size (#12214)
* Improve the correctness of _parse_patch * If there are no more actions, do not attempt to make further transitions, even if not all states are final. * Assert that the number of actions for a step is the same as the number of states. * Reimplement distillation with oracle cut size The code for distillation with an oracle cut size was not reimplemented after the parser refactor. We did not notice, because we did not have tests for this functionality. This change brings back the functionality and adds this to the parser tests. * Rename states2actions to _states_to_actions for consistency * Test distillation max cuts in NER * Mark parser/NER tests as slow * Typo * Fix invariant in _states_diff_to_actions * Rename _init_batch -> _init_batch_from_teacher * Ninja edit the ninja edit * Check that we raise an exception when we pass the incorrect number or actions * Remove unnecessary get Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com> * Write out condition more explicitly --------- Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>
This commit is contained in:
parent
dd3f138830
commit
e27c60a702
|
@ -249,9 +249,11 @@ cdef list _parse_batch(CBlas cblas, TransitionSystem moves, StateC** states,
|
|||
cdef np.ndarray step_actions
|
||||
|
||||
scores = []
|
||||
while sizes.states >= 1:
|
||||
while sizes.states >= 1 and (actions is None or len(actions) > 0):
|
||||
step_scores = numpy.empty((sizes.states, sizes.classes), dtype="f")
|
||||
step_actions = actions[0] if actions is not None else None
|
||||
assert step_actions is None or step_actions.size == sizes.states, \
|
||||
f"number of step actions ({step_actions.size}) must equal number of states ({sizes.states})"
|
||||
with nogil:
|
||||
_predict_states(cblas, &activations, <float*>step_scores.data, states, &weights, sizes)
|
||||
if actions is None:
|
||||
|
|
|
@ -36,6 +36,11 @@ from ..errors import Errors, Warnings
|
|||
from .. import util
|
||||
|
||||
|
||||
# TODO: Remove when we switch to Cython 3.
|
||||
cdef extern from "<algorithm>" namespace "std" nogil:
|
||||
bint equal[InputIt1, InputIt2](InputIt1 first1, InputIt1 last1, InputIt2 first2) except +
|
||||
|
||||
|
||||
NUMPY_OPS = NumpyOps()
|
||||
|
||||
|
||||
|
@ -253,8 +258,8 @@ class Parser(TrainablePipe):
|
|||
# batch uniform length. Since we do not have a gold standard
|
||||
# sequence, we use the teacher's predictions as the gold
|
||||
# standard.
|
||||
max_moves = int(random.uniform(max_moves // 2, max_moves * 2))
|
||||
states = self._init_batch(teacher_pipe, student_docs, max_moves)
|
||||
max_moves = int(random.uniform(max(max_moves // 2, 1), max_moves * 2))
|
||||
states = self._init_batch_from_teacher(teacher_pipe, student_docs, max_moves)
|
||||
else:
|
||||
states = self.moves.init_batch(student_docs)
|
||||
|
||||
|
@ -265,12 +270,12 @@ class Parser(TrainablePipe):
|
|||
# gradients of the student's transition distributions relative to the
|
||||
# teacher's distributions.
|
||||
|
||||
student_inputs = TransitionModelInputs(docs=student_docs, moves=self.moves,
|
||||
max_moves=max_moves)
|
||||
student_inputs = TransitionModelInputs(docs=student_docs,
|
||||
states=[state.copy() for state in states], moves=self.moves, max_moves=max_moves)
|
||||
(student_states, student_scores), backprop_scores = self.model.begin_update(student_inputs)
|
||||
actions = states2actions(student_states)
|
||||
actions = _states_diff_to_actions(states, student_states)
|
||||
teacher_inputs = TransitionModelInputs(docs=[eg.reference for eg in examples],
|
||||
moves=self.moves, actions=actions)
|
||||
states=states, moves=teacher_pipe.moves, actions=actions)
|
||||
(_, teacher_scores) = teacher_pipe.model.predict(teacher_inputs)
|
||||
|
||||
loss, d_scores = self.get_teacher_student_loss(teacher_scores, student_scores)
|
||||
|
@ -522,7 +527,7 @@ class Parser(TrainablePipe):
|
|||
set_dropout_rate(self.model, 0.0)
|
||||
student_inputs = TransitionModelInputs(docs=docs, moves=self.moves)
|
||||
(student_states, student_scores), backprop_scores = self.model.begin_update(student_inputs)
|
||||
actions = states2actions(student_states)
|
||||
actions = _states_to_actions(student_states)
|
||||
teacher_inputs = TransitionModelInputs(docs=docs, moves=self.moves, actions=actions)
|
||||
_, teacher_scores = self._rehearsal_model.predict(teacher_inputs)
|
||||
|
||||
|
@ -642,7 +647,7 @@ class Parser(TrainablePipe):
|
|||
raise ValueError(Errors.E149) from None
|
||||
return self
|
||||
|
||||
def _init_batch(self, teacher_step_model, docs, max_length):
|
||||
def _init_batch_from_teacher(self, teacher_pipe, docs, max_length):
|
||||
"""Make a square batch of length equal to the shortest transition
|
||||
sequence or a cap. A long
|
||||
doc will get multiple states. Let's say we have a doc of length 2*N,
|
||||
|
@ -651,10 +656,12 @@ class Parser(TrainablePipe):
|
|||
_init_gold_batch, this version uses a teacher model to generate the
|
||||
cut sequences."""
|
||||
cdef:
|
||||
StateClass start_state
|
||||
StateClass state
|
||||
Transition action
|
||||
all_states = self.moves.init_batch(docs)
|
||||
TransitionSystem moves = teacher_pipe.moves
|
||||
|
||||
# Start with the same heuristic as in supervised training: exclude
|
||||
# docs that are within the maximum length.
|
||||
all_states = moves.init_batch(docs)
|
||||
states = []
|
||||
to_cut = []
|
||||
for state, doc in zip(all_states, docs):
|
||||
|
@ -663,18 +670,28 @@ class Parser(TrainablePipe):
|
|||
states.append(state)
|
||||
else:
|
||||
to_cut.append(state)
|
||||
|
||||
if not to_cut:
|
||||
return states
|
||||
|
||||
# Parse the states that are too long with the teacher's parsing model.
|
||||
teacher_inputs = TransitionModelInputs(docs=docs, moves=moves,
|
||||
states=[state.copy() for state in to_cut])
|
||||
(teacher_states, _ ) = teacher_pipe.model.predict(teacher_inputs)
|
||||
|
||||
# Step through the teacher's actions and store every state after
|
||||
# each multiple of max_length.
|
||||
teacher_actions = _states_to_actions(teacher_states)
|
||||
while to_cut:
|
||||
states.extend(state.copy() for state in to_cut)
|
||||
# Move states forward max_length actions.
|
||||
length = 0
|
||||
while to_cut and length < max_length:
|
||||
teacher_scores = teacher_step_model.predict(to_cut)
|
||||
self.transition_states(to_cut, teacher_scores)
|
||||
# States that are completed do not need further cutting.
|
||||
to_cut = [state for state in to_cut if not state.is_final()]
|
||||
length += 1
|
||||
return states
|
||||
for step_actions in teacher_actions[:max_length]:
|
||||
to_cut = moves.apply_actions(to_cut, step_actions)
|
||||
teacher_actions = teacher_actions[max_length:]
|
||||
|
||||
if len(teacher_actions) < max_length:
|
||||
break
|
||||
|
||||
return states
|
||||
|
||||
def _init_gold_batch(self, examples, max_length):
|
||||
"""Make a square batch, of length equal to the shortest transition
|
||||
|
@ -736,7 +753,7 @@ def _change_attrs(model, **kwargs):
|
|||
model.attrs[key] = value
|
||||
|
||||
|
||||
def states2actions(states: List[StateClass]) -> List[Ints1d]:
|
||||
def _states_to_actions(states: List[StateClass]) -> List[Ints1d]:
|
||||
cdef int step
|
||||
cdef StateClass state
|
||||
cdef StateC* c_state
|
||||
|
@ -757,3 +774,45 @@ def states2actions(states: List[StateClass]) -> List[Ints1d]:
|
|||
actions.append(numpy.array(step_actions, dtype="i"))
|
||||
|
||||
return actions
|
||||
|
||||
def _states_diff_to_actions(
|
||||
before_states: List[StateClass],
|
||||
after_states: List[StateClass]
|
||||
) -> List[Ints1d]:
|
||||
"""
|
||||
Return for two sets of states the actions to go from the first set of
|
||||
states to the second set of states. The histories of the first set of
|
||||
states must be a prefix of the second set of states.
|
||||
"""
|
||||
cdef StateClass before_state, after_state
|
||||
cdef StateC* c_state_before
|
||||
cdef StateC* c_state_after
|
||||
|
||||
assert len(before_states) == len(after_states)
|
||||
|
||||
# Check invariant: before states histories must be prefixes of after states.
|
||||
for before_state, after_state in zip(before_states, after_states):
|
||||
c_state_before = before_state.c
|
||||
c_state_after = after_state.c
|
||||
|
||||
assert equal(c_state_before.history.begin(), c_state_before.history.end(),
|
||||
c_state_after.history.begin())
|
||||
|
||||
actions = []
|
||||
while True:
|
||||
step = len(actions)
|
||||
|
||||
step_actions = []
|
||||
for before_state, after_state in zip(before_states, after_states):
|
||||
c_state_before = before_state.c
|
||||
c_state_after = after_state.c
|
||||
if step < c_state_after.history.size() - c_state_before.history.size():
|
||||
step_actions.append(c_state_after.history[c_state_before.history.size() + step])
|
||||
|
||||
# We are done if we have exhausted all histories.
|
||||
if len(step_actions) == 0:
|
||||
break
|
||||
|
||||
actions.append(numpy.array(step_actions, dtype="i"))
|
||||
|
||||
return actions
|
||||
|
|
61
spacy/tests/parser/test_model.py
Normal file
61
spacy/tests/parser/test_model.py
Normal file
|
@ -0,0 +1,61 @@
|
|||
import numpy
|
||||
import pytest
|
||||
|
||||
from spacy.lang.en import English
|
||||
from spacy.ml.tb_framework import TransitionModelInputs
|
||||
from spacy.training import Example
|
||||
|
||||
TRAIN_DATA = [
|
||||
(
|
||||
"They trade mortgage-backed securities.",
|
||||
{
|
||||
"heads": [1, 1, 4, 4, 5, 1, 1],
|
||||
"deps": ["nsubj", "ROOT", "compound", "punct", "nmod", "dobj", "punct"],
|
||||
},
|
||||
),
|
||||
(
|
||||
"I like London and Berlin.",
|
||||
{
|
||||
"heads": [1, 1, 1, 2, 2, 1],
|
||||
"deps": ["nsubj", "ROOT", "dobj", "cc", "conj", "punct"],
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def nlp_parser():
|
||||
nlp = English()
|
||||
parser = nlp.add_pipe("parser")
|
||||
|
||||
train_examples = []
|
||||
for text, annotations in TRAIN_DATA:
|
||||
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
|
||||
for dep in annotations["deps"]:
|
||||
parser.add_label(dep)
|
||||
nlp.initialize()
|
||||
|
||||
return nlp, parser
|
||||
|
||||
|
||||
def test_incorrect_number_of_actions(nlp_parser):
|
||||
nlp, parser = nlp_parser
|
||||
doc = nlp.make_doc("test")
|
||||
|
||||
# Too many actions for the number of docs
|
||||
with pytest.raises(AssertionError):
|
||||
parser.model.predict(
|
||||
TransitionModelInputs(
|
||||
docs=[doc], moves=parser.moves, actions=[numpy.array([0, 0], dtype="i")]
|
||||
)
|
||||
)
|
||||
|
||||
# Too few actions for the number of docs
|
||||
with pytest.raises(AssertionError):
|
||||
parser.model.predict(
|
||||
TransitionModelInputs(
|
||||
docs=[doc, doc],
|
||||
moves=parser.moves,
|
||||
actions=[numpy.array([0], dtype="i")],
|
||||
)
|
||||
)
|
|
@ -623,7 +623,9 @@ def test_is_distillable():
|
|||
assert ner.is_distillable
|
||||
|
||||
|
||||
def test_distill():
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("max_moves", [0, 1, 5, 100])
|
||||
def test_distill(max_moves):
|
||||
teacher = English()
|
||||
teacher_ner = teacher.add_pipe("ner")
|
||||
train_examples = []
|
||||
|
@ -641,6 +643,7 @@ def test_distill():
|
|||
|
||||
student = English()
|
||||
student_ner = student.add_pipe("ner")
|
||||
student_ner.cfg["update_with_oracle_cut_size"] = max_moves
|
||||
student_ner.initialize(
|
||||
get_examples=lambda: train_examples, labels=teacher_ner.label_data
|
||||
)
|
||||
|
|
|
@ -463,7 +463,9 @@ def test_is_distillable():
|
|||
assert parser.is_distillable
|
||||
|
||||
|
||||
def test_distill():
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("max_moves", [0, 1, 5, 100])
|
||||
def test_distill(max_moves):
|
||||
teacher = English()
|
||||
teacher_parser = teacher.add_pipe("parser")
|
||||
train_examples = []
|
||||
|
@ -481,6 +483,7 @@ def test_distill():
|
|||
|
||||
student = English()
|
||||
student_parser = student.add_pipe("parser")
|
||||
student_parser.cfg["update_with_oracle_cut_size"] = max_moves
|
||||
student_parser.initialize(
|
||||
get_examples=lambda: train_examples, labels=teacher_parser.label_data
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue
Block a user