mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
Add TrainablePipe.{distill,get_teacher_student_loss}
(#12016)
* Add `TrainablePipe.{distill,get_teacher_student_loss}` This change adds two methods: - `TrainablePipe::distill` which performs a training step of a student pipe on a teacher pipe, giving a batch of `Doc`s. - `TrainablePipe::get_teacher_student_loss` computes the loss of a student relative to the teacher. The `distill` or `get_teacher_student_loss` methods are also implemented in the tagger, edit tree lemmatizer, and parser pipes, to enable distillation in those pipes and as an example for other pipes. * Fix stray `Beam` import * Fix incorrect import * Apply suggestions from code review Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> * TrainablePipe.distill: use `Iterable[Example]` * Add Pipe.is_distillable method * Add `validate_distillation_examples` This first calls `validate_examples` and then checks that the student/teacher tokens are the same. * Update distill documentation * Add distill documentation for all pipes that support distillation * Fix incorrect identifier * Apply suggestions from code review Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> * Add comment to explain `is_distillable` Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
This commit is contained in:
parent
c2f3e699ca
commit
5e297aa20e
|
@ -955,6 +955,9 @@ class Errors(metaclass=ErrorsWithCodes):
|
|||
E4000 = ("Expected a Doc as input, but got: '{type}'")
|
||||
E4001 = ("Expected input to be one of the following types: ({expected_types}), "
|
||||
"but got '{received_type}'")
|
||||
E4002 = ("Pipe '{name}' requires a teacher pipe for distillation.")
|
||||
E4003 = ("Training examples for distillation must have the exact same tokens in the "
|
||||
"reference and predicted docs.")
|
||||
|
||||
|
||||
# fmt: on
|
||||
|
|
|
@ -23,6 +23,7 @@ DEFAULT_NVTX_ANNOTATABLE_PIPE_METHODS = [
|
|||
"update",
|
||||
"rehearse",
|
||||
"get_loss",
|
||||
"get_teacher_student_loss",
|
||||
"initialize",
|
||||
"begin_update",
|
||||
"finish_update",
|
||||
|
|
|
@ -155,6 +155,25 @@ class EditTreeLemmatizer(TrainablePipe):
|
|||
|
||||
return float(loss), d_scores
|
||||
|
||||
def get_teacher_student_loss(
|
||||
self, teacher_scores: List[Floats2d], student_scores: List[Floats2d]
|
||||
) -> Tuple[float, List[Floats2d]]:
|
||||
"""Calculate the loss and its gradient for a batch of student
|
||||
scores, relative to teacher scores.
|
||||
|
||||
teacher_scores: Scores representing the teacher model's predictions.
|
||||
student_scores: Scores representing the student model's predictions.
|
||||
|
||||
RETURNS (Tuple[float, float]): The loss and the gradient.
|
||||
|
||||
DOCS: https://spacy.io/api/edittreelemmatizer#get_teacher_student_loss
|
||||
"""
|
||||
loss_func = LegacySequenceCategoricalCrossentropy(normalize=False)
|
||||
d_scores, loss = loss_func(student_scores, teacher_scores)
|
||||
if self.model.ops.xp.isnan(loss):
|
||||
raise ValueError(Errors.E910.format(name=self.name))
|
||||
return float(loss), d_scores
|
||||
|
||||
def predict(self, docs: Iterable[Doc]) -> ActivationsT:
|
||||
n_docs = len(list(docs))
|
||||
if not any(len(doc) for doc in docs):
|
||||
|
|
|
@ -87,6 +87,10 @@ cdef class Pipe:
|
|||
return self.scorer(examples, **scorer_kwargs)
|
||||
return {}
|
||||
|
||||
@property
|
||||
def is_distillable(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_trainable(self) -> bool:
|
||||
return False
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# cython: infer_types=True, profile=True, binding=True
|
||||
from typing import Callable, Dict, Iterable, List, Optional, Union
|
||||
from typing import Tuple
|
||||
import numpy
|
||||
import srsly
|
||||
from thinc.api import Model, set_dropout_rate, Config
|
||||
|
@ -245,7 +246,6 @@ class Tagger(TrainablePipe):
|
|||
|
||||
DOCS: https://spacy.io/api/tagger#rehearse
|
||||
"""
|
||||
loss_func = LegacySequenceCategoricalCrossentropy()
|
||||
if losses is None:
|
||||
losses = {}
|
||||
losses.setdefault(self.name, 0.0)
|
||||
|
@ -259,12 +259,32 @@ class Tagger(TrainablePipe):
|
|||
set_dropout_rate(self.model, drop)
|
||||
tag_scores, bp_tag_scores = self.model.begin_update(docs)
|
||||
tutor_tag_scores, _ = self._rehearsal_model.begin_update(docs)
|
||||
grads, loss = loss_func(tag_scores, tutor_tag_scores)
|
||||
loss, grads = self.get_teacher_student_loss(tutor_tag_scores, tag_scores)
|
||||
bp_tag_scores(grads)
|
||||
self.finish_update(sgd)
|
||||
if sgd is not None:
|
||||
self.finish_update(sgd)
|
||||
losses[self.name] += loss
|
||||
return losses
|
||||
|
||||
def get_teacher_student_loss(
|
||||
self, teacher_scores: List[Floats2d], student_scores: List[Floats2d]
|
||||
) -> Tuple[float, List[Floats2d]]:
|
||||
"""Calculate the loss and its gradient for a batch of student
|
||||
scores, relative to teacher scores.
|
||||
|
||||
teacher_scores: Scores representing the teacher model's predictions.
|
||||
student_scores: Scores representing the student model's predictions.
|
||||
|
||||
RETURNS (Tuple[float, float]): The loss and the gradient.
|
||||
|
||||
DOCS: https://spacy.io/api/tagger#get_teacher_student_loss
|
||||
"""
|
||||
loss_func = LegacySequenceCategoricalCrossentropy(normalize=False)
|
||||
d_scores, loss = loss_func(student_scores, teacher_scores)
|
||||
if self.model.ops.xp.isnan(loss):
|
||||
raise ValueError(Errors.E910.format(name=self.name))
|
||||
return float(loss), d_scores
|
||||
|
||||
def get_loss(self, examples, scores):
|
||||
"""Find the loss and gradient of loss for the batch of documents and
|
||||
their predicted scores.
|
||||
|
|
|
@ -6,7 +6,7 @@ import warnings
|
|||
|
||||
from ..tokens.doc cimport Doc
|
||||
|
||||
from ..training import validate_examples
|
||||
from ..training import validate_examples, validate_distillation_examples
|
||||
from ..errors import Errors, Warnings
|
||||
from .pipe import Pipe, deserialize_config
|
||||
from .. import util
|
||||
|
@ -56,6 +56,53 @@ cdef class TrainablePipe(Pipe):
|
|||
except Exception as e:
|
||||
error_handler(self.name, self, [doc], e)
|
||||
|
||||
|
||||
def distill(self,
|
||||
teacher_pipe: Optional["TrainablePipe"],
|
||||
examples: Iterable["Example"],
|
||||
*,
|
||||
drop: float=0.0,
|
||||
sgd: Optional[Optimizer]=None,
|
||||
losses: Optional[Dict[str, float]]=None) -> Dict[str, float]:
|
||||
"""Train a pipe (the student) on the predictions of another pipe
|
||||
(the teacher). The student is typically trained on the probability
|
||||
distribution of the teacher, but details may differ per pipe.
|
||||
|
||||
teacher_pipe (Optional[TrainablePipe]): The teacher pipe to learn
|
||||
from.
|
||||
examples (Iterable[Example]): Distillation examples. The reference
|
||||
and predicted docs must have the same number of tokens and the
|
||||
same orthography.
|
||||
drop (float): dropout rate.
|
||||
sgd (Optional[Optimizer]): An optimizer. Will be created via
|
||||
create_optimizer if not set.
|
||||
losses (Optional[Dict[str, float]]): Optional record of loss during
|
||||
distillation.
|
||||
RETURNS: The updated losses dictionary.
|
||||
|
||||
DOCS: https://spacy.io/api/pipe#distill
|
||||
"""
|
||||
# By default we require a teacher pipe, but there are downstream
|
||||
# implementations that don't require a pipe.
|
||||
if teacher_pipe is None:
|
||||
raise ValueError(Errors.E4002.format(name=self.name))
|
||||
if losses is None:
|
||||
losses = {}
|
||||
losses.setdefault(self.name, 0.0)
|
||||
validate_distillation_examples(examples, "TrainablePipe.distill")
|
||||
set_dropout_rate(self.model, drop)
|
||||
for node in teacher_pipe.model.walk():
|
||||
if node.name == "softmax":
|
||||
node.attrs["softmax_normalize"] = True
|
||||
teacher_scores = teacher_pipe.model.predict([eg.reference for eg in examples])
|
||||
student_scores, bp_student_scores = self.model.begin_update([eg.predicted for eg in examples])
|
||||
loss, d_scores = self.get_teacher_student_loss(teacher_scores, student_scores)
|
||||
bp_student_scores(d_scores)
|
||||
if sgd is not None:
|
||||
self.finish_update(sgd)
|
||||
losses[self.name] += loss
|
||||
return losses
|
||||
|
||||
def pipe(self, stream: Iterable[Doc], *, batch_size: int=128) -> Iterator[Doc]:
|
||||
"""Apply the pipe to a stream of documents. This usually happens under
|
||||
the hood when the nlp object is called on a text and all components are
|
||||
|
@ -169,6 +216,19 @@ cdef class TrainablePipe(Pipe):
|
|||
"""
|
||||
raise NotImplementedError(Errors.E931.format(parent="TrainablePipe", method="get_loss", name=self.name))
|
||||
|
||||
def get_teacher_student_loss(self, teacher_scores, student_scores):
|
||||
"""Calculate the loss and its gradient for a batch of student
|
||||
scores, relative to teacher scores.
|
||||
|
||||
teacher_scores: Scores representing the teacher model's predictions.
|
||||
student_scores: Scores representing the student model's predictions.
|
||||
|
||||
RETURNS (Tuple[float, float]): The loss and the gradient.
|
||||
|
||||
DOCS: https://spacy.io/api/pipe#get_teacher_student_loss
|
||||
"""
|
||||
raise NotImplementedError(Errors.E931.format(parent="TrainablePipe", method="get_teacher_student_loss", name=self.name))
|
||||
|
||||
def create_optimizer(self) -> Optimizer:
|
||||
"""Create an optimizer for the pipeline component.
|
||||
|
||||
|
@ -205,6 +265,14 @@ cdef class TrainablePipe(Pipe):
|
|||
"""
|
||||
raise NotImplementedError(Errors.E931.format(parent="Pipe", method="add_label", name=self.name))
|
||||
|
||||
@property
|
||||
def is_distillable(self) -> bool:
|
||||
# Normally a pipe overrides `get_teacher_student_loss` to implement
|
||||
# distillation. In more exceptional cases, a pipe can provide its
|
||||
# own `distill` implementation. If neither of these methods is
|
||||
# overridden, the pipe does not implement distillation.
|
||||
return not (self.__class__.distill is TrainablePipe.distill and self.__class__.get_teacher_student_loss is TrainablePipe.get_teacher_student_loss)
|
||||
|
||||
@property
|
||||
def is_trainable(self) -> bool:
|
||||
return True
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# cython: infer_types=True, cdivision=True, boundscheck=False, binding=True
|
||||
from __future__ import print_function
|
||||
from typing import Dict, Iterable, List, Optional, Tuple
|
||||
from cymem.cymem cimport Pool
|
||||
cimport numpy as np
|
||||
from itertools import islice
|
||||
|
@ -9,7 +10,10 @@ from libc.stdlib cimport calloc, free
|
|||
import random
|
||||
|
||||
import srsly
|
||||
from thinc.api import get_ops, set_dropout_rate, CupyOps, NumpyOps
|
||||
from thinc.api import get_ops, set_dropout_rate, CupyOps, NumpyOps, Optimizer
|
||||
from thinc.api import chain, softmax_activation, use_ops
|
||||
from thinc.legacy import LegacySequenceCategoricalCrossentropy
|
||||
from thinc.types import Floats2d
|
||||
import numpy.random
|
||||
import numpy
|
||||
import warnings
|
||||
|
@ -26,6 +30,7 @@ from ._parser_internals cimport _beam_utils
|
|||
from ._parser_internals import _beam_utils
|
||||
|
||||
from ..training import validate_examples, validate_get_examples
|
||||
from ..training import validate_distillation_examples
|
||||
from ..errors import Errors, Warnings
|
||||
from .. import util
|
||||
|
||||
|
@ -203,6 +208,121 @@ cdef class Parser(TrainablePipe):
|
|||
# Defined in subclasses, to avoid circular import
|
||||
raise NotImplementedError
|
||||
|
||||
def distill(self,
|
||||
teacher_pipe: Optional[TrainablePipe],
|
||||
examples: Iterable["Example"],
|
||||
*,
|
||||
drop: float=0.0,
|
||||
sgd: Optional[Optimizer]=None,
|
||||
losses: Optional[Dict[str, float]]=None):
|
||||
"""Train a pipe (the student) on the predictions of another pipe
|
||||
(the teacher). The student is trained on the transition probabilities
|
||||
of the teacher.
|
||||
|
||||
teacher_pipe (Optional[TrainablePipe]): The teacher pipe to learn
|
||||
from.
|
||||
examples (Iterable[Example]): Distillation examples. The reference
|
||||
and predicted docs must have the same number of tokens and the
|
||||
same orthography.
|
||||
drop (float): dropout rate.
|
||||
sgd (Optional[Optimizer]): An optimizer. Will be created via
|
||||
create_optimizer if not set.
|
||||
losses (Optional[Dict[str, float]]): Optional record of loss during
|
||||
distillation.
|
||||
RETURNS: The updated losses dictionary.
|
||||
|
||||
DOCS: https://spacy.io/api/dependencyparser#distill
|
||||
"""
|
||||
if teacher_pipe is None:
|
||||
raise ValueError(Errors.E4002.format(name=self.name))
|
||||
if losses is None:
|
||||
losses = {}
|
||||
losses.setdefault(self.name, 0.0)
|
||||
|
||||
validate_distillation_examples(examples, "TransitionParser.distill")
|
||||
|
||||
set_dropout_rate(self.model, drop)
|
||||
|
||||
student_docs = [eg.predicted for eg in examples]
|
||||
|
||||
teacher_step_model = teacher_pipe.model.predict([eg.reference for eg in examples])
|
||||
student_step_model, backprop_tok2vec = self.model.begin_update(student_docs)
|
||||
|
||||
# Add softmax activation, so that we can compute student losses
|
||||
# with cross-entropy loss.
|
||||
with use_ops("numpy"):
|
||||
teacher_model = chain(teacher_step_model, softmax_activation())
|
||||
student_model = chain(student_step_model, softmax_activation())
|
||||
|
||||
max_moves = self.cfg["update_with_oracle_cut_size"]
|
||||
if max_moves >= 1:
|
||||
# Chop sequences into lengths of this many words, to make the
|
||||
# batch uniform length. Since we do not have a gold standard
|
||||
# sequence, we use the teacher's predictions as the gold
|
||||
# standard.
|
||||
max_moves = int(random.uniform(max_moves // 2, max_moves * 2))
|
||||
states = self._init_batch(teacher_step_model, student_docs, max_moves)
|
||||
else:
|
||||
states = self.moves.init_batch(student_docs)
|
||||
|
||||
loss = 0.0
|
||||
n_moves = 0
|
||||
while states:
|
||||
# We do distillation as follows: (1) for every state, we compute the
|
||||
# transition softmax distributions: (2) we backpropagate the error of
|
||||
# the student (compared to the teacher) into the student model; (3)
|
||||
# for all states, we move to the next state using the student's
|
||||
# predictions.
|
||||
teacher_scores = teacher_model.predict(states)
|
||||
student_scores, backprop = student_model.begin_update(states)
|
||||
state_loss, d_scores = self.get_teacher_student_loss(teacher_scores, student_scores)
|
||||
backprop(d_scores)
|
||||
loss += state_loss
|
||||
self.transition_states(states, student_scores)
|
||||
states = [state for state in states if not state.is_final()]
|
||||
|
||||
# Stop when we reach the maximum number of moves, otherwise we start
|
||||
# to process the remainder of cut sequences again.
|
||||
if max_moves >= 1 and n_moves >= max_moves:
|
||||
break
|
||||
n_moves += 1
|
||||
|
||||
backprop_tok2vec(student_docs)
|
||||
|
||||
if sgd is not None:
|
||||
self.finish_update(sgd)
|
||||
|
||||
losses[self.name] += loss
|
||||
|
||||
del backprop
|
||||
del backprop_tok2vec
|
||||
teacher_step_model.clear_memory()
|
||||
student_step_model.clear_memory()
|
||||
del teacher_model
|
||||
del student_model
|
||||
|
||||
return losses
|
||||
|
||||
|
||||
def get_teacher_student_loss(
|
||||
self, teacher_scores: List[Floats2d], student_scores: List[Floats2d]
|
||||
) -> Tuple[float, List[Floats2d]]:
|
||||
"""Calculate the loss and its gradient for a batch of student
|
||||
scores, relative to teacher scores.
|
||||
|
||||
teacher_scores: Scores representing the teacher model's predictions.
|
||||
student_scores: Scores representing the student model's predictions.
|
||||
|
||||
RETURNS (Tuple[float, float]): The loss and the gradient.
|
||||
|
||||
DOCS: https://spacy.io/api/dependencyparser#get_teacher_student_loss
|
||||
"""
|
||||
loss_func = LegacySequenceCategoricalCrossentropy(normalize=False)
|
||||
d_scores, loss = loss_func(student_scores, teacher_scores)
|
||||
if self.model.ops.xp.isnan(loss):
|
||||
raise ValueError(Errors.E910.format(name=self.name))
|
||||
return float(loss), d_scores
|
||||
|
||||
def init_multitask_objectives(self, get_examples, pipeline, **cfg):
|
||||
"""Setup models for secondary objectives, to benefit from multi-task
|
||||
learning. This method is intended to be overridden by subclasses.
|
||||
|
@ -625,6 +745,40 @@ cdef class Parser(TrainablePipe):
|
|||
raise ValueError(Errors.E149) from None
|
||||
return self
|
||||
|
||||
def _init_batch(self, teacher_step_model, docs, max_length):
|
||||
"""Make a square batch of length equal to the shortest transition
|
||||
sequence or a cap. A long
|
||||
doc will get multiple states. Let's say we have a doc of length 2*N,
|
||||
where N is the shortest doc. We'll make two states, one representing
|
||||
long_doc[:N], and another representing long_doc[N:]. In contrast to
|
||||
_init_gold_batch, this version uses a teacher model to generate the
|
||||
cut sequences."""
|
||||
cdef:
|
||||
StateClass start_state
|
||||
StateClass state
|
||||
Transition action
|
||||
all_states = self.moves.init_batch(docs)
|
||||
states = []
|
||||
to_cut = []
|
||||
for state, doc in zip(all_states, docs):
|
||||
if not state.is_final():
|
||||
if len(doc) < max_length:
|
||||
states.append(state)
|
||||
else:
|
||||
to_cut.append(state)
|
||||
while to_cut:
|
||||
states.extend(state.copy() for state in to_cut)
|
||||
# Move states forward max_length actions.
|
||||
length = 0
|
||||
while to_cut and length < max_length:
|
||||
teacher_scores = teacher_step_model.predict(to_cut)
|
||||
self.transition_states(to_cut, teacher_scores)
|
||||
# States that are completed do not need further cutting.
|
||||
to_cut = [state for state in to_cut if not state.is_final()]
|
||||
length += 1
|
||||
return states
|
||||
|
||||
|
||||
def _init_gold_batch(self, examples, max_length):
|
||||
"""Make a square batch, of length equal to the shortest transition
|
||||
sequence or a cap. A long
|
||||
|
|
|
@ -617,6 +617,52 @@ def test_overfitting_IO(use_upper):
|
|||
assert ents[1].kb_id == 0
|
||||
|
||||
|
||||
def test_is_distillable():
|
||||
nlp = English()
|
||||
ner = nlp.add_pipe("ner")
|
||||
assert ner.is_distillable
|
||||
|
||||
|
||||
def test_distill():
|
||||
teacher = English()
|
||||
teacher_ner = teacher.add_pipe("ner")
|
||||
train_examples = []
|
||||
for text, annotations in TRAIN_DATA:
|
||||
train_examples.append(Example.from_dict(teacher.make_doc(text), annotations))
|
||||
for ent in annotations.get("entities"):
|
||||
teacher_ner.add_label(ent[2])
|
||||
|
||||
optimizer = teacher.initialize(get_examples=lambda: train_examples)
|
||||
|
||||
for i in range(50):
|
||||
losses = {}
|
||||
teacher.update(train_examples, sgd=optimizer, losses=losses)
|
||||
assert losses["ner"] < 0.00001
|
||||
|
||||
student = English()
|
||||
student_ner = student.add_pipe("ner")
|
||||
student_ner.initialize(
|
||||
get_examples=lambda: train_examples, labels=teacher_ner.label_data
|
||||
)
|
||||
|
||||
distill_examples = [
|
||||
Example.from_dict(teacher.make_doc(t[0]), {}) for t in TRAIN_DATA
|
||||
]
|
||||
|
||||
for i in range(100):
|
||||
losses = {}
|
||||
student_ner.distill(teacher_ner, distill_examples, sgd=optimizer, losses=losses)
|
||||
assert losses["ner"] < 0.0001
|
||||
|
||||
# test the trained model
|
||||
test_text = "I like London."
|
||||
doc = student(test_text)
|
||||
ents = doc.ents
|
||||
assert len(ents) == 1
|
||||
assert ents[0].text == "London"
|
||||
assert ents[0].label_ == "LOC"
|
||||
|
||||
|
||||
def test_beam_ner_scores():
|
||||
# Test that we can get confidence values out of the beam_ner pipe
|
||||
beam_width = 16
|
||||
|
|
|
@ -396,6 +396,55 @@ def test_overfitting_IO(pipe_name):
|
|||
assert_equal(batch_deps_1, no_batch_deps)
|
||||
|
||||
|
||||
def test_is_distillable():
|
||||
nlp = English()
|
||||
parser = nlp.add_pipe("parser")
|
||||
assert parser.is_distillable
|
||||
|
||||
|
||||
def test_distill():
|
||||
teacher = English()
|
||||
teacher_parser = teacher.add_pipe("parser")
|
||||
train_examples = []
|
||||
for text, annotations in TRAIN_DATA:
|
||||
train_examples.append(Example.from_dict(teacher.make_doc(text), annotations))
|
||||
for dep in annotations.get("deps", []):
|
||||
teacher_parser.add_label(dep)
|
||||
|
||||
optimizer = teacher.initialize(get_examples=lambda: train_examples)
|
||||
|
||||
for i in range(200):
|
||||
losses = {}
|
||||
teacher.update(train_examples, sgd=optimizer, losses=losses)
|
||||
assert losses["parser"] < 0.0001
|
||||
|
||||
student = English()
|
||||
student_parser = student.add_pipe("parser")
|
||||
student_parser.initialize(
|
||||
get_examples=lambda: train_examples, labels=teacher_parser.label_data
|
||||
)
|
||||
|
||||
distill_examples = [
|
||||
Example.from_dict(teacher.make_doc(t[0]), {}) for t in TRAIN_DATA
|
||||
]
|
||||
|
||||
for i in range(200):
|
||||
losses = {}
|
||||
student_parser.distill(
|
||||
teacher_parser, distill_examples, sgd=optimizer, losses=losses
|
||||
)
|
||||
assert losses["parser"] < 0.0001
|
||||
|
||||
test_text = "I like securities."
|
||||
doc = student(test_text)
|
||||
assert doc[0].dep_ == "nsubj"
|
||||
assert doc[2].dep_ == "dobj"
|
||||
assert doc[3].dep_ == "punct"
|
||||
assert doc[0].head.i == 1
|
||||
assert doc[2].head.i == 1
|
||||
assert doc[3].head.i == 1
|
||||
|
||||
|
||||
# fmt: off
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("pipe_name", ["parser", "beam_parser"])
|
||||
|
|
|
@ -195,6 +195,53 @@ def test_overfitting_IO():
|
|||
assert doc4[3].lemma_ == "egg"
|
||||
|
||||
|
||||
def test_is_distillable():
|
||||
nlp = English()
|
||||
lemmatizer = nlp.add_pipe("trainable_lemmatizer")
|
||||
assert lemmatizer.is_distillable
|
||||
|
||||
|
||||
def test_distill():
|
||||
teacher = English()
|
||||
teacher_lemmatizer = teacher.add_pipe("trainable_lemmatizer")
|
||||
teacher_lemmatizer.min_tree_freq = 1
|
||||
train_examples = []
|
||||
for t in TRAIN_DATA:
|
||||
train_examples.append(Example.from_dict(teacher.make_doc(t[0]), t[1]))
|
||||
|
||||
optimizer = teacher.initialize(get_examples=lambda: train_examples)
|
||||
|
||||
for i in range(50):
|
||||
losses = {}
|
||||
teacher.update(train_examples, sgd=optimizer, losses=losses)
|
||||
assert losses["trainable_lemmatizer"] < 0.00001
|
||||
|
||||
student = English()
|
||||
student_lemmatizer = student.add_pipe("trainable_lemmatizer")
|
||||
student_lemmatizer.min_tree_freq = 1
|
||||
student_lemmatizer.initialize(
|
||||
get_examples=lambda: train_examples, labels=teacher_lemmatizer.label_data
|
||||
)
|
||||
|
||||
distill_examples = [
|
||||
Example.from_dict(teacher.make_doc(t[0]), {}) for t in TRAIN_DATA
|
||||
]
|
||||
|
||||
for i in range(50):
|
||||
losses = {}
|
||||
student_lemmatizer.distill(
|
||||
teacher_lemmatizer, distill_examples, sgd=optimizer, losses=losses
|
||||
)
|
||||
assert losses["trainable_lemmatizer"] < 0.00001
|
||||
|
||||
test_text = "She likes blue eggs"
|
||||
doc = student(test_text)
|
||||
assert doc[0].lemma_ == "she"
|
||||
assert doc[1].lemma_ == "like"
|
||||
assert doc[2].lemma_ == "blue"
|
||||
assert doc[3].lemma_ == "egg"
|
||||
|
||||
|
||||
def test_lemmatizer_requires_labels():
|
||||
nlp = English()
|
||||
nlp.add_pipe("trainable_lemmatizer")
|
||||
|
|
|
@ -50,6 +50,12 @@ def test_implicit_label():
|
|||
nlp.initialize(get_examples=lambda: train_examples)
|
||||
|
||||
|
||||
def test_is_distillable():
|
||||
nlp = English()
|
||||
morphologizer = nlp.add_pipe("morphologizer")
|
||||
assert morphologizer.is_distillable
|
||||
|
||||
|
||||
def test_no_resize():
|
||||
nlp = Language()
|
||||
morphologizer = nlp.add_pipe("morphologizer")
|
||||
|
|
|
@ -11,6 +11,12 @@ from spacy.pipeline import TrainablePipe
|
|||
from spacy.tests.util import make_tempdir
|
||||
|
||||
|
||||
def test_is_distillable():
|
||||
nlp = English()
|
||||
senter = nlp.add_pipe("senter")
|
||||
assert senter.is_distillable
|
||||
|
||||
|
||||
def test_label_types():
|
||||
nlp = Language()
|
||||
senter = nlp.add_pipe("senter")
|
||||
|
|
|
@ -213,6 +213,52 @@ def test_overfitting_IO():
|
|||
assert doc3[0].tag_ != "N"
|
||||
|
||||
|
||||
def test_is_distillable():
|
||||
nlp = English()
|
||||
tagger = nlp.add_pipe("tagger")
|
||||
assert tagger.is_distillable
|
||||
|
||||
|
||||
def test_distill():
|
||||
teacher = English()
|
||||
teacher_tagger = teacher.add_pipe("tagger")
|
||||
train_examples = []
|
||||
for t in TRAIN_DATA:
|
||||
train_examples.append(Example.from_dict(teacher.make_doc(t[0]), t[1]))
|
||||
|
||||
optimizer = teacher.initialize(get_examples=lambda: train_examples)
|
||||
|
||||
for i in range(50):
|
||||
losses = {}
|
||||
teacher.update(train_examples, sgd=optimizer, losses=losses)
|
||||
assert losses["tagger"] < 0.00001
|
||||
|
||||
student = English()
|
||||
student_tagger = student.add_pipe("tagger")
|
||||
student_tagger.min_tree_freq = 1
|
||||
student_tagger.initialize(
|
||||
get_examples=lambda: train_examples, labels=teacher_tagger.label_data
|
||||
)
|
||||
|
||||
distill_examples = [
|
||||
Example.from_dict(teacher.make_doc(t[0]), {}) for t in TRAIN_DATA
|
||||
]
|
||||
|
||||
for i in range(50):
|
||||
losses = {}
|
||||
student_tagger.distill(
|
||||
teacher_tagger, distill_examples, sgd=optimizer, losses=losses
|
||||
)
|
||||
assert losses["tagger"] < 0.00001
|
||||
|
||||
test_text = "I like blue eggs"
|
||||
doc = student(test_text)
|
||||
assert doc[0].tag_ == "N"
|
||||
assert doc[1].tag_ == "V"
|
||||
assert doc[2].tag_ == "J"
|
||||
assert doc[3].tag_ == "N"
|
||||
|
||||
|
||||
def test_save_activations():
|
||||
# Test if activations are correctly added to Doc when requested.
|
||||
nlp = English()
|
||||
|
|
|
@ -565,6 +565,12 @@ def test_initialize_examples(name, get_examples, train_data):
|
|||
nlp.initialize(get_examples=get_examples())
|
||||
|
||||
|
||||
def test_is_distillable():
|
||||
nlp = English()
|
||||
textcat = nlp.add_pipe("textcat")
|
||||
assert not textcat.is_distillable
|
||||
|
||||
|
||||
def test_overfitting_IO():
|
||||
# Simple test to try and quickly overfit the single-label textcat component - ensuring the ML models work correctly
|
||||
fix_random_seed(0)
|
||||
|
|
|
@ -8,7 +8,7 @@ from spacy.lang.en import English
|
|||
from spacy.tokens import Doc, DocBin
|
||||
from spacy.training import Alignment, Corpus, Example, biluo_tags_to_offsets
|
||||
from spacy.training import biluo_tags_to_spans, docs_to_json, iob_to_biluo
|
||||
from spacy.training import offsets_to_biluo_tags
|
||||
from spacy.training import offsets_to_biluo_tags, validate_distillation_examples
|
||||
from spacy.training.alignment_array import AlignmentArray
|
||||
from spacy.training.align import get_alignments
|
||||
from spacy.training.converters import json_to_docs
|
||||
|
@ -365,6 +365,19 @@ def test_example_from_dict_some_ner(en_vocab):
|
|||
assert ner_tags == ["U-LOC", None, None, None]
|
||||
|
||||
|
||||
def test_validate_distillation_examples(en_vocab):
|
||||
words = ["a", "b", "c", "d"]
|
||||
spaces = [True, True, False, True]
|
||||
predicted = Doc(en_vocab, words=words, spaces=spaces)
|
||||
|
||||
example = Example.from_dict(predicted, {})
|
||||
validate_distillation_examples([example], "test_validate_distillation_examples")
|
||||
|
||||
example = Example.from_dict(predicted, {"words": words + ["e"]})
|
||||
with pytest.raises(ValueError, match=r"distillation"):
|
||||
validate_distillation_examples([example], "test_validate_distillation_examples")
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore::UserWarning")
|
||||
def test_json_to_docs_no_ner(en_vocab):
|
||||
data = [
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from .corpus import Corpus, JsonlCorpus # noqa: F401
|
||||
from .example import Example, validate_examples, validate_get_examples # noqa: F401
|
||||
from .example import validate_distillation_examples # noqa: F401
|
||||
from .alignment import Alignment # noqa: F401
|
||||
from .augment import dont_augment, orth_variants_augmenter # noqa: F401
|
||||
from .iob_utils import iob_to_biluo, biluo_to_iob # noqa: F401
|
||||
|
|
|
@ -47,6 +47,13 @@ def validate_examples(examples, method):
|
|||
raise TypeError(err)
|
||||
|
||||
|
||||
def validate_distillation_examples(examples, method):
|
||||
validate_examples(examples, method)
|
||||
for eg in examples:
|
||||
if [token.text for token in eg.reference] != [token.text for token in eg.predicted]:
|
||||
raise ValueError(Errors.E4003)
|
||||
|
||||
|
||||
def validate_get_examples(get_examples, method):
|
||||
"""Check that a generator of a batch of examples received during processing is valid:
|
||||
the callable produces a non-empty list of Example objects.
|
||||
|
|
|
@ -131,6 +131,39 @@ and all pipeline components are applied to the `Doc` in order. Both
|
|||
| `doc` | The document to process. ~~Doc~~ |
|
||||
| **RETURNS** | The processed document. ~~Doc~~ |
|
||||
|
||||
## DependencyParser.distill {id="distill", tag="method,experimental", version="4"}
|
||||
|
||||
Train a pipe (the student) on the predictions of another pipe (the teacher). The
|
||||
student is typically trained on the probability distribution of the teacher, but
|
||||
details may differ per pipe. The goal of distillation is to transfer knowledge
|
||||
from the teacher to the student.
|
||||
|
||||
The distillation is performed on ~~Example~~ objects. The `Example.reference`
|
||||
and `Example.predicted` ~~Doc~~s must have the same number of tokens and the
|
||||
same orthography. Even though the reference does not need have to have gold
|
||||
annotations, the teacher could adds its own annotations when necessary.
|
||||
|
||||
This feature is experimental.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> teacher_pipe = teacher.add_pipe("parser")
|
||||
> student_pipe = student.add_pipe("parser")
|
||||
> optimizer = nlp.resume_training()
|
||||
> losses = student.distill(teacher_pipe, examples, sgd=optimizer)
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `teacher_pipe` | The teacher pipe to learn from. ~~Optional[TrainablePipe]~~ |
|
||||
| `examples` | Distillation examples. The reference and predicted docs must have the same number of tokens and the same orthography. ~~Iterable[Example]~~ |
|
||||
| _keyword-only_ | |
|
||||
| `drop` | Dropout rate. ~~float~~ |
|
||||
| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ |
|
||||
| `losses` | Optional record of the loss during distillation. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ |
|
||||
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
|
||||
|
||||
## DependencyParser.pipe {id="pipe",tag="method"}
|
||||
|
||||
Apply the pipe to a stream of documents. This usually happens under the hood
|
||||
|
@ -268,6 +301,27 @@ predicted scores.
|
|||
| `scores` | Scores representing the model's predictions. ~~StateClass~~ |
|
||||
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
|
||||
|
||||
## DependencyParser.get_teacher_student_loss {id="get_teacher_student_loss", tag="method", version="4"}
|
||||
|
||||
Calculate the loss and its gradient for the batch of student scores relative to
|
||||
the teacher scores.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> teacher_parser = teacher.get_pipe("parser")
|
||||
> student_parser = student.add_pipe("parser")
|
||||
> student_scores = student_parser.predict([eg.predicted for eg in examples])
|
||||
> teacher_scores = teacher_parser.predict([eg.predicted for eg in examples])
|
||||
> loss, d_loss = student_parser.get_teacher_student_loss(teacher_scores, student_scores)
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| ---------------- | --------------------------------------------------------------------------- |
|
||||
| `teacher_scores` | Scores representing the teacher model's predictions. |
|
||||
| `student_scores` | Scores representing the student model's predictions. |
|
||||
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
|
||||
|
||||
## DependencyParser.create_optimizer {id="create_optimizer",tag="method"}
|
||||
|
||||
Create an [`Optimizer`](https://thinc.ai/docs/api-optimizers) for the pipeline
|
||||
|
|
|
@ -115,6 +115,39 @@ and all pipeline components are applied to the `Doc` in order. Both
|
|||
| `doc` | The document to process. ~~Doc~~ |
|
||||
| **RETURNS** | The processed document. ~~Doc~~ |
|
||||
|
||||
## EditTreeLemmatizer.distill {id="distill", tag="method,experimental", version="4"}
|
||||
|
||||
Train a pipe (the student) on the predictions of another pipe (the teacher). The
|
||||
student is typically trained on the probability distribution of the teacher, but
|
||||
details may differ per pipe. The goal of distillation is to transfer knowledge
|
||||
from the teacher to the student.
|
||||
|
||||
The distillation is performed on ~~Example~~ objects. The `Example.reference`
|
||||
and `Example.predicted` ~~Doc~~s must have the same number of tokens and the
|
||||
same orthography. Even though the reference does not need have to have gold
|
||||
annotations, the teacher could adds its own annotations when necessary.
|
||||
|
||||
This feature is experimental.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> teacher_pipe = teacher.add_pipe("trainable_lemmatizer")
|
||||
> student_pipe = student.add_pipe("trainable_lemmatizer")
|
||||
> optimizer = nlp.resume_training()
|
||||
> losses = student.distill(teacher_pipe, examples, sgd=optimizer)
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `teacher_pipe` | The teacher pipe to learn from. ~~Optional[TrainablePipe]~~ |
|
||||
| `examples` | Distillation examples. The reference and predicted docs must have the same number of tokens and the same orthography. ~~Iterable[Example]~~ |
|
||||
| _keyword-only_ | |
|
||||
| `drop` | Dropout rate. ~~float~~ |
|
||||
| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ |
|
||||
| `losses` | Optional record of the loss during distillation. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ |
|
||||
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
|
||||
|
||||
## EditTreeLemmatizer.pipe {id="pipe",tag="method"}
|
||||
|
||||
Apply the pipe to a stream of documents. This usually happens under the hood
|
||||
|
@ -269,6 +302,27 @@ Create an optimizer for the pipeline component.
|
|||
| ----------- | ---------------------------- |
|
||||
| **RETURNS** | The optimizer. ~~Optimizer~~ |
|
||||
|
||||
## EditTreeLemmatizer.get_teacher_student_loss {id="get_teacher_student_loss", tag="method", version="4"}
|
||||
|
||||
Calculate the loss and its gradient for the batch of student scores relative to
|
||||
the teacher scores.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> teacher_lemmatizer = teacher.get_pipe("trainable_lemmatizer")
|
||||
> student_lemmatizer = student.add_pipe("trainable_lemmatizer")
|
||||
> student_scores = student_lemmatizer.predict([eg.predicted for eg in examples])
|
||||
> teacher_scores = teacher_lemmatizer.predict([eg.predicted for eg in examples])
|
||||
> loss, d_loss = student_lemmatizer.get_teacher_student_loss(teacher_scores, student_scores)
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| ---------------- | --------------------------------------------------------------------------- |
|
||||
| `teacher_scores` | Scores representing the teacher model's predictions. |
|
||||
| `student_scores` | Scores representing the student model's predictions. |
|
||||
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
|
||||
|
||||
## EditTreeLemmatizer.use_params {id="use_params",tag="method, contextmanager"}
|
||||
|
||||
Modify the pipe's model, to use the given parameter values. At the end of the
|
||||
|
|
|
@ -127,6 +127,39 @@ and all pipeline components are applied to the `Doc` in order. Both
|
|||
| `doc` | The document to process. ~~Doc~~ |
|
||||
| **RETURNS** | The processed document. ~~Doc~~ |
|
||||
|
||||
## EntityRecognizer.distill {id="distill", tag="method,experimental", version="4"}
|
||||
|
||||
Train a pipe (the student) on the predictions of another pipe (the teacher). The
|
||||
student is typically trained on the probability distribution of the teacher, but
|
||||
details may differ per pipe. The goal of distillation is to transfer knowledge
|
||||
from the teacher to the student.
|
||||
|
||||
The distillation is performed on ~~Example~~ objects. The `Example.reference`
|
||||
and `Example.predicted` ~~Doc~~s must have the same number of tokens and the
|
||||
same orthography. Even though the reference does not need have to have gold
|
||||
annotations, the teacher could adds its own annotations when necessary.
|
||||
|
||||
This feature is experimental.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> teacher_pipe = teacher.add_pipe("ner")
|
||||
> student_pipe = student.add_pipe("ner")
|
||||
> optimizer = nlp.resume_training()
|
||||
> losses = student.distill(teacher_pipe, examples, sgd=optimizer)
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `teacher_pipe` | The teacher pipe to learn from. ~~Optional[TrainablePipe]~~ |
|
||||
| `examples` | Distillation examples. The reference and predicted docs must have the same number of tokens and the same orthography. ~~Iterable[Example]~~ |
|
||||
| _keyword-only_ | |
|
||||
| `drop` | Dropout rate. ~~float~~ |
|
||||
| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ |
|
||||
| `losses` | Optional record of the loss during distillation. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ |
|
||||
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
|
||||
|
||||
## EntityRecognizer.pipe {id="pipe",tag="method"}
|
||||
|
||||
Apply the pipe to a stream of documents. This usually happens under the hood
|
||||
|
@ -264,6 +297,27 @@ predicted scores.
|
|||
| `scores` | Scores representing the model's predictions. ~~StateClass~~ |
|
||||
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
|
||||
|
||||
## EntityRecognizer.get_teacher_student_loss {id="get_teacher_student_loss", tag="method", version="4"}
|
||||
|
||||
Calculate the loss and its gradient for the batch of student scores relative to
|
||||
the teacher scores.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> teacher_ner = teacher.get_pipe("ner")
|
||||
> student_ner = student.add_pipe("ner")
|
||||
> student_scores = student_ner.predict([eg.predicted for eg in examples])
|
||||
> teacher_scores = teacher_ner.predict([eg.predicted for eg in examples])
|
||||
> loss, d_loss = student_ner.get_teacher_student_loss(teacher_scores, student_scores)
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| ---------------- | --------------------------------------------------------------------------- |
|
||||
| `teacher_scores` | Scores representing the teacher model's predictions. |
|
||||
| `student_scores` | Scores representing the student model's predictions. |
|
||||
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
|
||||
|
||||
## EntityRecognizer.create_optimizer {id="create_optimizer",tag="method"}
|
||||
|
||||
Create an optimizer for the pipeline component.
|
||||
|
|
|
@ -121,6 +121,39 @@ delegate to the [`predict`](/api/morphologizer#predict) and
|
|||
| `doc` | The document to process. ~~Doc~~ |
|
||||
| **RETURNS** | The processed document. ~~Doc~~ |
|
||||
|
||||
## Morphologizer.distill {id="distill", tag="method,experimental", version="4"}
|
||||
|
||||
Train a pipe (the student) on the predictions of another pipe (the teacher). The
|
||||
student is typically trained on the probability distribution of the teacher, but
|
||||
details may differ per pipe. The goal of distillation is to transfer knowledge
|
||||
from the teacher to the student.
|
||||
|
||||
The distillation is performed on ~~Example~~ objects. The `Example.reference`
|
||||
and `Example.predicted` ~~Doc~~s must have the same number of tokens and the
|
||||
same orthography. Even though the reference does not need have to have gold
|
||||
annotations, the teacher could adds its own annotations when necessary.
|
||||
|
||||
This feature is experimental.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> teacher_pipe = teacher.add_pipe("morphologizer")
|
||||
> student_pipe = student.add_pipe("morphologizer")
|
||||
> optimizer = nlp.resume_training()
|
||||
> losses = student.distill(teacher_pipe, examples, sgd=optimizer)
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `teacher_pipe` | The teacher pipe to learn from. ~~Optional[TrainablePipe]~~ |
|
||||
| `examples` | Distillation examples. The reference and predicted docs must have the same number of tokens and the same orthography. ~~Iterable[Example]~~ |
|
||||
| _keyword-only_ | |
|
||||
| `drop` | Dropout rate. ~~float~~ |
|
||||
| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ |
|
||||
| `losses` | Optional record of the loss during distillation. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ |
|
||||
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
|
||||
|
||||
## Morphologizer.pipe {id="pipe",tag="method"}
|
||||
|
||||
Apply the pipe to a stream of documents. This usually happens under the hood
|
||||
|
@ -259,6 +292,27 @@ predicted scores.
|
|||
| `scores` | Scores representing the model's predictions. |
|
||||
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
|
||||
|
||||
## Morphologizer.get_teacher_student_loss {id="get_teacher_student_loss", tag="method", version="4"}
|
||||
|
||||
Calculate the loss and its gradient for the batch of student scores relative to
|
||||
the teacher scores.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> teacher_morphologizer = teacher.get_pipe("morphologizer")
|
||||
> student_morphologizer = student.add_pipe("morphologizer")
|
||||
> student_scores = student_morphologizer.predict([eg.predicted for eg in examples])
|
||||
> teacher_scores = teacher_morphologizer.predict([eg.predicted for eg in examples])
|
||||
> loss, d_loss = student_morphologizer.get_teacher_student_loss(teacher_scores, student_scores)
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| ---------------- | --------------------------------------------------------------------------- |
|
||||
| `teacher_scores` | Scores representing the teacher model's predictions. |
|
||||
| `student_scores` | Scores representing the student model's predictions. |
|
||||
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
|
||||
|
||||
## Morphologizer.create_optimizer {id="create_optimizer",tag="method"}
|
||||
|
||||
Create an optimizer for the pipeline component.
|
||||
|
|
|
@ -234,6 +234,39 @@ predictions and gold-standard annotations, and update the component's model.
|
|||
| `losses` | Optional record of the loss during training. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ |
|
||||
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
|
||||
|
||||
## TrainablePipe.distill {id="distill", tag="method,experimental", version="4"}
|
||||
|
||||
Train a pipe (the student) on the predictions of another pipe (the teacher). The
|
||||
student is typically trained on the probability distribution of the teacher, but
|
||||
details may differ per pipe. The goal of distillation is to transfer knowledge
|
||||
from the teacher to the student.
|
||||
|
||||
The distillation is performed on ~~Example~~ objects. The `Example.reference`
|
||||
and `Example.predicted` ~~Doc~~s must have the same number of tokens and the
|
||||
same orthography. Even though the reference does not need have to have gold
|
||||
annotations, the teacher could adds its own annotations when necessary.
|
||||
|
||||
This feature is experimental.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> teacher_pipe = teacher.add_pipe("your_custom_pipe")
|
||||
> student_pipe = student.add_pipe("your_custom_pipe")
|
||||
> optimizer = nlp.resume_training()
|
||||
> losses = student.distill(teacher_pipe, examples, sgd=optimizer)
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `teacher_pipe` | The teacher pipe to learn from. ~~Optional[TrainablePipe]~~ |
|
||||
| `examples` | Distillation examples. The reference and predicted docs must have the same number of tokens and the same orthography. ~~Iterable[Example]~~ |
|
||||
| _keyword-only_ | |
|
||||
| `drop` | Dropout rate. ~~float~~ |
|
||||
| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ |
|
||||
| `losses` | Optional record of the loss during distillation. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ |
|
||||
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
|
||||
|
||||
## TrainablePipe.rehearse {id="rehearse",tag="method,experimental",version="3"}
|
||||
|
||||
Perform a "rehearsal" update from a batch of data. Rehearsal updates teach the
|
||||
|
@ -281,6 +314,34 @@ This method needs to be overwritten with your own custom `get_loss` method.
|
|||
| `scores` | Scores representing the model's predictions. |
|
||||
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
|
||||
|
||||
## TrainablePipe.get_teacher_student_loss {id="get_teacher_student_loss", tag="method", version="4"}
|
||||
|
||||
Calculate the loss and its gradient for the batch of student scores relative to
|
||||
the teacher scores.
|
||||
|
||||
<Infobox variant="danger">
|
||||
|
||||
This method needs to be overwritten with your own custom
|
||||
`get_teacher_student_loss` method.
|
||||
|
||||
</Infobox>
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> teacher_pipe = teacher.get_pipe("your_custom_pipe")
|
||||
> student_pipe = student.add_pipe("your_custom_pipe")
|
||||
> student_scores = student_pipe.predict([eg.predicted for eg in examples])
|
||||
> teacher_scores = teacher_pipe.predict([eg.predicted for eg in examples])
|
||||
> loss, d_loss = student_pipe.get_teacher_student_loss(teacher_scores, student_scores)
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| ---------------- | --------------------------------------------------------------------------- |
|
||||
| `teacher_scores` | Scores representing the teacher model's predictions. |
|
||||
| `student_scores` | Scores representing the student model's predictions. |
|
||||
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
|
||||
|
||||
## TrainablePipe.score {id="score",tag="method",version="3"}
|
||||
|
||||
Score a batch of examples.
|
||||
|
|
|
@ -106,6 +106,39 @@ and all pipeline components are applied to the `Doc` in order. Both
|
|||
| `doc` | The document to process. ~~Doc~~ |
|
||||
| **RETURNS** | The processed document. ~~Doc~~ |
|
||||
|
||||
## SentenceRecognizer.distill {id="distill", tag="method,experimental", version="4"}
|
||||
|
||||
Train a pipe (the student) on the predictions of another pipe (the teacher). The
|
||||
student is typically trained on the probability distribution of the teacher, but
|
||||
details may differ per pipe. The goal of distillation is to transfer knowledge
|
||||
from the teacher to the student.
|
||||
|
||||
The distillation is performed on ~~Example~~ objects. The `Example.reference`
|
||||
and `Example.predicted` ~~Doc~~s must have the same number of tokens and the
|
||||
same orthography. Even though the reference does not need have to have gold
|
||||
annotations, the teacher could adds its own annotations when necessary.
|
||||
|
||||
This feature is experimental.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> teacher_pipe = teacher.add_pipe("senter")
|
||||
> student_pipe = student.add_pipe("senter")
|
||||
> optimizer = nlp.resume_training()
|
||||
> losses = student.distill(teacher_pipe, examples, sgd=optimizer)
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `teacher_pipe` | The teacher pipe to learn from. ~~Optional[TrainablePipe]~~ |
|
||||
| `examples` | Distillation examples. The reference and predicted docs must have the same number of tokens and the same orthography. ~~Iterable[Example]~~ |
|
||||
| _keyword-only_ | |
|
||||
| `drop` | Dropout rate. ~~float~~ |
|
||||
| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ |
|
||||
| `losses` | Optional record of the loss during distillation. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ |
|
||||
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
|
||||
|
||||
## SentenceRecognizer.pipe {id="pipe",tag="method"}
|
||||
|
||||
Apply the pipe to a stream of documents. This usually happens under the hood
|
||||
|
@ -254,6 +287,27 @@ predicted scores.
|
|||
| `scores` | Scores representing the model's predictions. |
|
||||
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
|
||||
|
||||
## SentenceRecognizer.get_teacher_student_loss {id="get_teacher_student_loss", tag="method", version="4"}
|
||||
|
||||
Calculate the loss and its gradient for the batch of student scores relative to
|
||||
the teacher scores.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> teacher_senter = teacher.get_pipe("senter")
|
||||
> student_senter = student.add_pipe("senter")
|
||||
> student_scores = student_senter.predict([eg.predicted for eg in examples])
|
||||
> teacher_scores = teacher_senter.predict([eg.predicted for eg in examples])
|
||||
> loss, d_loss = student_senter.get_teacher_student_loss(teacher_scores, student_scores)
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| ---------------- | --------------------------------------------------------------------------- |
|
||||
| `teacher_scores` | Scores representing the teacher model's predictions. |
|
||||
| `student_scores` | Scores representing the student model's predictions. |
|
||||
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
|
||||
|
||||
## SentenceRecognizer.create_optimizer {id="create_optimizer",tag="method"}
|
||||
|
||||
Create an optimizer for the pipeline component.
|
||||
|
|
|
@ -105,6 +105,39 @@ and all pipeline components are applied to the `Doc` in order. Both
|
|||
| `doc` | The document to process. ~~Doc~~ |
|
||||
| **RETURNS** | The processed document. ~~Doc~~ |
|
||||
|
||||
## Tagger.distill {id="distill", tag="method,experimental", version="4"}
|
||||
|
||||
Train a pipe (the student) on the predictions of another pipe (the teacher). The
|
||||
student is typically trained on the probability distribution of the teacher, but
|
||||
details may differ per pipe. The goal of distillation is to transfer knowledge
|
||||
from the teacher to the student.
|
||||
|
||||
The distillation is performed on ~~Example~~ objects. The `Example.reference`
|
||||
and `Example.predicted` ~~Doc~~s must have the same number of tokens and the
|
||||
same orthography. Even though the reference does not need have to have gold
|
||||
annotations, the teacher could adds its own annotations when necessary.
|
||||
|
||||
This feature is experimental.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> teacher_pipe = teacher.add_pipe("tagger")
|
||||
> student_pipe = student.add_pipe("tagger")
|
||||
> optimizer = nlp.resume_training()
|
||||
> losses = student.distill(teacher_pipe, examples, sgd=optimizer)
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `teacher_pipe` | The teacher pipe to learn from. ~~Optional[TrainablePipe]~~ |
|
||||
| `examples` | Distillation examples. The reference and predicted docs must have the same number of tokens and the same orthography. ~~Iterable[Example]~~ |
|
||||
| _keyword-only_ | |
|
||||
| `drop` | Dropout rate. ~~float~~ |
|
||||
| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ |
|
||||
| `losses` | Optional record of the loss during distillation. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ |
|
||||
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
|
||||
|
||||
## Tagger.pipe {id="pipe",tag="method"}
|
||||
|
||||
Apply the pipe to a stream of documents. This usually happens under the hood
|
||||
|
@ -265,6 +298,27 @@ predicted scores.
|
|||
| `scores` | Scores representing the model's predictions. |
|
||||
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
|
||||
|
||||
## Tagger.get_teacher_student_loss {id="get_teacher_student_loss", tag="method", version="4"}
|
||||
|
||||
Calculate the loss and its gradient for the batch of student scores relative to
|
||||
the teacher scores.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> teacher_tagger = teacher.get_pipe("tagger")
|
||||
> student_tagger = student.add_pipe("tagger")
|
||||
> student_scores = student_tagger.predict([eg.predicted for eg in examples])
|
||||
> teacher_scores = teacher_tagger.predict([eg.predicted for eg in examples])
|
||||
> loss, d_loss = student_tagger.get_teacher_student_loss(teacher_scores, student_scores)
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| ---------------- | --------------------------------------------------------------------------- |
|
||||
| `teacher_scores` | Scores representing the teacher model's predictions. |
|
||||
| `student_scores` | Scores representing the student model's predictions. |
|
||||
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
|
||||
|
||||
## Tagger.create_optimizer {id="create_optimizer",tag="method"}
|
||||
|
||||
Create an optimizer for the pipeline component.
|
||||
|
|
|
@ -921,7 +921,8 @@ backprop passes.
|
|||
Recursively wrap both the models and methods of each pipe using
|
||||
[NVTX](https://nvidia.github.io/NVTX/) range markers. By default, the following
|
||||
methods are wrapped: `pipe`, `predict`, `set_annotations`, `update`, `rehearse`,
|
||||
`get_loss`, `initialize`, `begin_update`, `finish_update`, `update`.
|
||||
`get_loss`, `get_teacher_student_loss`, `initialize`, `begin_update`,
|
||||
`finish_update`, `update`.
|
||||
|
||||
| Name | Description |
|
||||
| --------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
|
|
|
@ -1354,12 +1354,14 @@ For some use cases, it makes sense to also overwrite additional methods to
|
|||
customize how the model is updated from examples, how it's initialized, how the
|
||||
loss is calculated and to add evaluation scores to the training output.
|
||||
|
||||
| Name | Description |
|
||||
| ------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| [`update`](/api/pipe#update) | Learn from a batch of [`Example`](/api/example) objects containing the predictions and gold-standard annotations, and update the component's model. |
|
||||
| [`initialize`](/api/pipe#initialize) | Initialize the model. Typically calls into [`Model.initialize`](https://thinc.ai/docs/api-model#initialize) and can be passed custom arguments via the [`[initialize]`](/api/data-formats#config-initialize) config block that are only loaded during training or when you call [`nlp.initialize`](/api/language#initialize), not at runtime. |
|
||||
| [`get_loss`](/api/pipe#get_loss) | Return a tuple of the loss and the gradient for a batch of [`Example`](/api/example) objects. |
|
||||
| [`score`](/api/pipe#score) | Score a batch of [`Example`](/api/example) objects and return a dictionary of scores. The [`@Language.factory`](/api/language#factory) decorator can define the `default_score_weights` of the component to decide which keys of the scores to display during training and how they count towards the final score. |
|
||||
| Name | Description |
|
||||
| ---------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| [`update`](/api/pipe#update) | Learn from a batch of [`Example`](/api/example) objects containing the predictions and gold-standard annotations, and update the component's model. |
|
||||
| [`distill`](/api/pipe#distill) | Learn from a teacher pipeline using a batch of [`Doc`](/api/doc) objects and update the component's model. |
|
||||
| [`initialize`](/api/pipe#initialize) | Initialize the model. Typically calls into [`Model.initialize`](https://thinc.ai/docs/api-model#initialize) and can be passed custom arguments via the [`[initialize]`](/api/data-formats#config-initialize) config block that are only loaded during training or when you call [`nlp.initialize`](/api/language#initialize), not at runtime. |
|
||||
| [`get_loss`](/api/pipe#get_loss) | Return a tuple of the loss and the gradient for a batch of [`Example`](/api/example) objects. |
|
||||
| [`get_teacher_student_loss`](/api/pipe#get_teacher_student_loss) | Return a tuple of the loss and the gradient for the student scores relative to the teacher scores. |
|
||||
| [`score`](/api/pipe#score) | Score a batch of [`Example`](/api/example) objects and return a dictionary of scores. The [`@Language.factory`](/api/language#factory) decorator can define the `default_score_weights` of the component to decide which keys of the scores to display during training and how they count towards the final score. |
|
||||
|
||||
<Infobox title="Custom trainable components and models" emoji="📖">
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user