Add TrainablePipe.{distill,get_teacher_student_loss}

This change adds two methods:

- `TrainablePipe::distill` which performs a training step of a
   student pipe on a teacher pipe, giving a batch of `Doc`s.
- `TrainablePipe::get_teacher_student_loss` computes the loss
  of a student relative to the teacher.

The `distill` or `get_teacher_student_loss` methods are also implemented
in the tagger, edit tree lemmatizer, and parser pipes, to enable
distillation in those pipes and as an example for other pipes.
This commit is contained in:
Daniël de Kok 2022-12-22 12:12:44 +01:00
parent d30ba9b7b8
commit adead2a104
18 changed files with 575 additions and 11 deletions

View File

@ -949,6 +949,7 @@ class Errors(metaclass=ErrorsWithCodes):
E4000 = ("Expected a Doc as input, but got: '{type}'") E4000 = ("Expected a Doc as input, but got: '{type}'")
E4001 = ("Expected input to be one of the following types: ({expected_types}), " E4001 = ("Expected input to be one of the following types: ({expected_types}), "
"but got '{received_type}'") "but got '{received_type}'")
E4002 = ("Pipe '{name}' requires teacher pipe for distillation.")
# fmt: on # fmt: on

View File

@ -23,6 +23,7 @@ DEFAULT_NVTX_ANNOTATABLE_PIPE_METHODS = [
"update", "update",
"rehearse", "rehearse",
"get_loss", "get_loss",
"get_student_teacher_loss",
"initialize", "initialize",
"begin_update", "begin_update",
"finish_update", "finish_update",

View File

@ -155,6 +155,23 @@ class EditTreeLemmatizer(TrainablePipe):
return float(loss), d_scores return float(loss), d_scores
def get_teacher_student_loss(
self, teacher_scores: List[Floats2d], student_scores: List[Floats2d]
) -> Tuple[float, List[Floats2d]]:
"""Calculate the loss and its gradient for a batch of student
scores, relative to teacher scores.
teacher_scores: Scores representing the teacher model's predictions.
student_scores: Scores representing the student model's predictions.
DOCS: https://spacy.io/api/edittreelemmatizer#get_teacher_student_loss
"""
loss_func = LegacySequenceCategoricalCrossentropy(normalize=False)
d_scores, loss = loss_func(student_scores, teacher_scores)
if self.model.ops.xp.isnan(loss):
raise ValueError(Errors.E910.format(name=self.name))
return float(loss), d_scores
def predict(self, docs: Iterable[Doc]) -> ActivationsT: def predict(self, docs: Iterable[Doc]) -> ActivationsT:
n_docs = len(list(docs)) n_docs = len(list(docs))
if not any(len(doc) for doc in docs): if not any(len(doc) for doc in docs):

View File

@ -1,5 +1,6 @@
# cython: infer_types=True, profile=True, binding=True # cython: infer_types=True, profile=True, binding=True
from typing import Callable, Dict, Iterable, List, Optional, Union from typing import Callable, Dict, Iterable, List, Optional, Union
from typing import Tuple
import numpy import numpy
import srsly import srsly
from thinc.api import Model, set_dropout_rate, Config from thinc.api import Model, set_dropout_rate, Config
@ -245,7 +246,6 @@ class Tagger(TrainablePipe):
DOCS: https://spacy.io/api/tagger#rehearse DOCS: https://spacy.io/api/tagger#rehearse
""" """
loss_func = LegacySequenceCategoricalCrossentropy()
if losses is None: if losses is None:
losses = {} losses = {}
losses.setdefault(self.name, 0.0) losses.setdefault(self.name, 0.0)
@ -259,12 +259,30 @@ class Tagger(TrainablePipe):
set_dropout_rate(self.model, drop) set_dropout_rate(self.model, drop)
tag_scores, bp_tag_scores = self.model.begin_update(docs) tag_scores, bp_tag_scores = self.model.begin_update(docs)
tutor_tag_scores, _ = self._rehearsal_model.begin_update(docs) tutor_tag_scores, _ = self._rehearsal_model.begin_update(docs)
grads, loss = loss_func(tag_scores, tutor_tag_scores) loss, grads = self.get_teacher_student_loss(tutor_tag_scores, tag_scores)
bp_tag_scores(grads) bp_tag_scores(grads)
if sgd is not None:
self.finish_update(sgd) self.finish_update(sgd)
losses[self.name] += loss losses[self.name] += loss
return losses return losses
def get_teacher_student_loss(
self, teacher_scores: List[Floats2d], student_scores: List[Floats2d]
) -> Tuple[float, List[Floats2d]]:
"""Calculate the loss and its gradient for a batch of student
scores, relative to teacher scores.
teacher_scores: Scores representing the teacher model's predictions.
student_scores: Scores representing the student model's predictions.
DOCS: https://spacy.io/api/tagger#get_teacher_student_loss
"""
loss_func = LegacySequenceCategoricalCrossentropy(normalize=False)
d_scores, loss = loss_func(student_scores, teacher_scores)
if self.model.ops.xp.isnan(loss):
raise ValueError(Errors.E910.format(name=self.name))
return float(loss), d_scores
def get_loss(self, examples, scores): def get_loss(self, examples, scores):
"""Find the loss and gradient of loss for the batch of documents and """Find the loss and gradient of loss for the batch of documents and
their predicted scores. their predicted scores.

View File

@ -56,6 +56,56 @@ cdef class TrainablePipe(Pipe):
except Exception as e: except Exception as e:
error_handler(self.name, self, [doc], e) error_handler(self.name, self, [doc], e)
def distill(self,
teacher_pipe: Optional["TrainablePipe"],
teacher_docs: Iterable["Doc"],
student_docs: Iterable["Doc"],
*,
drop: float=0.0,
sgd: Optional[Optimizer]=None,
losses: Optional[Dict[str, float]]=None) -> Dict[str, float]:
"""Train a pipe (the student) on the predictions of another pipe
(the teacher). The student is typically trained on the probability
distribution of the teacher, but details may differ per pipe.
teacher_pipe (Optional[TrainablePipe]): The teacher pipe to learn
from.
teacher_docs (Iterable[Doc]): Documents passed through teacher pipes.
student_docs (Iterable[Doc]): Documents passed through student pipes.
Must contain the same tokens as `teacher_docs` but may have
different annotations.
drop (float): dropout rate.
sgd (Optional[Optimizer]): An optimizer. Will be created via
create_optimizer if not set.
losses (Optional[Dict[str, float]]): Optional record of loss during
distillation.
RETURNS: The updated losses dictionary.
"""
# By default we require a teacher pipe, but there are downstream
# implementations that don't require a pipe.
if teacher_pipe is None:
raise ValueError(Errors.E4002.format(name=self.name))
if losses is None:
losses = {}
losses.setdefault(self.name, 0.0)
if not any(len(doc) for doc in teacher_docs):
return losses
if not any(len(doc) for doc in student_docs):
return losses
set_dropout_rate(self.model, drop)
for node in teacher_pipe.model.walk():
if node.name == "softmax":
node.attrs["softmax_normalize"] = True
teacher_scores = teacher_pipe.model.predict(teacher_docs)
student_scores, bp_student_scores = self.model.begin_update(student_docs)
loss, d_scores = self.get_teacher_student_loss(teacher_scores, student_scores)
bp_student_scores(d_scores)
if sgd is not None:
self.finish_update(sgd)
losses[self.name] += loss
return losses
def pipe(self, stream: Iterable[Doc], *, batch_size: int=128) -> Iterator[Doc]: def pipe(self, stream: Iterable[Doc], *, batch_size: int=128) -> Iterator[Doc]:
"""Apply the pipe to a stream of documents. This usually happens under """Apply the pipe to a stream of documents. This usually happens under
the hood when the nlp object is called on a text and all components are the hood when the nlp object is called on a text and all components are
@ -169,6 +219,17 @@ cdef class TrainablePipe(Pipe):
""" """
raise NotImplementedError(Errors.E931.format(parent="TrainablePipe", method="get_loss", name=self.name)) raise NotImplementedError(Errors.E931.format(parent="TrainablePipe", method="get_loss", name=self.name))
def get_teacher_student_loss(self, teacher_scores, student_scores):
"""Calculate the loss and its gradient for a batch of student
scores, relative to teacher scores.
teacher_scores: Scores representing the teacher model's predictions.
student_scores: Scores representing the student model's predictions.
DOCS: https://spacy.io/api/pipe#get_teacher_student_loss
"""
raise NotImplementedError(Errors.E931.format(parent="TrainablePipe", method="get_teacher_student_loss", name=self.name))
def create_optimizer(self) -> Optimizer: def create_optimizer(self) -> Optimizer:
"""Create an optimizer for the pipeline component. """Create an optimizer for the pipeline component.

View File

@ -1,5 +1,6 @@
# cython: infer_types=True, cdivision=True, boundscheck=False, binding=True # cython: infer_types=True, cdivision=True, boundscheck=False, binding=True
from __future__ import print_function from __future__ import print_function
from typing import Dict, Iterable, List, Optional, Tuple
from cymem.cymem cimport Pool from cymem.cymem cimport Pool
cimport numpy as np cimport numpy as np
from itertools import islice from itertools import islice
@ -9,7 +10,10 @@ from libc.stdlib cimport calloc, free
import random import random
import srsly import srsly
from thinc.api import get_ops, set_dropout_rate, CupyOps, NumpyOps from thinc.api import get_ops, set_dropout_rate, CupyOps, NumpyOps, Optimizer
from thinc.api import LegacySequenceCategoricalCrossentropy, chain, softmax_activation, use_ops
from thinc.extra.search cimport Beam
from thinc.types import Floats2d
import numpy.random import numpy.random
import numpy import numpy
import warnings import warnings
@ -203,6 +207,120 @@ cdef class Parser(TrainablePipe):
# Defined in subclasses, to avoid circular import # Defined in subclasses, to avoid circular import
raise NotImplementedError raise NotImplementedError
def distill(self,
teacher_pipe: Optional[TrainablePipe],
teacher_docs: Iterable[Doc],
student_docs: Iterable[Doc],
*,
drop: float=0.0,
sgd: Optional[Optimizer]=None,
losses: Optional[Dict[str, float]]=None):
"""Train a pipe (the student) on the predictions of another pipe
(the teacher). The student is trained on the transition probabilities
of the teacher.
teacher_pipe (Optional[TrainablePipe]): The teacher pipe to learn
from.
teacher_docs (Iterable[Doc]): Documents passed through teacher pipes.
student_docs (Iterable[Doc]): Documents passed through student pipes.
Must contain the same tokens as `teacher_docs` but may have
different annotations.
drop (float): dropout rate.
sgd (Optional[Optimizer]): An optimizer. Will be created via
create_optimizer if not set.
losses (Optional[Dict[str, float]]): Optional record of loss during
distillation.
RETURNS: The updated losses dictionary.
"""
if teacher_pipe is None:
raise ValueError(Errors.E4002.format(name=self.name))
if losses is None:
losses = {}
losses.setdefault(self.name, 0.0)
if not any(len(doc) for doc in teacher_docs):
return losses
if not any(len(doc) for doc in student_docs):
return losses
set_dropout_rate(self.model, drop)
teacher_step_model = teacher_pipe.model.predict(teacher_docs)
student_step_model, backprop_tok2vec = self.model.begin_update(student_docs)
# Add softmax activation, so that we can compute student losses
# with cross-entropy loss.
with use_ops("numpy"):
teacher_model = chain(teacher_step_model, softmax_activation())
student_model = chain(student_step_model, softmax_activation())
max_moves = self.cfg["update_with_oracle_cut_size"]
if max_moves >= 1:
# Chop sequences into lengths of this many words, to make the
# batch uniform length. Since we do not have a gold standard
# sequence, we use the teacher's predictions as the gold
# standard.
max_moves = int(random.uniform(max_moves // 2, max_moves * 2))
states = self._init_batch(teacher_step_model, student_docs, max_moves)
else:
states = self.moves.init_batch(student_docs)
loss = 0.0
n_moves = 0
while states:
# We do distillation as follows: (1) for every state, we compute the
# transition softmax distributions: (2) we backpropagate the error of
# the student (compared to the teacher) into the student model; (3)
# for all states, we move to the next state using the student's
# predictions.
teacher_scores = teacher_model.predict(states)
student_scores, backprop = student_model.begin_update(states)
state_loss, d_scores = self.get_teacher_student_loss(teacher_scores, student_scores)
backprop(d_scores)
loss += state_loss
self.transition_states(states, student_scores)
states = [state for state in states if not state.is_final()]
# Stop when we reach the maximum number of moves, otherwise we start
# to process the remainder of cut sequences again.
if max_moves >= 1 and n_moves >= max_moves:
break
n_moves += 1
backprop_tok2vec(student_docs)
if sgd is not None:
self.finish_update(sgd)
losses[self.name] += loss
del backprop
del backprop_tok2vec
teacher_step_model.clear_memory()
student_step_model.clear_memory()
del teacher_model
del student_model
return losses
def get_teacher_student_loss(
self, teacher_scores: List[Floats2d], student_scores: List[Floats2d]
) -> Tuple[float, List[Floats2d]]:
"""Calculate the loss and its gradient for a batch of student
scores, relative to teacher scores.
teacher_scores: Scores representing the teacher model's predictions.
student_scores: Scores representing the student model's predictions.
DOCS: https://spacy.io/api/dependencyparser#get_teacher_student_loss
"""
loss_func = LegacySequenceCategoricalCrossentropy(normalize=False)
d_scores, loss = loss_func(student_scores, teacher_scores)
if self.model.ops.xp.isnan(loss):
raise ValueError(Errors.E910.format(name=self.name))
return float(loss), d_scores
def init_multitask_objectives(self, get_examples, pipeline, **cfg): def init_multitask_objectives(self, get_examples, pipeline, **cfg):
"""Setup models for secondary objectives, to benefit from multi-task """Setup models for secondary objectives, to benefit from multi-task
learning. This method is intended to be overridden by subclasses. learning. This method is intended to be overridden by subclasses.
@ -625,6 +743,40 @@ cdef class Parser(TrainablePipe):
raise ValueError(Errors.E149) from None raise ValueError(Errors.E149) from None
return self return self
def _init_batch(self, teacher_step_model, docs, max_length):
"""Make a square batch of length equal to the shortest transition
sequence or a cap. A long
doc will get multiple states. Let's say we have a doc of length 2*N,
where N is the shortest doc. We'll make two states, one representing
long_doc[:N], and another representing long_doc[N:]. In contrast to
_init_gold_batch, this version uses a teacher model to generate the
cut sequences."""
cdef:
StateClass start_state
StateClass state
Transition action
all_states = self.moves.init_batch(docs)
states = []
to_cut = []
for state, doc in zip(all_states, docs):
if not state.is_final():
if len(doc) < max_length:
states.append(state)
else:
to_cut.append(state)
while to_cut:
states.extend(state.copy() for state in to_cut)
# Move states forward max_length actions.
length = 0
while to_cut and length < max_length:
teacher_scores = teacher_step_model.predict(to_cut)
self.transition_states(to_cut, teacher_scores)
# States that are completed do not need further cutting.
to_cut = [state for state in to_cut if not state.is_final()]
length += 1
return states
def _init_gold_batch(self, examples, max_length): def _init_gold_batch(self, examples, max_length):
"""Make a square batch, of length equal to the shortest transition """Make a square batch, of length equal to the shortest transition
sequence or a cap. A long sequence or a cap. A long

View File

@ -617,6 +617,44 @@ def test_overfitting_IO(use_upper):
assert ents[1].kb_id == 0 assert ents[1].kb_id == 0
def test_distill():
teacher = English()
teacher_ner = teacher.add_pipe("ner")
train_examples = []
for text, annotations in TRAIN_DATA:
train_examples.append(Example.from_dict(teacher.make_doc(text), annotations))
for ent in annotations.get("entities"):
teacher_ner.add_label(ent[2])
optimizer = teacher.initialize(get_examples=lambda: train_examples)
for i in range(50):
losses = {}
teacher.update(train_examples, sgd=optimizer, losses=losses)
assert losses["ner"] < 0.00001
student = English()
student_ner = student.add_pipe("ner")
student_ner.initialize(
get_examples=lambda: train_examples, labels=teacher_ner.label_data
)
docs = [eg.predicted for eg in train_examples]
for i in range(100):
losses = {}
student_ner.distill(teacher_ner, docs, docs, sgd=optimizer, losses=losses)
assert losses["ner"] < 0.0001
# test the trained model
test_text = "I like London."
doc = student(test_text)
ents = doc.ents
assert len(ents) == 1
assert ents[0].text == "London"
assert ents[0].label_ == "LOC"
def test_beam_ner_scores(): def test_beam_ner_scores():
# Test that we can get confidence values out of the beam_ner pipe # Test that we can get confidence values out of the beam_ner pipe
beam_width = 16 beam_width = 16

View File

@ -396,6 +396,45 @@ def test_overfitting_IO(pipe_name):
assert_equal(batch_deps_1, no_batch_deps) assert_equal(batch_deps_1, no_batch_deps)
def test_distill():
teacher = English()
teacher_parser = teacher.add_pipe("parser")
train_examples = []
for text, annotations in TRAIN_DATA:
train_examples.append(Example.from_dict(teacher.make_doc(text), annotations))
for dep in annotations.get("deps", []):
teacher_parser.add_label(dep)
optimizer = teacher.initialize(get_examples=lambda: train_examples)
for i in range(200):
losses = {}
teacher.update(train_examples, sgd=optimizer, losses=losses)
assert losses["parser"] < 0.0001
student = English()
student_parser = student.add_pipe("parser")
student_parser.initialize(
get_examples=lambda: train_examples, labels=teacher_parser.label_data
)
docs = [eg.predicted for eg in train_examples]
for i in range(200):
losses = {}
student_parser.distill(teacher_parser, docs, docs, sgd=optimizer, losses=losses)
assert losses["parser"] < 0.0001
test_text = "I like securities."
doc = student(test_text)
assert doc[0].dep_ == "nsubj"
assert doc[2].dep_ == "dobj"
assert doc[3].dep_ == "punct"
assert doc[0].head.i == 1
assert doc[2].head.i == 1
assert doc[3].head.i == 1
# fmt: off # fmt: off
@pytest.mark.slow @pytest.mark.slow
@pytest.mark.parametrize("pipe_name", ["parser", "beam_parser"]) @pytest.mark.parametrize("pipe_name", ["parser", "beam_parser"])

View File

@ -195,6 +195,45 @@ def test_overfitting_IO():
assert doc4[3].lemma_ == "egg" assert doc4[3].lemma_ == "egg"
def test_distill():
teacher = English()
teacher_lemmatizer = teacher.add_pipe("trainable_lemmatizer")
teacher_lemmatizer.min_tree_freq = 1
train_examples = []
for t in TRAIN_DATA:
train_examples.append(Example.from_dict(teacher.make_doc(t[0]), t[1]))
optimizer = teacher.initialize(get_examples=lambda: train_examples)
for i in range(50):
losses = {}
teacher.update(train_examples, sgd=optimizer, losses=losses)
assert losses["trainable_lemmatizer"] < 0.00001
student = English()
student_lemmatizer = student.add_pipe("trainable_lemmatizer")
student_lemmatizer.min_tree_freq = 1
student_lemmatizer.initialize(
get_examples=lambda: train_examples, labels=teacher_lemmatizer.label_data
)
docs = [eg.predicted for eg in train_examples]
for i in range(50):
losses = {}
student_lemmatizer.distill(
teacher_lemmatizer, docs, docs, sgd=optimizer, losses=losses
)
assert losses["trainable_lemmatizer"] < 0.00001
test_text = "She likes blue eggs"
doc = student(test_text)
assert doc[0].lemma_ == "she"
assert doc[1].lemma_ == "like"
assert doc[2].lemma_ == "blue"
assert doc[3].lemma_ == "egg"
def test_lemmatizer_requires_labels(): def test_lemmatizer_requires_labels():
nlp = English() nlp = English()
nlp.add_pipe("trainable_lemmatizer") nlp.add_pipe("trainable_lemmatizer")

View File

@ -213,6 +213,42 @@ def test_overfitting_IO():
assert doc3[0].tag_ != "N" assert doc3[0].tag_ != "N"
def test_distill():
teacher = English()
teacher_tagger = teacher.add_pipe("tagger")
train_examples = []
for t in TRAIN_DATA:
train_examples.append(Example.from_dict(teacher.make_doc(t[0]), t[1]))
optimizer = teacher.initialize(get_examples=lambda: train_examples)
for i in range(50):
losses = {}
teacher.update(train_examples, sgd=optimizer, losses=losses)
assert losses["tagger"] < 0.00001
student = English()
student_tagger = student.add_pipe("tagger")
student_tagger.min_tree_freq = 1
student_tagger.initialize(
get_examples=lambda: train_examples, labels=teacher_tagger.label_data
)
docs = [eg.predicted for eg in train_examples]
for i in range(50):
losses = {}
student_tagger.distill(teacher_tagger, docs, docs, sgd=optimizer, losses=losses)
assert losses["tagger"] < 0.00001
test_text = "I like blue eggs"
doc = student(test_text)
assert doc[0].tag_ == "N"
assert doc[1].tag_ == "V"
assert doc[2].tag_ == "J"
assert doc[3].tag_ == "N"
def test_save_activations(): def test_save_activations():
# Test if activations are correctly added to Doc when requested. # Test if activations are correctly added to Doc when requested.
nlp = English() nlp = English()

View File

@ -268,6 +268,27 @@ predicted scores.
| `scores` | Scores representing the model's predictions. ~~StateClass~~ | | `scores` | Scores representing the model's predictions. ~~StateClass~~ |
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ | | **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
## DependencyParser.get_teacher_student_loss {#get_teacher_student_loss tag="method" new="4"}
Calculate the loss and its gradient for the batch of student scores relative to
the teacher scores.
> #### Example
>
> ```python
> teacher_parser = teacher.get_pipe("parser")
> student_parser = student.add_pipe("parser")
> student_scores = student_parser.predict([eg.predicted for eg in examples])
> teacher_scores = teacher_parser.predict([eg.predicted for eg in examples])
> loss, d_loss = student_parser.get_teacher_student_loss(teacher_scores, student_scores)
> ```
| Name | Description |
| ---------------- | --------------------------------------------------------------------------- |
| `teacher_scores` | Scores representing the teacher model's predictions. |
| `student_scores` | Scores representing the student model's predictions. |
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
## DependencyParser.create_optimizer {#create_optimizer tag="method"} ## DependencyParser.create_optimizer {#create_optimizer tag="method"}
Create an [`Optimizer`](https://thinc.ai/docs/api-optimizers) for the pipeline Create an [`Optimizer`](https://thinc.ai/docs/api-optimizers) for the pipeline

View File

@ -269,6 +269,27 @@ Create an optimizer for the pipeline component.
| ----------- | ---------------------------- | | ----------- | ---------------------------- |
| **RETURNS** | The optimizer. ~~Optimizer~~ | | **RETURNS** | The optimizer. ~~Optimizer~~ |
## EditTreeLemmatizer.get_teacher_student_loss {#get_teacher_student_loss tag="method" new="4"}
Calculate the loss and its gradient for the batch of student scores relative to
the teacher scores.
> #### Example
>
> ```python
> teacher_lemmatizer = teacher.get_pipe("trainable_lemmatizer")
> student_lemmatizer = student.add_pipe("trainable_lemmatizer")
> student_scores = student_lemmatizer.predict([eg.predicted for eg in examples])
> teacher_scores = teacher_lemmatizer.predict([eg.predicted for eg in examples])
> loss, d_loss = student_lemmatizer.get_teacher_student_loss(teacher_scores, student_scores)
> ```
| Name | Description |
| ---------------- | --------------------------------------------------------------------------- |
| `teacher_scores` | Scores representing the teacher model's predictions. |
| `student_scores` | Scores representing the student model's predictions. |
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
## EditTreeLemmatizer.use_params {#use_params tag="method, contextmanager"} ## EditTreeLemmatizer.use_params {#use_params tag="method, contextmanager"}
Modify the pipe's model, to use the given parameter values. At the end of the Modify the pipe's model, to use the given parameter values. At the end of the

View File

@ -259,6 +259,27 @@ predicted scores.
| `scores` | Scores representing the model's predictions. | | `scores` | Scores representing the model's predictions. |
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ | | **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
## Morphologizer.get_teacher_student_loss {#get_teacher_student_loss tag="method" new="4"}
Calculate the loss and its gradient for the batch of student scores relative to
the teacher scores.
> #### Example
>
> ```python
> teacher_morphologizer = teacher.get_pipe("morphologizer")
> student_morphologizer = student.add_pipe("morphologizer")
> student_scores = student_morphologizer.predict([eg.predicted for eg in examples])
> teacher_scores = teacher_morphologizer.predict([eg.predicted for eg in examples])
> loss, d_loss = student_morphologizer.get_teacher_student_loss(teacher_scores, student_scores)
> ```
| Name | Description |
| ---------------- | --------------------------------------------------------------------------- |
| `teacher_scores` | Scores representing the teacher model's predictions. |
| `student_scores` | Scores representing the student model's predictions. |
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
## Morphologizer.create_optimizer {#create_optimizer tag="method"} ## Morphologizer.create_optimizer {#create_optimizer tag="method"}
Create an optimizer for the pipeline component. Create an optimizer for the pipeline component.

View File

@ -234,6 +234,33 @@ predictions and gold-standard annotations, and update the component's model.
| `losses` | Optional record of the loss during training. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ | | `losses` | Optional record of the loss during training. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ |
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ | | **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
## TrainablePipe.distill {#rehearse tag="method,experimental" new="4"}
Train a pipe (the student) on the predictions of another pipe (the teacher). The
student is typically trained on the probability distribution of the teacher, but
details may differ per pipe. The goal of distillation is to transfer knowledge
from the teacher to the student. This feature is experimental.
> #### Example
>
> ```python
> teacher_pipe = teacher.add_pipe("your_custom_pipe")
> student_pipe = student.add_pipe("your_custom_pipe")
> optimizer = nlp.resume_training()
> losses = student.distill(teacher_pipe, teacher_docs, student_docs, sgd=optimizer)
> ```
| Name | Description |
| -------------- | -------------------------------------------------------------------------------------------------------------------------------------------- |
| `teacher_pipe` | The teacher pipe to learn from. ~~Optional[TrainablePipe]~~ |
| `teacher_docs` | Documents passed through teacher pipes. ~~Iterable[Doc]~~ |
| `student_docs` | Documents passed through student pipes. Must contain the same tokens as `teacher_docs` but may have different annotations. ~~Iterable[Doc]~~ |
| _keyword-only_ | |
| `drop` | Dropout rate. ~~float~~ |
| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ |
| `losses` | Optional record of the loss during distillation. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ |
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
## TrainablePipe.rehearse {#rehearse tag="method,experimental" new="3"} ## TrainablePipe.rehearse {#rehearse tag="method,experimental" new="3"}
Perform a "rehearsal" update from a batch of data. Rehearsal updates teach the Perform a "rehearsal" update from a batch of data. Rehearsal updates teach the
@ -281,6 +308,33 @@ This method needs to be overwritten with your own custom `get_loss` method.
| `scores` | Scores representing the model's predictions. | | `scores` | Scores representing the model's predictions. |
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ | | **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
## TrainablePipe.get_teacher_student_loss {#get_teacher_student_loss tag="method"}
Calculate the loss and its gradient for the batch of student scores relative to
the teacher scores.
<Infobox variant="danger">
This method needs to be overwritten with your own custom `get_teacher_student_loss` method.
</Infobox>
> #### Example
>
> ```python
> teacher_pipe = teacher.get_pipe("your_custom_pipe")
> student_pipe = student.add_pipe("your_custom_pipe")
> student_scores = student_pipe.predict([eg.predicted for eg in examples])
> teacher_scores = teacher_pipe.predict([eg.predicted for eg in examples])
> loss, d_loss = student_pipe.get_teacher_student_loss(teacher_scores, student_scores)
> ```
| Name | Description |
| ---------------- | --------------------------------------------------------------------------- |
| `teacher_scores` | Scores representing the teacher model's predictions. |
| `student_scores` | Scores representing the student model's predictions. |
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
## TrainablePipe.score {#score tag="method" new="3"} ## TrainablePipe.score {#score tag="method" new="3"}
Score a batch of examples. Score a batch of examples.

View File

@ -254,6 +254,27 @@ predicted scores.
| `scores` | Scores representing the model's predictions. | | `scores` | Scores representing the model's predictions. |
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ | | **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
## SentenceRecognizer.get_teacher_student_loss {#get_teacher_student_loss tag="method" new="4"}
Calculate the loss and its gradient for the batch of student scores relative to
the teacher scores.
> #### Example
>
> ```python
> teacher_senter = teacher.get_pipe("senter")
> student_senter = student.add_pipe("senter")
> student_scores = student_senter.predict([eg.predicted for eg in examples])
> teacher_scores = teacher_senter.predict([eg.predicted for eg in examples])
> loss, d_loss = student_senter.get_teacher_student_loss(teacher_scores, student_scores)
> ```
| Name | Description |
| ---------------- | --------------------------------------------------------------------------- |
| `teacher_scores` | Scores representing the teacher model's predictions. |
| `student_scores` | Scores representing the student model's predictions. |
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
## SentenceRecognizer.create_optimizer {#create_optimizer tag="method"} ## SentenceRecognizer.create_optimizer {#create_optimizer tag="method"}
Create an optimizer for the pipeline component. Create an optimizer for the pipeline component.

View File

@ -265,6 +265,27 @@ predicted scores.
| `scores` | Scores representing the model's predictions. | | `scores` | Scores representing the model's predictions. |
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ | | **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
## Tagger.get_teacher_student_loss {#get_teacher_student_loss tag="method" new="4"}
Calculate the loss and its gradient for the batch of student scores relative to
the teacher scores.
> #### Example
>
> ```python
> teacher_tagger = teacher.get_pipe("tagger")
> student_tagger = student.add_pipe("tagger")
> student_scores = student_tagger.predict([eg.predicted for eg in examples])
> teacher_scores = teacher_tagger.predict([eg.predicted for eg in examples])
> loss, d_loss = student_tagger.get_teacher_student_loss(teacher_scores, student_scores)
> ```
| Name | Description |
| ---------------- | --------------------------------------------------------------------------- |
| `teacher_scores` | Scores representing the teacher model's predictions. |
| `student_scores` | Scores representing the student model's predictions. |
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
## Tagger.create_optimizer {#create_optimizer tag="method"} ## Tagger.create_optimizer {#create_optimizer tag="method"}
Create an optimizer for the pipeline component. Create an optimizer for the pipeline component.

View File

@ -899,7 +899,8 @@ backprop passes.
Recursively wrap both the models and methods of each pipe using Recursively wrap both the models and methods of each pipe using
[NVTX](https://nvidia.github.io/NVTX/) range markers. By default, the following [NVTX](https://nvidia.github.io/NVTX/) range markers. By default, the following
methods are wrapped: `pipe`, `predict`, `set_annotations`, `update`, `rehearse`, methods are wrapped: `pipe`, `predict`, `set_annotations`, `update`, `rehearse`,
`get_loss`, `initialize`, `begin_update`, `finish_update`, `update`. `get_loss`, `get_student_teacher_loss`, `initialize`, `begin_update`,
`finish_update`, `update`.
| Name | Description | | Name | Description |
| --------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------- | | --------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------- |

View File

@ -1370,10 +1370,12 @@ customize how the model is updated from examples, how it's initialized, how the
loss is calculated and to add evaluation scores to the training output. loss is calculated and to add evaluation scores to the training output.
| Name | Description | | Name | Description |
| ------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | ---------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| [`update`](/api/pipe#update) | Learn from a batch of [`Example`](/api/example) objects containing the predictions and gold-standard annotations, and update the component's model. | | [`update`](/api/pipe#update) | Learn from a batch of [`Example`](/api/example) objects containing the predictions and gold-standard annotations, and update the component's model. |
| [`distill`](/api/pipe#distill) | Learn from a teacher pipeline using a batch of [`Doc`](/api/doc) objects and update the component's model. |
| [`initialize`](/api/pipe#initialize) | Initialize the model. Typically calls into [`Model.initialize`](https://thinc.ai/docs/api-model#initialize) and can be passed custom arguments via the [`[initialize]`](/api/data-formats#config-initialize) config block that are only loaded during training or when you call [`nlp.initialize`](/api/language#initialize), not at runtime. | | [`initialize`](/api/pipe#initialize) | Initialize the model. Typically calls into [`Model.initialize`](https://thinc.ai/docs/api-model#initialize) and can be passed custom arguments via the [`[initialize]`](/api/data-formats#config-initialize) config block that are only loaded during training or when you call [`nlp.initialize`](/api/language#initialize), not at runtime. |
| [`get_loss`](/api/pipe#get_loss) | Return a tuple of the loss and the gradient for a batch of [`Example`](/api/example) objects. | | [`get_loss`](/api/pipe#get_loss) | Return a tuple of the loss and the gradient for a batch of [`Example`](/api/example) objects. |
| [`get_teacher_student_loss`](/api/pipe#get_teacher_student_loss) | Return a tuple of the loss and the gradient for the student scores relative to the teacher scores. |
| [`score`](/api/pipe#score) | Score a batch of [`Example`](/api/example) objects and return a dictionary of scores. The [`@Language.factory`](/api/language#factory) decorator can define the `default_score_weights` of the component to decide which keys of the scores to display during training and how they count towards the final score. | | [`score`](/api/pipe#score) | Score a batch of [`Example`](/api/example) objects and return a dictionary of scores. The [`@Language.factory`](/api/language#factory) decorator can define the `default_score_weights` of the component to decide which keys of the scores to display during training and how they count towards the final score. |
<Infobox title="Custom trainable components and models" emoji="📖"> <Infobox title="Custom trainable components and models" emoji="📖">