Revert "Reimplement distillation with oracle cut size (#12214)"

This reverts commit e27c60a702.
This commit is contained in:
Daniël de Kok 2023-12-08 14:38:05 +01:00
parent 1b2d66f98e
commit 05803cfe76
5 changed files with 24 additions and 156 deletions

View File

@ -267,11 +267,9 @@ cdef list _parse_batch(CBlas cblas, TransitionSystem moves, StateC** states,
cdef np.ndarray step_actions
scores = []
while sizes.states >= 1 and (actions is None or len(actions) > 0):
while sizes.states >= 1:
step_scores = numpy.empty((sizes.states, sizes.classes), dtype="f")
step_actions = actions[0] if actions is not None else None
assert step_actions is None or step_actions.size == sizes.states, \
f"number of step actions ({step_actions.size}) must equal number of states ({sizes.states})"
with nogil:
_predict_states(cblas, &activations, <float*>step_scores.data, states, &weights, sizes)
if actions is None:

View File

@ -43,10 +43,6 @@ from ..training import (
from ._parser_internals import _beam_utils
# TODO: Remove when we switch to Cython 3.
cdef extern from "<algorithm>" namespace "std" nogil:
bint equal[InputIt1, InputIt2](InputIt1 first1, InputIt1 last1, InputIt2 first2) except +
NUMPY_OPS = NumpyOps()
@ -265,8 +261,8 @@ class Parser(TrainablePipe):
# batch uniform length. Since we do not have a gold standard
# sequence, we use the teacher's predictions as the gold
# standard.
max_moves = int(random.uniform(max(max_moves // 2, 1), max_moves * 2))
states = self._init_batch_from_teacher(teacher_pipe, student_docs, max_moves)
max_moves = int(random.uniform(max_moves // 2, max_moves * 2))
states = self._init_batch(teacher_pipe, student_docs, max_moves)
else:
states = self.moves.init_batch(student_docs)
@ -277,14 +273,12 @@ class Parser(TrainablePipe):
# gradients of the student's transition distributions relative to the
# teacher's distributions.
student_inputs = TransitionModelInputs(docs=student_docs,
states=[state.copy() for state in states],
moves=self.moves,
student_inputs = TransitionModelInputs(docs=student_docs, moves=self.moves,
max_moves=max_moves)
(student_states, student_scores), backprop_scores = self.model.begin_update(student_inputs)
actions = _states_diff_to_actions(states, student_states)
actions = states2actions(student_states)
teacher_inputs = TransitionModelInputs(docs=[eg.reference for eg in examples],
states=states, moves=teacher_pipe.moves, actions=actions)
moves=self.moves, actions=actions)
(_, teacher_scores) = teacher_pipe.model.predict(teacher_inputs)
loss, d_scores = self.get_teacher_student_loss(teacher_scores, student_scores)
@ -532,7 +526,7 @@ class Parser(TrainablePipe):
set_dropout_rate(self.model, 0.0)
student_inputs = TransitionModelInputs(docs=docs, moves=self.moves)
(student_states, student_scores), backprop_scores = self.model.begin_update(student_inputs)
actions = _states_to_actions(student_states)
actions = states2actions(student_states)
teacher_inputs = TransitionModelInputs(docs=docs, moves=self.moves, actions=actions)
_, teacher_scores = self._rehearsal_model.predict(teacher_inputs)
@ -652,7 +646,7 @@ class Parser(TrainablePipe):
raise ValueError(Errors.E149) from None
return self
def _init_batch_from_teacher(self, teacher_pipe, docs, max_length):
def _init_batch(self, teacher_step_model, docs, max_length):
"""Make a square batch of length equal to the shortest transition
sequence or a cap. A long
doc will get multiple states. Let's say we have a doc of length 2*N,
@ -661,12 +655,10 @@ class Parser(TrainablePipe):
_init_gold_batch, this version uses a teacher model to generate the
cut sequences."""
cdef:
StateClass start_state
StateClass state
TransitionSystem moves = teacher_pipe.moves
# Start with the same heuristic as in supervised training: exclude
# docs that are within the maximum length.
all_states = moves.init_batch(docs)
Transition action
all_states = self.moves.init_batch(docs)
states = []
to_cut = []
for state, doc in zip(all_states, docs):
@ -675,30 +667,19 @@ class Parser(TrainablePipe):
states.append(state)
else:
to_cut.append(state)
if not to_cut:
return states
# Parse the states that are too long with the teacher's parsing model.
teacher_inputs = TransitionModelInputs(docs=docs,
moves=moves,
states=[state.copy() for state in to_cut])
(teacher_states, _) = teacher_pipe.model.predict(teacher_inputs)
# Step through the teacher's actions and store every state after
# each multiple of max_length.
teacher_actions = _states_to_actions(teacher_states)
while to_cut:
states.extend(state.copy() for state in to_cut)
for step_actions in teacher_actions[:max_length]:
to_cut = moves.apply_actions(to_cut, step_actions)
teacher_actions = teacher_actions[max_length:]
if len(teacher_actions) < max_length:
break
# Move states forward max_length actions.
length = 0
while to_cut and length < max_length:
teacher_scores = teacher_step_model.predict(to_cut)
self.transition_states(to_cut, teacher_scores)
# States that are completed do not need further cutting.
to_cut = [state for state in to_cut if not state.is_final()]
length += 1
return states
def _init_gold_batch(self, examples, max_length):
"""Make a square batch, of length equal to the shortest transition
sequence or a cap. A long doc will get multiple states. Let's say we
@ -759,7 +740,7 @@ def _change_attrs(model, **kwargs):
model.attrs[key] = value
def _states_to_actions(states: List[StateClass]) -> List[Ints1d]:
def states2actions(states: List[StateClass]) -> List[Ints1d]:
cdef int step
cdef StateClass state
cdef StateC* c_state
@ -780,47 +761,3 @@ def _states_to_actions(states: List[StateClass]) -> List[Ints1d]:
actions.append(numpy.array(step_actions, dtype="i"))
return actions
def _states_diff_to_actions(
before_states: List[StateClass],
after_states: List[StateClass]
) -> List[Ints1d]:
"""
Return for two sets of states the actions to go from the first set of
states to the second set of states. The histories of the first set of
states must be a prefix of the second set of states.
"""
cdef StateClass before_state, after_state
cdef StateC* c_state_before
cdef StateC* c_state_after
assert len(before_states) == len(after_states)
# Check invariant: before states histories must be prefixes of after states.
for before_state, after_state in zip(before_states, after_states):
c_state_before = before_state.c
c_state_after = after_state.c
assert equal(c_state_before.history.begin(),
c_state_before.history.end(),
c_state_after.history.begin())
actions = []
while True:
step = len(actions)
step_actions = []
for before_state, after_state in zip(before_states, after_states):
c_state_before = before_state.c
c_state_after = after_state.c
if step < c_state_after.history.size() - c_state_before.history.size():
step_actions.append(c_state_after.history[c_state_before.history.size() + step])
# We are done if we have exhausted all histories.
if len(step_actions) == 0:
break
actions.append(numpy.array(step_actions, dtype="i"))
return actions

View File

@ -1,61 +0,0 @@
import numpy
import pytest
from spacy.lang.en import English
from spacy.ml.tb_framework import TransitionModelInputs
from spacy.training import Example
TRAIN_DATA = [
(
"They trade mortgage-backed securities.",
{
"heads": [1, 1, 4, 4, 5, 1, 1],
"deps": ["nsubj", "ROOT", "compound", "punct", "nmod", "dobj", "punct"],
},
),
(
"I like London and Berlin.",
{
"heads": [1, 1, 1, 2, 2, 1],
"deps": ["nsubj", "ROOT", "dobj", "cc", "conj", "punct"],
},
),
]
@pytest.fixture
def nlp_parser():
nlp = English()
parser = nlp.add_pipe("parser")
train_examples = []
for text, annotations in TRAIN_DATA:
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
for dep in annotations["deps"]:
parser.add_label(dep)
nlp.initialize()
return nlp, parser
def test_incorrect_number_of_actions(nlp_parser):
nlp, parser = nlp_parser
doc = nlp.make_doc("test")
# Too many actions for the number of docs
with pytest.raises(AssertionError):
parser.model.predict(
TransitionModelInputs(
docs=[doc], moves=parser.moves, actions=[numpy.array([0, 0], dtype="i")]
)
)
# Too few actions for the number of docs
with pytest.raises(AssertionError):
parser.model.predict(
TransitionModelInputs(
docs=[doc, doc],
moves=parser.moves,
actions=[numpy.array([0], dtype="i")],
)
)

View File

@ -623,9 +623,7 @@ def test_is_distillable():
assert ner.is_distillable
@pytest.mark.slow
@pytest.mark.parametrize("max_moves", [0, 1, 5, 100])
def test_distill(max_moves):
def test_distill():
teacher = English()
teacher_ner = teacher.add_pipe("ner")
train_examples = []
@ -643,7 +641,6 @@ def test_distill(max_moves):
student = English()
student_ner = student.add_pipe("ner")
student_ner.cfg["update_with_oracle_cut_size"] = max_moves
student_ner.initialize(
get_examples=lambda: train_examples, labels=teacher_ner.label_data
)

View File

@ -462,9 +462,7 @@ def test_is_distillable():
assert parser.is_distillable
@pytest.mark.slow
@pytest.mark.parametrize("max_moves", [0, 1, 5, 100])
def test_distill(max_moves):
def test_distill():
teacher = English()
teacher_parser = teacher.add_pipe("parser")
train_examples = []
@ -482,7 +480,6 @@ def test_distill(max_moves):
student = English()
student_parser = student.add_pipe("parser")
student_parser.cfg["update_with_oracle_cut_size"] = max_moves
student_parser.initialize(
get_examples=lambda: train_examples, labels=teacher_parser.label_data
)