mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-28 02:46:35 +03:00
Revert "Reimplement distillation with oracle cut size (#12214)"
This reverts commit e27c60a702
.
This commit is contained in:
parent
1b2d66f98e
commit
05803cfe76
|
@ -267,11 +267,9 @@ cdef list _parse_batch(CBlas cblas, TransitionSystem moves, StateC** states,
|
||||||
cdef np.ndarray step_actions
|
cdef np.ndarray step_actions
|
||||||
|
|
||||||
scores = []
|
scores = []
|
||||||
while sizes.states >= 1 and (actions is None or len(actions) > 0):
|
while sizes.states >= 1:
|
||||||
step_scores = numpy.empty((sizes.states, sizes.classes), dtype="f")
|
step_scores = numpy.empty((sizes.states, sizes.classes), dtype="f")
|
||||||
step_actions = actions[0] if actions is not None else None
|
step_actions = actions[0] if actions is not None else None
|
||||||
assert step_actions is None or step_actions.size == sizes.states, \
|
|
||||||
f"number of step actions ({step_actions.size}) must equal number of states ({sizes.states})"
|
|
||||||
with nogil:
|
with nogil:
|
||||||
_predict_states(cblas, &activations, <float*>step_scores.data, states, &weights, sizes)
|
_predict_states(cblas, &activations, <float*>step_scores.data, states, &weights, sizes)
|
||||||
if actions is None:
|
if actions is None:
|
||||||
|
|
|
@ -43,10 +43,6 @@ from ..training import (
|
||||||
from ._parser_internals import _beam_utils
|
from ._parser_internals import _beam_utils
|
||||||
|
|
||||||
|
|
||||||
# TODO: Remove when we switch to Cython 3.
|
|
||||||
cdef extern from "<algorithm>" namespace "std" nogil:
|
|
||||||
bint equal[InputIt1, InputIt2](InputIt1 first1, InputIt1 last1, InputIt2 first2) except +
|
|
||||||
|
|
||||||
NUMPY_OPS = NumpyOps()
|
NUMPY_OPS = NumpyOps()
|
||||||
|
|
||||||
|
|
||||||
|
@ -265,8 +261,8 @@ class Parser(TrainablePipe):
|
||||||
# batch uniform length. Since we do not have a gold standard
|
# batch uniform length. Since we do not have a gold standard
|
||||||
# sequence, we use the teacher's predictions as the gold
|
# sequence, we use the teacher's predictions as the gold
|
||||||
# standard.
|
# standard.
|
||||||
max_moves = int(random.uniform(max(max_moves // 2, 1), max_moves * 2))
|
max_moves = int(random.uniform(max_moves // 2, max_moves * 2))
|
||||||
states = self._init_batch_from_teacher(teacher_pipe, student_docs, max_moves)
|
states = self._init_batch(teacher_pipe, student_docs, max_moves)
|
||||||
else:
|
else:
|
||||||
states = self.moves.init_batch(student_docs)
|
states = self.moves.init_batch(student_docs)
|
||||||
|
|
||||||
|
@ -277,14 +273,12 @@ class Parser(TrainablePipe):
|
||||||
# gradients of the student's transition distributions relative to the
|
# gradients of the student's transition distributions relative to the
|
||||||
# teacher's distributions.
|
# teacher's distributions.
|
||||||
|
|
||||||
student_inputs = TransitionModelInputs(docs=student_docs,
|
student_inputs = TransitionModelInputs(docs=student_docs, moves=self.moves,
|
||||||
states=[state.copy() for state in states],
|
|
||||||
moves=self.moves,
|
|
||||||
max_moves=max_moves)
|
max_moves=max_moves)
|
||||||
(student_states, student_scores), backprop_scores = self.model.begin_update(student_inputs)
|
(student_states, student_scores), backprop_scores = self.model.begin_update(student_inputs)
|
||||||
actions = _states_diff_to_actions(states, student_states)
|
actions = states2actions(student_states)
|
||||||
teacher_inputs = TransitionModelInputs(docs=[eg.reference for eg in examples],
|
teacher_inputs = TransitionModelInputs(docs=[eg.reference for eg in examples],
|
||||||
states=states, moves=teacher_pipe.moves, actions=actions)
|
moves=self.moves, actions=actions)
|
||||||
(_, teacher_scores) = teacher_pipe.model.predict(teacher_inputs)
|
(_, teacher_scores) = teacher_pipe.model.predict(teacher_inputs)
|
||||||
|
|
||||||
loss, d_scores = self.get_teacher_student_loss(teacher_scores, student_scores)
|
loss, d_scores = self.get_teacher_student_loss(teacher_scores, student_scores)
|
||||||
|
@ -532,7 +526,7 @@ class Parser(TrainablePipe):
|
||||||
set_dropout_rate(self.model, 0.0)
|
set_dropout_rate(self.model, 0.0)
|
||||||
student_inputs = TransitionModelInputs(docs=docs, moves=self.moves)
|
student_inputs = TransitionModelInputs(docs=docs, moves=self.moves)
|
||||||
(student_states, student_scores), backprop_scores = self.model.begin_update(student_inputs)
|
(student_states, student_scores), backprop_scores = self.model.begin_update(student_inputs)
|
||||||
actions = _states_to_actions(student_states)
|
actions = states2actions(student_states)
|
||||||
teacher_inputs = TransitionModelInputs(docs=docs, moves=self.moves, actions=actions)
|
teacher_inputs = TransitionModelInputs(docs=docs, moves=self.moves, actions=actions)
|
||||||
_, teacher_scores = self._rehearsal_model.predict(teacher_inputs)
|
_, teacher_scores = self._rehearsal_model.predict(teacher_inputs)
|
||||||
|
|
||||||
|
@ -652,7 +646,7 @@ class Parser(TrainablePipe):
|
||||||
raise ValueError(Errors.E149) from None
|
raise ValueError(Errors.E149) from None
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def _init_batch_from_teacher(self, teacher_pipe, docs, max_length):
|
def _init_batch(self, teacher_step_model, docs, max_length):
|
||||||
"""Make a square batch of length equal to the shortest transition
|
"""Make a square batch of length equal to the shortest transition
|
||||||
sequence or a cap. A long
|
sequence or a cap. A long
|
||||||
doc will get multiple states. Let's say we have a doc of length 2*N,
|
doc will get multiple states. Let's say we have a doc of length 2*N,
|
||||||
|
@ -661,12 +655,10 @@ class Parser(TrainablePipe):
|
||||||
_init_gold_batch, this version uses a teacher model to generate the
|
_init_gold_batch, this version uses a teacher model to generate the
|
||||||
cut sequences."""
|
cut sequences."""
|
||||||
cdef:
|
cdef:
|
||||||
|
StateClass start_state
|
||||||
StateClass state
|
StateClass state
|
||||||
TransitionSystem moves = teacher_pipe.moves
|
Transition action
|
||||||
|
all_states = self.moves.init_batch(docs)
|
||||||
# Start with the same heuristic as in supervised training: exclude
|
|
||||||
# docs that are within the maximum length.
|
|
||||||
all_states = moves.init_batch(docs)
|
|
||||||
states = []
|
states = []
|
||||||
to_cut = []
|
to_cut = []
|
||||||
for state, doc in zip(all_states, docs):
|
for state, doc in zip(all_states, docs):
|
||||||
|
@ -675,30 +667,19 @@ class Parser(TrainablePipe):
|
||||||
states.append(state)
|
states.append(state)
|
||||||
else:
|
else:
|
||||||
to_cut.append(state)
|
to_cut.append(state)
|
||||||
|
|
||||||
if not to_cut:
|
|
||||||
return states
|
|
||||||
|
|
||||||
# Parse the states that are too long with the teacher's parsing model.
|
|
||||||
teacher_inputs = TransitionModelInputs(docs=docs,
|
|
||||||
moves=moves,
|
|
||||||
states=[state.copy() for state in to_cut])
|
|
||||||
(teacher_states, _) = teacher_pipe.model.predict(teacher_inputs)
|
|
||||||
|
|
||||||
# Step through the teacher's actions and store every state after
|
|
||||||
# each multiple of max_length.
|
|
||||||
teacher_actions = _states_to_actions(teacher_states)
|
|
||||||
while to_cut:
|
while to_cut:
|
||||||
states.extend(state.copy() for state in to_cut)
|
states.extend(state.copy() for state in to_cut)
|
||||||
for step_actions in teacher_actions[:max_length]:
|
# Move states forward max_length actions.
|
||||||
to_cut = moves.apply_actions(to_cut, step_actions)
|
length = 0
|
||||||
teacher_actions = teacher_actions[max_length:]
|
while to_cut and length < max_length:
|
||||||
|
teacher_scores = teacher_step_model.predict(to_cut)
|
||||||
if len(teacher_actions) < max_length:
|
self.transition_states(to_cut, teacher_scores)
|
||||||
break
|
# States that are completed do not need further cutting.
|
||||||
|
to_cut = [state for state in to_cut if not state.is_final()]
|
||||||
|
length += 1
|
||||||
return states
|
return states
|
||||||
|
|
||||||
|
|
||||||
def _init_gold_batch(self, examples, max_length):
|
def _init_gold_batch(self, examples, max_length):
|
||||||
"""Make a square batch, of length equal to the shortest transition
|
"""Make a square batch, of length equal to the shortest transition
|
||||||
sequence or a cap. A long doc will get multiple states. Let's say we
|
sequence or a cap. A long doc will get multiple states. Let's say we
|
||||||
|
@ -759,7 +740,7 @@ def _change_attrs(model, **kwargs):
|
||||||
model.attrs[key] = value
|
model.attrs[key] = value
|
||||||
|
|
||||||
|
|
||||||
def _states_to_actions(states: List[StateClass]) -> List[Ints1d]:
|
def states2actions(states: List[StateClass]) -> List[Ints1d]:
|
||||||
cdef int step
|
cdef int step
|
||||||
cdef StateClass state
|
cdef StateClass state
|
||||||
cdef StateC* c_state
|
cdef StateC* c_state
|
||||||
|
@ -780,47 +761,3 @@ def _states_to_actions(states: List[StateClass]) -> List[Ints1d]:
|
||||||
actions.append(numpy.array(step_actions, dtype="i"))
|
actions.append(numpy.array(step_actions, dtype="i"))
|
||||||
|
|
||||||
return actions
|
return actions
|
||||||
|
|
||||||
|
|
||||||
def _states_diff_to_actions(
|
|
||||||
before_states: List[StateClass],
|
|
||||||
after_states: List[StateClass]
|
|
||||||
) -> List[Ints1d]:
|
|
||||||
"""
|
|
||||||
Return for two sets of states the actions to go from the first set of
|
|
||||||
states to the second set of states. The histories of the first set of
|
|
||||||
states must be a prefix of the second set of states.
|
|
||||||
"""
|
|
||||||
cdef StateClass before_state, after_state
|
|
||||||
cdef StateC* c_state_before
|
|
||||||
cdef StateC* c_state_after
|
|
||||||
|
|
||||||
assert len(before_states) == len(after_states)
|
|
||||||
|
|
||||||
# Check invariant: before states histories must be prefixes of after states.
|
|
||||||
for before_state, after_state in zip(before_states, after_states):
|
|
||||||
c_state_before = before_state.c
|
|
||||||
c_state_after = after_state.c
|
|
||||||
|
|
||||||
assert equal(c_state_before.history.begin(),
|
|
||||||
c_state_before.history.end(),
|
|
||||||
c_state_after.history.begin())
|
|
||||||
|
|
||||||
actions = []
|
|
||||||
while True:
|
|
||||||
step = len(actions)
|
|
||||||
|
|
||||||
step_actions = []
|
|
||||||
for before_state, after_state in zip(before_states, after_states):
|
|
||||||
c_state_before = before_state.c
|
|
||||||
c_state_after = after_state.c
|
|
||||||
if step < c_state_after.history.size() - c_state_before.history.size():
|
|
||||||
step_actions.append(c_state_after.history[c_state_before.history.size() + step])
|
|
||||||
|
|
||||||
# We are done if we have exhausted all histories.
|
|
||||||
if len(step_actions) == 0:
|
|
||||||
break
|
|
||||||
|
|
||||||
actions.append(numpy.array(step_actions, dtype="i"))
|
|
||||||
|
|
||||||
return actions
|
|
||||||
|
|
|
@ -1,61 +0,0 @@
|
||||||
import numpy
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from spacy.lang.en import English
|
|
||||||
from spacy.ml.tb_framework import TransitionModelInputs
|
|
||||||
from spacy.training import Example
|
|
||||||
|
|
||||||
TRAIN_DATA = [
|
|
||||||
(
|
|
||||||
"They trade mortgage-backed securities.",
|
|
||||||
{
|
|
||||||
"heads": [1, 1, 4, 4, 5, 1, 1],
|
|
||||||
"deps": ["nsubj", "ROOT", "compound", "punct", "nmod", "dobj", "punct"],
|
|
||||||
},
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"I like London and Berlin.",
|
|
||||||
{
|
|
||||||
"heads": [1, 1, 1, 2, 2, 1],
|
|
||||||
"deps": ["nsubj", "ROOT", "dobj", "cc", "conj", "punct"],
|
|
||||||
},
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def nlp_parser():
|
|
||||||
nlp = English()
|
|
||||||
parser = nlp.add_pipe("parser")
|
|
||||||
|
|
||||||
train_examples = []
|
|
||||||
for text, annotations in TRAIN_DATA:
|
|
||||||
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
|
|
||||||
for dep in annotations["deps"]:
|
|
||||||
parser.add_label(dep)
|
|
||||||
nlp.initialize()
|
|
||||||
|
|
||||||
return nlp, parser
|
|
||||||
|
|
||||||
|
|
||||||
def test_incorrect_number_of_actions(nlp_parser):
|
|
||||||
nlp, parser = nlp_parser
|
|
||||||
doc = nlp.make_doc("test")
|
|
||||||
|
|
||||||
# Too many actions for the number of docs
|
|
||||||
with pytest.raises(AssertionError):
|
|
||||||
parser.model.predict(
|
|
||||||
TransitionModelInputs(
|
|
||||||
docs=[doc], moves=parser.moves, actions=[numpy.array([0, 0], dtype="i")]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Too few actions for the number of docs
|
|
||||||
with pytest.raises(AssertionError):
|
|
||||||
parser.model.predict(
|
|
||||||
TransitionModelInputs(
|
|
||||||
docs=[doc, doc],
|
|
||||||
moves=parser.moves,
|
|
||||||
actions=[numpy.array([0], dtype="i")],
|
|
||||||
)
|
|
||||||
)
|
|
|
@ -623,9 +623,7 @@ def test_is_distillable():
|
||||||
assert ner.is_distillable
|
assert ner.is_distillable
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
def test_distill():
|
||||||
@pytest.mark.parametrize("max_moves", [0, 1, 5, 100])
|
|
||||||
def test_distill(max_moves):
|
|
||||||
teacher = English()
|
teacher = English()
|
||||||
teacher_ner = teacher.add_pipe("ner")
|
teacher_ner = teacher.add_pipe("ner")
|
||||||
train_examples = []
|
train_examples = []
|
||||||
|
@ -643,7 +641,6 @@ def test_distill(max_moves):
|
||||||
|
|
||||||
student = English()
|
student = English()
|
||||||
student_ner = student.add_pipe("ner")
|
student_ner = student.add_pipe("ner")
|
||||||
student_ner.cfg["update_with_oracle_cut_size"] = max_moves
|
|
||||||
student_ner.initialize(
|
student_ner.initialize(
|
||||||
get_examples=lambda: train_examples, labels=teacher_ner.label_data
|
get_examples=lambda: train_examples, labels=teacher_ner.label_data
|
||||||
)
|
)
|
||||||
|
|
|
@ -462,9 +462,7 @@ def test_is_distillable():
|
||||||
assert parser.is_distillable
|
assert parser.is_distillable
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
def test_distill():
|
||||||
@pytest.mark.parametrize("max_moves", [0, 1, 5, 100])
|
|
||||||
def test_distill(max_moves):
|
|
||||||
teacher = English()
|
teacher = English()
|
||||||
teacher_parser = teacher.add_pipe("parser")
|
teacher_parser = teacher.add_pipe("parser")
|
||||||
train_examples = []
|
train_examples = []
|
||||||
|
@ -482,7 +480,6 @@ def test_distill(max_moves):
|
||||||
|
|
||||||
student = English()
|
student = English()
|
||||||
student_parser = student.add_pipe("parser")
|
student_parser = student.add_pipe("parser")
|
||||||
student_parser.cfg["update_with_oracle_cut_size"] = max_moves
|
|
||||||
student_parser.initialize(
|
student_parser.initialize(
|
||||||
get_examples=lambda: train_examples, labels=teacher_parser.label_data
|
get_examples=lambda: train_examples, labels=teacher_parser.label_data
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user