diff --git a/spacy/errors.py b/spacy/errors.py
index e800be1fa..e38931985 100644
--- a/spacy/errors.py
+++ b/spacy/errors.py
@@ -949,6 +949,7 @@ class Errors(metaclass=ErrorsWithCodes):
E4000 = ("Expected a Doc as input, but got: '{type}'")
E4001 = ("Expected input to be one of the following types: ({expected_types}), "
"but got '{received_type}'")
+ E4002 = ("Pipe '{name}' requires teacher pipe for distillation.")
# fmt: on
diff --git a/spacy/ml/callbacks.py b/spacy/ml/callbacks.py
index 3b60ec2ab..9b24c71bd 100644
--- a/spacy/ml/callbacks.py
+++ b/spacy/ml/callbacks.py
@@ -23,6 +23,7 @@ DEFAULT_NVTX_ANNOTATABLE_PIPE_METHODS = [
"update",
"rehearse",
"get_loss",
+ "get_student_teacher_loss",
"initialize",
"begin_update",
"finish_update",
diff --git a/spacy/pipeline/edit_tree_lemmatizer.py b/spacy/pipeline/edit_tree_lemmatizer.py
index 2a2242aa4..c2027f054 100644
--- a/spacy/pipeline/edit_tree_lemmatizer.py
+++ b/spacy/pipeline/edit_tree_lemmatizer.py
@@ -155,6 +155,23 @@ class EditTreeLemmatizer(TrainablePipe):
return float(loss), d_scores
+ def get_teacher_student_loss(
+ self, teacher_scores: List[Floats2d], student_scores: List[Floats2d]
+ ) -> Tuple[float, List[Floats2d]]:
+ """Calculate the loss and its gradient for a batch of student
+ scores, relative to teacher scores.
+
+ teacher_scores: Scores representing the teacher model's predictions.
+ student_scores: Scores representing the student model's predictions.
+
+ DOCS: https://spacy.io/api/edittreelemmatizer#get_teacher_student_loss
+ """
+ loss_func = LegacySequenceCategoricalCrossentropy(normalize=False)
+ d_scores, loss = loss_func(student_scores, teacher_scores)
+ if self.model.ops.xp.isnan(loss):
+ raise ValueError(Errors.E910.format(name=self.name))
+ return float(loss), d_scores
+
def predict(self, docs: Iterable[Doc]) -> ActivationsT:
n_docs = len(list(docs))
if not any(len(doc) for doc in docs):
diff --git a/spacy/pipeline/tagger.pyx b/spacy/pipeline/tagger.pyx
index e12f116af..41e6634f9 100644
--- a/spacy/pipeline/tagger.pyx
+++ b/spacy/pipeline/tagger.pyx
@@ -1,5 +1,6 @@
# cython: infer_types=True, profile=True, binding=True
from typing import Callable, Dict, Iterable, List, Optional, Union
+from typing import Tuple
import numpy
import srsly
from thinc.api import Model, set_dropout_rate, Config
@@ -245,7 +246,6 @@ class Tagger(TrainablePipe):
DOCS: https://spacy.io/api/tagger#rehearse
"""
- loss_func = LegacySequenceCategoricalCrossentropy()
if losses is None:
losses = {}
losses.setdefault(self.name, 0.0)
@@ -259,12 +259,30 @@ class Tagger(TrainablePipe):
set_dropout_rate(self.model, drop)
tag_scores, bp_tag_scores = self.model.begin_update(docs)
tutor_tag_scores, _ = self._rehearsal_model.begin_update(docs)
- grads, loss = loss_func(tag_scores, tutor_tag_scores)
+ loss, grads = self.get_teacher_student_loss(tutor_tag_scores, tag_scores)
bp_tag_scores(grads)
- self.finish_update(sgd)
+ if sgd is not None:
+ self.finish_update(sgd)
losses[self.name] += loss
return losses
+ def get_teacher_student_loss(
+ self, teacher_scores: List[Floats2d], student_scores: List[Floats2d]
+ ) -> Tuple[float, List[Floats2d]]:
+ """Calculate the loss and its gradient for a batch of student
+ scores, relative to teacher scores.
+
+ teacher_scores: Scores representing the teacher model's predictions.
+ student_scores: Scores representing the student model's predictions.
+
+ DOCS: https://spacy.io/api/tagger#get_teacher_student_loss
+ """
+ loss_func = LegacySequenceCategoricalCrossentropy(normalize=False)
+ d_scores, loss = loss_func(student_scores, teacher_scores)
+ if self.model.ops.xp.isnan(loss):
+ raise ValueError(Errors.E910.format(name=self.name))
+ return float(loss), d_scores
+
def get_loss(self, examples, scores):
"""Find the loss and gradient of loss for the batch of documents and
their predicted scores.
diff --git a/spacy/pipeline/trainable_pipe.pyx b/spacy/pipeline/trainable_pipe.pyx
index 5bba34e4a..875a55448 100644
--- a/spacy/pipeline/trainable_pipe.pyx
+++ b/spacy/pipeline/trainable_pipe.pyx
@@ -56,6 +56,56 @@ cdef class TrainablePipe(Pipe):
except Exception as e:
error_handler(self.name, self, [doc], e)
+
+ def distill(self,
+ teacher_pipe: Optional["TrainablePipe"],
+ teacher_docs: Iterable["Doc"],
+ student_docs: Iterable["Doc"],
+ *,
+ drop: float=0.0,
+ sgd: Optional[Optimizer]=None,
+ losses: Optional[Dict[str, float]]=None) -> Dict[str, float]:
+ """Train a pipe (the student) on the predictions of another pipe
+ (the teacher). The student is typically trained on the probability
+ distribution of the teacher, but details may differ per pipe.
+
+ teacher_pipe (Optional[TrainablePipe]): The teacher pipe to learn
+ from.
+ teacher_docs (Iterable[Doc]): Documents passed through teacher pipes.
+ student_docs (Iterable[Doc]): Documents passed through student pipes.
+ Must contain the same tokens as `teacher_docs` but may have
+ different annotations.
+ drop (float): dropout rate.
+ sgd (Optional[Optimizer]): An optimizer. Will be created via
+ create_optimizer if not set.
+ losses (Optional[Dict[str, float]]): Optional record of loss during
+ distillation.
+ RETURNS: The updated losses dictionary.
+ """
+ # By default we require a teacher pipe, but there are downstream
+ # implementations that don't require a pipe.
+ if teacher_pipe is None:
+ raise ValueError(Errors.E4002.format(name=self.name))
+ if losses is None:
+ losses = {}
+ losses.setdefault(self.name, 0.0)
+ if not any(len(doc) for doc in teacher_docs):
+ return losses
+ if not any(len(doc) for doc in student_docs):
+ return losses
+ set_dropout_rate(self.model, drop)
+ for node in teacher_pipe.model.walk():
+ if node.name == "softmax":
+ node.attrs["softmax_normalize"] = True
+ teacher_scores = teacher_pipe.model.predict(teacher_docs)
+ student_scores, bp_student_scores = self.model.begin_update(student_docs)
+ loss, d_scores = self.get_teacher_student_loss(teacher_scores, student_scores)
+ bp_student_scores(d_scores)
+ if sgd is not None:
+ self.finish_update(sgd)
+ losses[self.name] += loss
+ return losses
+
def pipe(self, stream: Iterable[Doc], *, batch_size: int=128) -> Iterator[Doc]:
"""Apply the pipe to a stream of documents. This usually happens under
the hood when the nlp object is called on a text and all components are
@@ -169,6 +219,17 @@ cdef class TrainablePipe(Pipe):
"""
raise NotImplementedError(Errors.E931.format(parent="TrainablePipe", method="get_loss", name=self.name))
+ def get_teacher_student_loss(self, teacher_scores, student_scores):
+ """Calculate the loss and its gradient for a batch of student
+ scores, relative to teacher scores.
+
+ teacher_scores: Scores representing the teacher model's predictions.
+ student_scores: Scores representing the student model's predictions.
+
+ DOCS: https://spacy.io/api/pipe#get_teacher_student_loss
+ """
+ raise NotImplementedError(Errors.E931.format(parent="TrainablePipe", method="get_teacher_student_loss", name=self.name))
+
def create_optimizer(self) -> Optimizer:
"""Create an optimizer for the pipeline component.
diff --git a/spacy/pipeline/transition_parser.pyx b/spacy/pipeline/transition_parser.pyx
index 9d7b258c6..92dddc181 100644
--- a/spacy/pipeline/transition_parser.pyx
+++ b/spacy/pipeline/transition_parser.pyx
@@ -1,5 +1,6 @@
# cython: infer_types=True, cdivision=True, boundscheck=False, binding=True
from __future__ import print_function
+from typing import Dict, Iterable, List, Optional, Tuple
from cymem.cymem cimport Pool
cimport numpy as np
from itertools import islice
@@ -9,7 +10,10 @@ from libc.stdlib cimport calloc, free
import random
import srsly
-from thinc.api import get_ops, set_dropout_rate, CupyOps, NumpyOps
+from thinc.api import get_ops, set_dropout_rate, CupyOps, NumpyOps, Optimizer
+from thinc.api import LegacySequenceCategoricalCrossentropy, chain, softmax_activation, use_ops
+from thinc.extra.search cimport Beam
+from thinc.types import Floats2d
import numpy.random
import numpy
import warnings
@@ -203,6 +207,120 @@ cdef class Parser(TrainablePipe):
# Defined in subclasses, to avoid circular import
raise NotImplementedError
+ def distill(self,
+ teacher_pipe: Optional[TrainablePipe],
+ teacher_docs: Iterable[Doc],
+ student_docs: Iterable[Doc],
+ *,
+ drop: float=0.0,
+ sgd: Optional[Optimizer]=None,
+ losses: Optional[Dict[str, float]]=None):
+ """Train a pipe (the student) on the predictions of another pipe
+ (the teacher). The student is trained on the transition probabilities
+ of the teacher.
+
+ teacher_pipe (Optional[TrainablePipe]): The teacher pipe to learn
+ from.
+ teacher_docs (Iterable[Doc]): Documents passed through teacher pipes.
+ student_docs (Iterable[Doc]): Documents passed through student pipes.
+ Must contain the same tokens as `teacher_docs` but may have
+ different annotations.
+ drop (float): dropout rate.
+ sgd (Optional[Optimizer]): An optimizer. Will be created via
+ create_optimizer if not set.
+ losses (Optional[Dict[str, float]]): Optional record of loss during
+ distillation.
+ RETURNS: The updated losses dictionary.
+ """
+ if teacher_pipe is None:
+ raise ValueError(Errors.E4002.format(name=self.name))
+ if losses is None:
+ losses = {}
+ losses.setdefault(self.name, 0.0)
+
+ if not any(len(doc) for doc in teacher_docs):
+ return losses
+ if not any(len(doc) for doc in student_docs):
+ return losses
+
+ set_dropout_rate(self.model, drop)
+
+ teacher_step_model = teacher_pipe.model.predict(teacher_docs)
+ student_step_model, backprop_tok2vec = self.model.begin_update(student_docs)
+
+ # Add softmax activation, so that we can compute student losses
+ # with cross-entropy loss.
+ with use_ops("numpy"):
+ teacher_model = chain(teacher_step_model, softmax_activation())
+ student_model = chain(student_step_model, softmax_activation())
+
+ max_moves = self.cfg["update_with_oracle_cut_size"]
+ if max_moves >= 1:
+ # Chop sequences into lengths of this many words, to make the
+ # batch uniform length. Since we do not have a gold standard
+ # sequence, we use the teacher's predictions as the gold
+ # standard.
+ max_moves = int(random.uniform(max_moves // 2, max_moves * 2))
+ states = self._init_batch(teacher_step_model, student_docs, max_moves)
+ else:
+ states = self.moves.init_batch(student_docs)
+
+ loss = 0.0
+ n_moves = 0
+ while states:
+ # We do distillation as follows: (1) for every state, we compute the
+ # transition softmax distributions: (2) we backpropagate the error of
+ # the student (compared to the teacher) into the student model; (3)
+ # for all states, we move to the next state using the student's
+ # predictions.
+ teacher_scores = teacher_model.predict(states)
+ student_scores, backprop = student_model.begin_update(states)
+ state_loss, d_scores = self.get_teacher_student_loss(teacher_scores, student_scores)
+ backprop(d_scores)
+ loss += state_loss
+ self.transition_states(states, student_scores)
+ states = [state for state in states if not state.is_final()]
+
+ # Stop when we reach the maximum number of moves, otherwise we start
+ # to process the remainder of cut sequences again.
+ if max_moves >= 1 and n_moves >= max_moves:
+ break
+ n_moves += 1
+
+ backprop_tok2vec(student_docs)
+
+ if sgd is not None:
+ self.finish_update(sgd)
+
+ losses[self.name] += loss
+
+ del backprop
+ del backprop_tok2vec
+ teacher_step_model.clear_memory()
+ student_step_model.clear_memory()
+ del teacher_model
+ del student_model
+
+ return losses
+
+
+ def get_teacher_student_loss(
+ self, teacher_scores: List[Floats2d], student_scores: List[Floats2d]
+ ) -> Tuple[float, List[Floats2d]]:
+ """Calculate the loss and its gradient for a batch of student
+ scores, relative to teacher scores.
+
+ teacher_scores: Scores representing the teacher model's predictions.
+ student_scores: Scores representing the student model's predictions.
+
+ DOCS: https://spacy.io/api/dependencyparser#get_teacher_student_loss
+ """
+ loss_func = LegacySequenceCategoricalCrossentropy(normalize=False)
+ d_scores, loss = loss_func(student_scores, teacher_scores)
+ if self.model.ops.xp.isnan(loss):
+ raise ValueError(Errors.E910.format(name=self.name))
+ return float(loss), d_scores
+
def init_multitask_objectives(self, get_examples, pipeline, **cfg):
"""Setup models for secondary objectives, to benefit from multi-task
learning. This method is intended to be overridden by subclasses.
@@ -625,6 +743,40 @@ cdef class Parser(TrainablePipe):
raise ValueError(Errors.E149) from None
return self
+ def _init_batch(self, teacher_step_model, docs, max_length):
+ """Make a square batch of length equal to the shortest transition
+ sequence or a cap. A long
+ doc will get multiple states. Let's say we have a doc of length 2*N,
+ where N is the shortest doc. We'll make two states, one representing
+ long_doc[:N], and another representing long_doc[N:]. In contrast to
+ _init_gold_batch, this version uses a teacher model to generate the
+ cut sequences."""
+ cdef:
+ StateClass start_state
+ StateClass state
+ Transition action
+ all_states = self.moves.init_batch(docs)
+ states = []
+ to_cut = []
+ for state, doc in zip(all_states, docs):
+ if not state.is_final():
+ if len(doc) < max_length:
+ states.append(state)
+ else:
+ to_cut.append(state)
+ while to_cut:
+ states.extend(state.copy() for state in to_cut)
+ # Move states forward max_length actions.
+ length = 0
+ while to_cut and length < max_length:
+ teacher_scores = teacher_step_model.predict(to_cut)
+ self.transition_states(to_cut, teacher_scores)
+ # States that are completed do not need further cutting.
+ to_cut = [state for state in to_cut if not state.is_final()]
+ length += 1
+ return states
+
+
def _init_gold_batch(self, examples, max_length):
"""Make a square batch, of length equal to the shortest transition
sequence or a cap. A long
diff --git a/spacy/tests/parser/test_ner.py b/spacy/tests/parser/test_ner.py
index 00889efdc..082b424b8 100644
--- a/spacy/tests/parser/test_ner.py
+++ b/spacy/tests/parser/test_ner.py
@@ -617,6 +617,44 @@ def test_overfitting_IO(use_upper):
assert ents[1].kb_id == 0
+def test_distill():
+ teacher = English()
+ teacher_ner = teacher.add_pipe("ner")
+ train_examples = []
+ for text, annotations in TRAIN_DATA:
+ train_examples.append(Example.from_dict(teacher.make_doc(text), annotations))
+ for ent in annotations.get("entities"):
+ teacher_ner.add_label(ent[2])
+
+ optimizer = teacher.initialize(get_examples=lambda: train_examples)
+
+ for i in range(50):
+ losses = {}
+ teacher.update(train_examples, sgd=optimizer, losses=losses)
+ assert losses["ner"] < 0.00001
+
+ student = English()
+ student_ner = student.add_pipe("ner")
+ student_ner.initialize(
+ get_examples=lambda: train_examples, labels=teacher_ner.label_data
+ )
+
+ docs = [eg.predicted for eg in train_examples]
+
+ for i in range(100):
+ losses = {}
+ student_ner.distill(teacher_ner, docs, docs, sgd=optimizer, losses=losses)
+ assert losses["ner"] < 0.0001
+
+ # test the trained model
+ test_text = "I like London."
+ doc = student(test_text)
+ ents = doc.ents
+ assert len(ents) == 1
+ assert ents[0].text == "London"
+ assert ents[0].label_ == "LOC"
+
+
def test_beam_ner_scores():
# Test that we can get confidence values out of the beam_ner pipe
beam_width = 16
diff --git a/spacy/tests/parser/test_parse.py b/spacy/tests/parser/test_parse.py
index aaf31ed56..79b0d6c5e 100644
--- a/spacy/tests/parser/test_parse.py
+++ b/spacy/tests/parser/test_parse.py
@@ -396,6 +396,45 @@ def test_overfitting_IO(pipe_name):
assert_equal(batch_deps_1, no_batch_deps)
+def test_distill():
+ teacher = English()
+ teacher_parser = teacher.add_pipe("parser")
+ train_examples = []
+ for text, annotations in TRAIN_DATA:
+ train_examples.append(Example.from_dict(teacher.make_doc(text), annotations))
+ for dep in annotations.get("deps", []):
+ teacher_parser.add_label(dep)
+
+ optimizer = teacher.initialize(get_examples=lambda: train_examples)
+
+ for i in range(200):
+ losses = {}
+ teacher.update(train_examples, sgd=optimizer, losses=losses)
+ assert losses["parser"] < 0.0001
+
+ student = English()
+ student_parser = student.add_pipe("parser")
+ student_parser.initialize(
+ get_examples=lambda: train_examples, labels=teacher_parser.label_data
+ )
+
+ docs = [eg.predicted for eg in train_examples]
+
+ for i in range(200):
+ losses = {}
+ student_parser.distill(teacher_parser, docs, docs, sgd=optimizer, losses=losses)
+ assert losses["parser"] < 0.0001
+
+ test_text = "I like securities."
+ doc = student(test_text)
+ assert doc[0].dep_ == "nsubj"
+ assert doc[2].dep_ == "dobj"
+ assert doc[3].dep_ == "punct"
+ assert doc[0].head.i == 1
+ assert doc[2].head.i == 1
+ assert doc[3].head.i == 1
+
+
# fmt: off
@pytest.mark.slow
@pytest.mark.parametrize("pipe_name", ["parser", "beam_parser"])
diff --git a/spacy/tests/pipeline/test_edit_tree_lemmatizer.py b/spacy/tests/pipeline/test_edit_tree_lemmatizer.py
index 5eeb55aa2..99bd06dce 100644
--- a/spacy/tests/pipeline/test_edit_tree_lemmatizer.py
+++ b/spacy/tests/pipeline/test_edit_tree_lemmatizer.py
@@ -195,6 +195,45 @@ def test_overfitting_IO():
assert doc4[3].lemma_ == "egg"
+def test_distill():
+ teacher = English()
+ teacher_lemmatizer = teacher.add_pipe("trainable_lemmatizer")
+ teacher_lemmatizer.min_tree_freq = 1
+ train_examples = []
+ for t in TRAIN_DATA:
+ train_examples.append(Example.from_dict(teacher.make_doc(t[0]), t[1]))
+
+ optimizer = teacher.initialize(get_examples=lambda: train_examples)
+
+ for i in range(50):
+ losses = {}
+ teacher.update(train_examples, sgd=optimizer, losses=losses)
+ assert losses["trainable_lemmatizer"] < 0.00001
+
+ student = English()
+ student_lemmatizer = student.add_pipe("trainable_lemmatizer")
+ student_lemmatizer.min_tree_freq = 1
+ student_lemmatizer.initialize(
+ get_examples=lambda: train_examples, labels=teacher_lemmatizer.label_data
+ )
+
+ docs = [eg.predicted for eg in train_examples]
+
+ for i in range(50):
+ losses = {}
+ student_lemmatizer.distill(
+ teacher_lemmatizer, docs, docs, sgd=optimizer, losses=losses
+ )
+ assert losses["trainable_lemmatizer"] < 0.00001
+
+ test_text = "She likes blue eggs"
+ doc = student(test_text)
+ assert doc[0].lemma_ == "she"
+ assert doc[1].lemma_ == "like"
+ assert doc[2].lemma_ == "blue"
+ assert doc[3].lemma_ == "egg"
+
+
def test_lemmatizer_requires_labels():
nlp = English()
nlp.add_pipe("trainable_lemmatizer")
diff --git a/spacy/tests/pipeline/test_tagger.py b/spacy/tests/pipeline/test_tagger.py
index a0c71198e..8b5226053 100644
--- a/spacy/tests/pipeline/test_tagger.py
+++ b/spacy/tests/pipeline/test_tagger.py
@@ -213,6 +213,42 @@ def test_overfitting_IO():
assert doc3[0].tag_ != "N"
+def test_distill():
+ teacher = English()
+ teacher_tagger = teacher.add_pipe("tagger")
+ train_examples = []
+ for t in TRAIN_DATA:
+ train_examples.append(Example.from_dict(teacher.make_doc(t[0]), t[1]))
+
+ optimizer = teacher.initialize(get_examples=lambda: train_examples)
+
+ for i in range(50):
+ losses = {}
+ teacher.update(train_examples, sgd=optimizer, losses=losses)
+ assert losses["tagger"] < 0.00001
+
+ student = English()
+ student_tagger = student.add_pipe("tagger")
+ student_tagger.min_tree_freq = 1
+ student_tagger.initialize(
+ get_examples=lambda: train_examples, labels=teacher_tagger.label_data
+ )
+
+ docs = [eg.predicted for eg in train_examples]
+
+ for i in range(50):
+ losses = {}
+ student_tagger.distill(teacher_tagger, docs, docs, sgd=optimizer, losses=losses)
+ assert losses["tagger"] < 0.00001
+
+ test_text = "I like blue eggs"
+ doc = student(test_text)
+ assert doc[0].tag_ == "N"
+ assert doc[1].tag_ == "V"
+ assert doc[2].tag_ == "J"
+ assert doc[3].tag_ == "N"
+
+
def test_save_activations():
# Test if activations are correctly added to Doc when requested.
nlp = English()
diff --git a/website/docs/api/dependencyparser.md b/website/docs/api/dependencyparser.md
index c30d39b57..a69f8a681 100644
--- a/website/docs/api/dependencyparser.md
+++ b/website/docs/api/dependencyparser.md
@@ -268,6 +268,27 @@ predicted scores.
| `scores` | Scores representing the model's predictions. ~~StateClass~~ |
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
+## DependencyParser.get_teacher_student_loss {#get_teacher_student_loss tag="method" new="4"}
+
+Calculate the loss and its gradient for the batch of student scores relative to
+the teacher scores.
+
+> #### Example
+>
+> ```python
+> teacher_parser = teacher.get_pipe("parser")
+> student_parser = student.add_pipe("parser")
+> student_scores = student_parser.predict([eg.predicted for eg in examples])
+> teacher_scores = teacher_parser.predict([eg.predicted for eg in examples])
+> loss, d_loss = student_parser.get_teacher_student_loss(teacher_scores, student_scores)
+> ```
+
+| Name | Description |
+| ---------------- | --------------------------------------------------------------------------- |
+| `teacher_scores` | Scores representing the teacher model's predictions. |
+| `student_scores` | Scores representing the student model's predictions. |
+| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
+
## DependencyParser.create_optimizer {#create_optimizer tag="method"}
Create an [`Optimizer`](https://thinc.ai/docs/api-optimizers) for the pipeline
diff --git a/website/docs/api/edittreelemmatizer.md b/website/docs/api/edittreelemmatizer.md
index 8bee74316..b5ae0cae5 100644
--- a/website/docs/api/edittreelemmatizer.md
+++ b/website/docs/api/edittreelemmatizer.md
@@ -269,6 +269,27 @@ Create an optimizer for the pipeline component.
| ----------- | ---------------------------- |
| **RETURNS** | The optimizer. ~~Optimizer~~ |
+## EditTreeLemmatizer.get_teacher_student_loss {#get_teacher_student_loss tag="method" new="4"}
+
+Calculate the loss and its gradient for the batch of student scores relative to
+the teacher scores.
+
+> #### Example
+>
+> ```python
+> teacher_lemmatizer = teacher.get_pipe("trainable_lemmatizer")
+> student_lemmatizer = student.add_pipe("trainable_lemmatizer")
+> student_scores = student_lemmatizer.predict([eg.predicted for eg in examples])
+> teacher_scores = teacher_lemmatizer.predict([eg.predicted for eg in examples])
+> loss, d_loss = student_lemmatizer.get_teacher_student_loss(teacher_scores, student_scores)
+> ```
+
+| Name | Description |
+| ---------------- | --------------------------------------------------------------------------- |
+| `teacher_scores` | Scores representing the teacher model's predictions. |
+| `student_scores` | Scores representing the student model's predictions. |
+| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
+
## EditTreeLemmatizer.use_params {#use_params tag="method, contextmanager"}
Modify the pipe's model, to use the given parameter values. At the end of the
diff --git a/website/docs/api/morphologizer.md b/website/docs/api/morphologizer.md
index 97444b157..ba63a0690 100644
--- a/website/docs/api/morphologizer.md
+++ b/website/docs/api/morphologizer.md
@@ -259,6 +259,27 @@ predicted scores.
| `scores` | Scores representing the model's predictions. |
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
+## Morphologizer.get_teacher_student_loss {#get_teacher_student_loss tag="method" new="4"}
+
+Calculate the loss and its gradient for the batch of student scores relative to
+the teacher scores.
+
+> #### Example
+>
+> ```python
+> teacher_morphologizer = teacher.get_pipe("morphologizer")
+> student_morphologizer = student.add_pipe("morphologizer")
+> student_scores = student_morphologizer.predict([eg.predicted for eg in examples])
+> teacher_scores = teacher_morphologizer.predict([eg.predicted for eg in examples])
+> loss, d_loss = student_morphologizer.get_teacher_student_loss(teacher_scores, student_scores)
+> ```
+
+| Name | Description |
+| ---------------- | --------------------------------------------------------------------------- |
+| `teacher_scores` | Scores representing the teacher model's predictions. |
+| `student_scores` | Scores representing the student model's predictions. |
+| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
+
## Morphologizer.create_optimizer {#create_optimizer tag="method"}
Create an optimizer for the pipeline component.
diff --git a/website/docs/api/pipe.md b/website/docs/api/pipe.md
index 70a4648b6..287488a03 100644
--- a/website/docs/api/pipe.md
+++ b/website/docs/api/pipe.md
@@ -234,6 +234,33 @@ predictions and gold-standard annotations, and update the component's model.
| `losses` | Optional record of the loss during training. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ |
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
+## TrainablePipe.distill {#rehearse tag="method,experimental" new="4"}
+
+Train a pipe (the student) on the predictions of another pipe (the teacher). The
+student is typically trained on the probability distribution of the teacher, but
+details may differ per pipe. The goal of distillation is to transfer knowledge
+from the teacher to the student. This feature is experimental.
+
+> #### Example
+>
+> ```python
+> teacher_pipe = teacher.add_pipe("your_custom_pipe")
+> student_pipe = student.add_pipe("your_custom_pipe")
+> optimizer = nlp.resume_training()
+> losses = student.distill(teacher_pipe, teacher_docs, student_docs, sgd=optimizer)
+> ```
+
+| Name | Description |
+| -------------- | -------------------------------------------------------------------------------------------------------------------------------------------- |
+| `teacher_pipe` | The teacher pipe to learn from. ~~Optional[TrainablePipe]~~ |
+| `teacher_docs` | Documents passed through teacher pipes. ~~Iterable[Doc]~~ |
+| `student_docs` | Documents passed through student pipes. Must contain the same tokens as `teacher_docs` but may have different annotations. ~~Iterable[Doc]~~ |
+| _keyword-only_ | |
+| `drop` | Dropout rate. ~~float~~ |
+| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ |
+| `losses` | Optional record of the loss during distillation. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ |
+| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
+
## TrainablePipe.rehearse {#rehearse tag="method,experimental" new="3"}
Perform a "rehearsal" update from a batch of data. Rehearsal updates teach the
@@ -281,6 +308,33 @@ This method needs to be overwritten with your own custom `get_loss` method.
| `scores` | Scores representing the model's predictions. |
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
+## TrainablePipe.get_teacher_student_loss {#get_teacher_student_loss tag="method"}
+
+Calculate the loss and its gradient for the batch of student scores relative to
+the teacher scores.
+
+
+
+This method needs to be overwritten with your own custom `get_teacher_student_loss` method.
+
+
+
+> #### Example
+>
+> ```python
+> teacher_pipe = teacher.get_pipe("your_custom_pipe")
+> student_pipe = student.add_pipe("your_custom_pipe")
+> student_scores = student_pipe.predict([eg.predicted for eg in examples])
+> teacher_scores = teacher_pipe.predict([eg.predicted for eg in examples])
+> loss, d_loss = student_pipe.get_teacher_student_loss(teacher_scores, student_scores)
+> ```
+
+| Name | Description |
+| ---------------- | --------------------------------------------------------------------------- |
+| `teacher_scores` | Scores representing the teacher model's predictions. |
+| `student_scores` | Scores representing the student model's predictions. |
+| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
+
## TrainablePipe.score {#score tag="method" new="3"}
Score a batch of examples.
diff --git a/website/docs/api/sentencerecognizer.md b/website/docs/api/sentencerecognizer.md
index 03744e1b5..1de271094 100644
--- a/website/docs/api/sentencerecognizer.md
+++ b/website/docs/api/sentencerecognizer.md
@@ -254,6 +254,27 @@ predicted scores.
| `scores` | Scores representing the model's predictions. |
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
+## SentenceRecognizer.get_teacher_student_loss {#get_teacher_student_loss tag="method" new="4"}
+
+Calculate the loss and its gradient for the batch of student scores relative to
+the teacher scores.
+
+> #### Example
+>
+> ```python
+> teacher_senter = teacher.get_pipe("senter")
+> student_senter = student.add_pipe("senter")
+> student_scores = student_senter.predict([eg.predicted for eg in examples])
+> teacher_scores = teacher_senter.predict([eg.predicted for eg in examples])
+> loss, d_loss = student_senter.get_teacher_student_loss(teacher_scores, student_scores)
+> ```
+
+| Name | Description |
+| ---------------- | --------------------------------------------------------------------------- |
+| `teacher_scores` | Scores representing the teacher model's predictions. |
+| `student_scores` | Scores representing the student model's predictions. |
+| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
+
## SentenceRecognizer.create_optimizer {#create_optimizer tag="method"}
Create an optimizer for the pipeline component.
diff --git a/website/docs/api/tagger.md b/website/docs/api/tagger.md
index 102793377..9f9279391 100644
--- a/website/docs/api/tagger.md
+++ b/website/docs/api/tagger.md
@@ -265,6 +265,27 @@ predicted scores.
| `scores` | Scores representing the model's predictions. |
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
+## Tagger.get_teacher_student_loss {#get_teacher_student_loss tag="method" new="4"}
+
+Calculate the loss and its gradient for the batch of student scores relative to
+the teacher scores.
+
+> #### Example
+>
+> ```python
+> teacher_tagger = teacher.get_pipe("tagger")
+> student_tagger = student.add_pipe("tagger")
+> student_scores = student_tagger.predict([eg.predicted for eg in examples])
+> teacher_scores = teacher_tagger.predict([eg.predicted for eg in examples])
+> loss, d_loss = student_tagger.get_teacher_student_loss(teacher_scores, student_scores)
+> ```
+
+| Name | Description |
+| ---------------- | --------------------------------------------------------------------------- |
+| `teacher_scores` | Scores representing the teacher model's predictions. |
+| `student_scores` | Scores representing the student model's predictions. |
+| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
+
## Tagger.create_optimizer {#create_optimizer tag="method"}
Create an optimizer for the pipeline component.
diff --git a/website/docs/api/top-level.md b/website/docs/api/top-level.md
index 26a5d42f4..19c497f81 100644
--- a/website/docs/api/top-level.md
+++ b/website/docs/api/top-level.md
@@ -899,7 +899,8 @@ backprop passes.
Recursively wrap both the models and methods of each pipe using
[NVTX](https://nvidia.github.io/NVTX/) range markers. By default, the following
methods are wrapped: `pipe`, `predict`, `set_annotations`, `update`, `rehearse`,
-`get_loss`, `initialize`, `begin_update`, `finish_update`, `update`.
+`get_loss`, `get_student_teacher_loss`, `initialize`, `begin_update`,
+`finish_update`, `update`.
| Name | Description |
| --------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------- |
diff --git a/website/docs/usage/processing-pipelines.md b/website/docs/usage/processing-pipelines.md
index b3940458b..7e69dbdee 100644
--- a/website/docs/usage/processing-pipelines.md
+++ b/website/docs/usage/processing-pipelines.md
@@ -1369,12 +1369,14 @@ For some use cases, it makes sense to also overwrite additional methods to
customize how the model is updated from examples, how it's initialized, how the
loss is calculated and to add evaluation scores to the training output.
-| Name | Description |
-| ------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
-| [`update`](/api/pipe#update) | Learn from a batch of [`Example`](/api/example) objects containing the predictions and gold-standard annotations, and update the component's model. |
-| [`initialize`](/api/pipe#initialize) | Initialize the model. Typically calls into [`Model.initialize`](https://thinc.ai/docs/api-model#initialize) and can be passed custom arguments via the [`[initialize]`](/api/data-formats#config-initialize) config block that are only loaded during training or when you call [`nlp.initialize`](/api/language#initialize), not at runtime. |
-| [`get_loss`](/api/pipe#get_loss) | Return a tuple of the loss and the gradient for a batch of [`Example`](/api/example) objects. |
-| [`score`](/api/pipe#score) | Score a batch of [`Example`](/api/example) objects and return a dictionary of scores. The [`@Language.factory`](/api/language#factory) decorator can define the `default_score_weights` of the component to decide which keys of the scores to display during training and how they count towards the final score. |
+| Name | Description |
+| ---------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| [`update`](/api/pipe#update) | Learn from a batch of [`Example`](/api/example) objects containing the predictions and gold-standard annotations, and update the component's model. |
+| [`distill`](/api/pipe#distill) | Learn from a teacher pipeline using a batch of [`Doc`](/api/doc) objects and update the component's model. |
+| [`initialize`](/api/pipe#initialize) | Initialize the model. Typically calls into [`Model.initialize`](https://thinc.ai/docs/api-model#initialize) and can be passed custom arguments via the [`[initialize]`](/api/data-formats#config-initialize) config block that are only loaded during training or when you call [`nlp.initialize`](/api/language#initialize), not at runtime. |
+| [`get_loss`](/api/pipe#get_loss) | Return a tuple of the loss and the gradient for a batch of [`Example`](/api/example) objects. |
+| [`get_teacher_student_loss`](/api/pipe#get_teacher_student_loss) | Return a tuple of the loss and the gradient for the student scores relative to the teacher scores. |
+| [`score`](/api/pipe#score) | Score a batch of [`Example`](/api/example) objects and return a dictionary of scores. The [`@Language.factory`](/api/language#factory) decorator can define the `default_score_weights` of the component to decide which keys of the scores to display during training and how they count towards the final score. |