mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 18:06:29 +03:00
Revert "Reimplement distillation with oracle cut size (#12214)"
This reverts commit e27c60a702
.
This commit is contained in:
parent
1b2d66f98e
commit
05803cfe76
|
@ -267,11 +267,9 @@ cdef list _parse_batch(CBlas cblas, TransitionSystem moves, StateC** states,
|
|||
cdef np.ndarray step_actions
|
||||
|
||||
scores = []
|
||||
while sizes.states >= 1 and (actions is None or len(actions) > 0):
|
||||
while sizes.states >= 1:
|
||||
step_scores = numpy.empty((sizes.states, sizes.classes), dtype="f")
|
||||
step_actions = actions[0] if actions is not None else None
|
||||
assert step_actions is None or step_actions.size == sizes.states, \
|
||||
f"number of step actions ({step_actions.size}) must equal number of states ({sizes.states})"
|
||||
with nogil:
|
||||
_predict_states(cblas, &activations, <float*>step_scores.data, states, &weights, sizes)
|
||||
if actions is None:
|
||||
|
|
|
@ -43,10 +43,6 @@ from ..training import (
|
|||
from ._parser_internals import _beam_utils
|
||||
|
||||
|
||||
# TODO: Remove when we switch to Cython 3.
|
||||
cdef extern from "<algorithm>" namespace "std" nogil:
|
||||
bint equal[InputIt1, InputIt2](InputIt1 first1, InputIt1 last1, InputIt2 first2) except +
|
||||
|
||||
NUMPY_OPS = NumpyOps()
|
||||
|
||||
|
||||
|
@ -265,8 +261,8 @@ class Parser(TrainablePipe):
|
|||
# batch uniform length. Since we do not have a gold standard
|
||||
# sequence, we use the teacher's predictions as the gold
|
||||
# standard.
|
||||
max_moves = int(random.uniform(max(max_moves // 2, 1), max_moves * 2))
|
||||
states = self._init_batch_from_teacher(teacher_pipe, student_docs, max_moves)
|
||||
max_moves = int(random.uniform(max_moves // 2, max_moves * 2))
|
||||
states = self._init_batch(teacher_pipe, student_docs, max_moves)
|
||||
else:
|
||||
states = self.moves.init_batch(student_docs)
|
||||
|
||||
|
@ -277,14 +273,12 @@ class Parser(TrainablePipe):
|
|||
# gradients of the student's transition distributions relative to the
|
||||
# teacher's distributions.
|
||||
|
||||
student_inputs = TransitionModelInputs(docs=student_docs,
|
||||
states=[state.copy() for state in states],
|
||||
moves=self.moves,
|
||||
max_moves=max_moves)
|
||||
student_inputs = TransitionModelInputs(docs=student_docs, moves=self.moves,
|
||||
max_moves=max_moves)
|
||||
(student_states, student_scores), backprop_scores = self.model.begin_update(student_inputs)
|
||||
actions = _states_diff_to_actions(states, student_states)
|
||||
actions = states2actions(student_states)
|
||||
teacher_inputs = TransitionModelInputs(docs=[eg.reference for eg in examples],
|
||||
states=states, moves=teacher_pipe.moves, actions=actions)
|
||||
moves=self.moves, actions=actions)
|
||||
(_, teacher_scores) = teacher_pipe.model.predict(teacher_inputs)
|
||||
|
||||
loss, d_scores = self.get_teacher_student_loss(teacher_scores, student_scores)
|
||||
|
@ -532,7 +526,7 @@ class Parser(TrainablePipe):
|
|||
set_dropout_rate(self.model, 0.0)
|
||||
student_inputs = TransitionModelInputs(docs=docs, moves=self.moves)
|
||||
(student_states, student_scores), backprop_scores = self.model.begin_update(student_inputs)
|
||||
actions = _states_to_actions(student_states)
|
||||
actions = states2actions(student_states)
|
||||
teacher_inputs = TransitionModelInputs(docs=docs, moves=self.moves, actions=actions)
|
||||
_, teacher_scores = self._rehearsal_model.predict(teacher_inputs)
|
||||
|
||||
|
@ -652,7 +646,7 @@ class Parser(TrainablePipe):
|
|||
raise ValueError(Errors.E149) from None
|
||||
return self
|
||||
|
||||
def _init_batch_from_teacher(self, teacher_pipe, docs, max_length):
|
||||
def _init_batch(self, teacher_step_model, docs, max_length):
|
||||
"""Make a square batch of length equal to the shortest transition
|
||||
sequence or a cap. A long
|
||||
doc will get multiple states. Let's say we have a doc of length 2*N,
|
||||
|
@ -661,12 +655,10 @@ class Parser(TrainablePipe):
|
|||
_init_gold_batch, this version uses a teacher model to generate the
|
||||
cut sequences."""
|
||||
cdef:
|
||||
StateClass start_state
|
||||
StateClass state
|
||||
TransitionSystem moves = teacher_pipe.moves
|
||||
|
||||
# Start with the same heuristic as in supervised training: exclude
|
||||
# docs that are within the maximum length.
|
||||
all_states = moves.init_batch(docs)
|
||||
Transition action
|
||||
all_states = self.moves.init_batch(docs)
|
||||
states = []
|
||||
to_cut = []
|
||||
for state, doc in zip(all_states, docs):
|
||||
|
@ -675,30 +667,19 @@ class Parser(TrainablePipe):
|
|||
states.append(state)
|
||||
else:
|
||||
to_cut.append(state)
|
||||
|
||||
if not to_cut:
|
||||
return states
|
||||
|
||||
# Parse the states that are too long with the teacher's parsing model.
|
||||
teacher_inputs = TransitionModelInputs(docs=docs,
|
||||
moves=moves,
|
||||
states=[state.copy() for state in to_cut])
|
||||
(teacher_states, _) = teacher_pipe.model.predict(teacher_inputs)
|
||||
|
||||
# Step through the teacher's actions and store every state after
|
||||
# each multiple of max_length.
|
||||
teacher_actions = _states_to_actions(teacher_states)
|
||||
while to_cut:
|
||||
states.extend(state.copy() for state in to_cut)
|
||||
for step_actions in teacher_actions[:max_length]:
|
||||
to_cut = moves.apply_actions(to_cut, step_actions)
|
||||
teacher_actions = teacher_actions[max_length:]
|
||||
|
||||
if len(teacher_actions) < max_length:
|
||||
break
|
||||
|
||||
# Move states forward max_length actions.
|
||||
length = 0
|
||||
while to_cut and length < max_length:
|
||||
teacher_scores = teacher_step_model.predict(to_cut)
|
||||
self.transition_states(to_cut, teacher_scores)
|
||||
# States that are completed do not need further cutting.
|
||||
to_cut = [state for state in to_cut if not state.is_final()]
|
||||
length += 1
|
||||
return states
|
||||
|
||||
|
||||
def _init_gold_batch(self, examples, max_length):
|
||||
"""Make a square batch, of length equal to the shortest transition
|
||||
sequence or a cap. A long doc will get multiple states. Let's say we
|
||||
|
@ -759,7 +740,7 @@ def _change_attrs(model, **kwargs):
|
|||
model.attrs[key] = value
|
||||
|
||||
|
||||
def _states_to_actions(states: List[StateClass]) -> List[Ints1d]:
|
||||
def states2actions(states: List[StateClass]) -> List[Ints1d]:
|
||||
cdef int step
|
||||
cdef StateClass state
|
||||
cdef StateC* c_state
|
||||
|
@ -780,47 +761,3 @@ def _states_to_actions(states: List[StateClass]) -> List[Ints1d]:
|
|||
actions.append(numpy.array(step_actions, dtype="i"))
|
||||
|
||||
return actions
|
||||
|
||||
|
||||
def _states_diff_to_actions(
|
||||
before_states: List[StateClass],
|
||||
after_states: List[StateClass]
|
||||
) -> List[Ints1d]:
|
||||
"""
|
||||
Return for two sets of states the actions to go from the first set of
|
||||
states to the second set of states. The histories of the first set of
|
||||
states must be a prefix of the second set of states.
|
||||
"""
|
||||
cdef StateClass before_state, after_state
|
||||
cdef StateC* c_state_before
|
||||
cdef StateC* c_state_after
|
||||
|
||||
assert len(before_states) == len(after_states)
|
||||
|
||||
# Check invariant: before states histories must be prefixes of after states.
|
||||
for before_state, after_state in zip(before_states, after_states):
|
||||
c_state_before = before_state.c
|
||||
c_state_after = after_state.c
|
||||
|
||||
assert equal(c_state_before.history.begin(),
|
||||
c_state_before.history.end(),
|
||||
c_state_after.history.begin())
|
||||
|
||||
actions = []
|
||||
while True:
|
||||
step = len(actions)
|
||||
|
||||
step_actions = []
|
||||
for before_state, after_state in zip(before_states, after_states):
|
||||
c_state_before = before_state.c
|
||||
c_state_after = after_state.c
|
||||
if step < c_state_after.history.size() - c_state_before.history.size():
|
||||
step_actions.append(c_state_after.history[c_state_before.history.size() + step])
|
||||
|
||||
# We are done if we have exhausted all histories.
|
||||
if len(step_actions) == 0:
|
||||
break
|
||||
|
||||
actions.append(numpy.array(step_actions, dtype="i"))
|
||||
|
||||
return actions
|
||||
|
|
|
@ -1,61 +0,0 @@
|
|||
import numpy
|
||||
import pytest
|
||||
|
||||
from spacy.lang.en import English
|
||||
from spacy.ml.tb_framework import TransitionModelInputs
|
||||
from spacy.training import Example
|
||||
|
||||
TRAIN_DATA = [
|
||||
(
|
||||
"They trade mortgage-backed securities.",
|
||||
{
|
||||
"heads": [1, 1, 4, 4, 5, 1, 1],
|
||||
"deps": ["nsubj", "ROOT", "compound", "punct", "nmod", "dobj", "punct"],
|
||||
},
|
||||
),
|
||||
(
|
||||
"I like London and Berlin.",
|
||||
{
|
||||
"heads": [1, 1, 1, 2, 2, 1],
|
||||
"deps": ["nsubj", "ROOT", "dobj", "cc", "conj", "punct"],
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def nlp_parser():
|
||||
nlp = English()
|
||||
parser = nlp.add_pipe("parser")
|
||||
|
||||
train_examples = []
|
||||
for text, annotations in TRAIN_DATA:
|
||||
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
|
||||
for dep in annotations["deps"]:
|
||||
parser.add_label(dep)
|
||||
nlp.initialize()
|
||||
|
||||
return nlp, parser
|
||||
|
||||
|
||||
def test_incorrect_number_of_actions(nlp_parser):
|
||||
nlp, parser = nlp_parser
|
||||
doc = nlp.make_doc("test")
|
||||
|
||||
# Too many actions for the number of docs
|
||||
with pytest.raises(AssertionError):
|
||||
parser.model.predict(
|
||||
TransitionModelInputs(
|
||||
docs=[doc], moves=parser.moves, actions=[numpy.array([0, 0], dtype="i")]
|
||||
)
|
||||
)
|
||||
|
||||
# Too few actions for the number of docs
|
||||
with pytest.raises(AssertionError):
|
||||
parser.model.predict(
|
||||
TransitionModelInputs(
|
||||
docs=[doc, doc],
|
||||
moves=parser.moves,
|
||||
actions=[numpy.array([0], dtype="i")],
|
||||
)
|
||||
)
|
|
@ -623,9 +623,7 @@ def test_is_distillable():
|
|||
assert ner.is_distillable
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("max_moves", [0, 1, 5, 100])
|
||||
def test_distill(max_moves):
|
||||
def test_distill():
|
||||
teacher = English()
|
||||
teacher_ner = teacher.add_pipe("ner")
|
||||
train_examples = []
|
||||
|
@ -643,7 +641,6 @@ def test_distill(max_moves):
|
|||
|
||||
student = English()
|
||||
student_ner = student.add_pipe("ner")
|
||||
student_ner.cfg["update_with_oracle_cut_size"] = max_moves
|
||||
student_ner.initialize(
|
||||
get_examples=lambda: train_examples, labels=teacher_ner.label_data
|
||||
)
|
||||
|
|
|
@ -462,9 +462,7 @@ def test_is_distillable():
|
|||
assert parser.is_distillable
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("max_moves", [0, 1, 5, 100])
|
||||
def test_distill(max_moves):
|
||||
def test_distill():
|
||||
teacher = English()
|
||||
teacher_parser = teacher.add_pipe("parser")
|
||||
train_examples = []
|
||||
|
@ -482,7 +480,6 @@ def test_distill(max_moves):
|
|||
|
||||
student = English()
|
||||
student_parser = student.add_pipe("parser")
|
||||
student_parser.cfg["update_with_oracle_cut_size"] = max_moves
|
||||
student_parser.initialize(
|
||||
get_examples=lambda: train_examples, labels=teacher_parser.label_data
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue
Block a user