diff --git a/spacy/errors.py b/spacy/errors.py
index 7cfc2423d..a6ec4f6a0 100644
--- a/spacy/errors.py
+++ b/spacy/errors.py
@@ -957,7 +957,10 @@ class Errors(metaclass=ErrorsWithCodes):
E4000 = ("Expected a Doc as input, but got: '{type}'")
E4001 = ("Expected input to be one of the following types: ({expected_types}), "
"but got '{received_type}'")
- E4002 = ("Backprop is not supported when is_train is not set.")
+ E4002 = ("Pipe '{name}' requires a teacher pipe for distillation.")
+ E4003 = ("Training examples for distillation must have the exact same tokens in the "
+ "reference and predicted docs.")
+ E4004 = ("Backprop is not supported when is_train is not set.")
# fmt: on
diff --git a/spacy/ml/callbacks.py b/spacy/ml/callbacks.py
index 3b60ec2ab..393f208a6 100644
--- a/spacy/ml/callbacks.py
+++ b/spacy/ml/callbacks.py
@@ -23,6 +23,7 @@ DEFAULT_NVTX_ANNOTATABLE_PIPE_METHODS = [
"update",
"rehearse",
"get_loss",
+ "get_teacher_student_loss",
"initialize",
"begin_update",
"finish_update",
diff --git a/spacy/ml/tb_framework.pyx b/spacy/ml/tb_framework.pyx
index f0316d8f9..79be13b00 100644
--- a/spacy/ml/tb_framework.pyx
+++ b/spacy/ml/tb_framework.pyx
@@ -236,7 +236,7 @@ def _forward_greedy_cpu(model: Model, TransitionSystem moves, states: List[State
scores = _parse_batch(cblas, moves, &c_states[0], weights, sizes, actions=actions)
def backprop(dY):
- raise ValueError(Errors.E4002)
+ raise ValueError(Errors.E4004)
return (states, scores), backprop
diff --git a/spacy/pipeline/edit_tree_lemmatizer.py b/spacy/pipeline/edit_tree_lemmatizer.py
index 2a2242aa4..20f83fffc 100644
--- a/spacy/pipeline/edit_tree_lemmatizer.py
+++ b/spacy/pipeline/edit_tree_lemmatizer.py
@@ -155,6 +155,25 @@ class EditTreeLemmatizer(TrainablePipe):
return float(loss), d_scores
+ def get_teacher_student_loss(
+ self, teacher_scores: List[Floats2d], student_scores: List[Floats2d]
+ ) -> Tuple[float, List[Floats2d]]:
+ """Calculate the loss and its gradient for a batch of student
+ scores, relative to teacher scores.
+
+ teacher_scores: Scores representing the teacher model's predictions.
+ student_scores: Scores representing the student model's predictions.
+
+ RETURNS (Tuple[float, float]): The loss and the gradient.
+
+ DOCS: https://spacy.io/api/edittreelemmatizer#get_teacher_student_loss
+ """
+ loss_func = LegacySequenceCategoricalCrossentropy(normalize=False)
+ d_scores, loss = loss_func(student_scores, teacher_scores)
+ if self.model.ops.xp.isnan(loss):
+ raise ValueError(Errors.E910.format(name=self.name))
+ return float(loss), d_scores
+
def predict(self, docs: Iterable[Doc]) -> ActivationsT:
n_docs = len(list(docs))
if not any(len(doc) for doc in docs):
diff --git a/spacy/pipeline/pipe.pyx b/spacy/pipeline/pipe.pyx
index c5650382b..8b8fdc361 100644
--- a/spacy/pipeline/pipe.pyx
+++ b/spacy/pipeline/pipe.pyx
@@ -87,6 +87,10 @@ cdef class Pipe:
return self.scorer(examples, **scorer_kwargs)
return {}
+ @property
+ def is_distillable(self) -> bool:
+ return False
+
@property
def is_trainable(self) -> bool:
return False
diff --git a/spacy/pipeline/tagger.pyx b/spacy/pipeline/tagger.pyx
index e12f116af..a6be51c3c 100644
--- a/spacy/pipeline/tagger.pyx
+++ b/spacy/pipeline/tagger.pyx
@@ -1,5 +1,6 @@
# cython: infer_types=True, profile=True, binding=True
from typing import Callable, Dict, Iterable, List, Optional, Union
+from typing import Tuple
import numpy
import srsly
from thinc.api import Model, set_dropout_rate, Config
@@ -245,7 +246,6 @@ class Tagger(TrainablePipe):
DOCS: https://spacy.io/api/tagger#rehearse
"""
- loss_func = LegacySequenceCategoricalCrossentropy()
if losses is None:
losses = {}
losses.setdefault(self.name, 0.0)
@@ -259,12 +259,32 @@ class Tagger(TrainablePipe):
set_dropout_rate(self.model, drop)
tag_scores, bp_tag_scores = self.model.begin_update(docs)
tutor_tag_scores, _ = self._rehearsal_model.begin_update(docs)
- grads, loss = loss_func(tag_scores, tutor_tag_scores)
+ loss, grads = self.get_teacher_student_loss(tutor_tag_scores, tag_scores)
bp_tag_scores(grads)
- self.finish_update(sgd)
+ if sgd is not None:
+ self.finish_update(sgd)
losses[self.name] += loss
return losses
+ def get_teacher_student_loss(
+ self, teacher_scores: List[Floats2d], student_scores: List[Floats2d]
+ ) -> Tuple[float, List[Floats2d]]:
+ """Calculate the loss and its gradient for a batch of student
+ scores, relative to teacher scores.
+
+ teacher_scores: Scores representing the teacher model's predictions.
+ student_scores: Scores representing the student model's predictions.
+
+ RETURNS (Tuple[float, float]): The loss and the gradient.
+
+ DOCS: https://spacy.io/api/tagger#get_teacher_student_loss
+ """
+ loss_func = LegacySequenceCategoricalCrossentropy(normalize=False)
+ d_scores, loss = loss_func(student_scores, teacher_scores)
+ if self.model.ops.xp.isnan(loss):
+ raise ValueError(Errors.E910.format(name=self.name))
+ return float(loss), d_scores
+
def get_loss(self, examples, scores):
"""Find the loss and gradient of loss for the batch of documents and
their predicted scores.
diff --git a/spacy/pipeline/trainable_pipe.pyx b/spacy/pipeline/trainable_pipe.pyx
index 5bba34e4a..77259fc0b 100644
--- a/spacy/pipeline/trainable_pipe.pyx
+++ b/spacy/pipeline/trainable_pipe.pyx
@@ -6,7 +6,7 @@ import warnings
from ..tokens.doc cimport Doc
-from ..training import validate_examples
+from ..training import validate_examples, validate_distillation_examples
from ..errors import Errors, Warnings
from .pipe import Pipe, deserialize_config
from .. import util
@@ -56,6 +56,53 @@ cdef class TrainablePipe(Pipe):
except Exception as e:
error_handler(self.name, self, [doc], e)
+
+ def distill(self,
+ teacher_pipe: Optional["TrainablePipe"],
+ examples: Iterable["Example"],
+ *,
+ drop: float=0.0,
+ sgd: Optional[Optimizer]=None,
+ losses: Optional[Dict[str, float]]=None) -> Dict[str, float]:
+ """Train a pipe (the student) on the predictions of another pipe
+ (the teacher). The student is typically trained on the probability
+ distribution of the teacher, but details may differ per pipe.
+
+ teacher_pipe (Optional[TrainablePipe]): The teacher pipe to learn
+ from.
+ examples (Iterable[Example]): Distillation examples. The reference
+ and predicted docs must have the same number of tokens and the
+ same orthography.
+ drop (float): dropout rate.
+ sgd (Optional[Optimizer]): An optimizer. Will be created via
+ create_optimizer if not set.
+ losses (Optional[Dict[str, float]]): Optional record of loss during
+ distillation.
+ RETURNS: The updated losses dictionary.
+
+ DOCS: https://spacy.io/api/pipe#distill
+ """
+ # By default we require a teacher pipe, but there are downstream
+ # implementations that don't require a pipe.
+ if teacher_pipe is None:
+ raise ValueError(Errors.E4002.format(name=self.name))
+ if losses is None:
+ losses = {}
+ losses.setdefault(self.name, 0.0)
+ validate_distillation_examples(examples, "TrainablePipe.distill")
+ set_dropout_rate(self.model, drop)
+ for node in teacher_pipe.model.walk():
+ if node.name == "softmax":
+ node.attrs["softmax_normalize"] = True
+ teacher_scores = teacher_pipe.model.predict([eg.reference for eg in examples])
+ student_scores, bp_student_scores = self.model.begin_update([eg.predicted for eg in examples])
+ loss, d_scores = self.get_teacher_student_loss(teacher_scores, student_scores)
+ bp_student_scores(d_scores)
+ if sgd is not None:
+ self.finish_update(sgd)
+ losses[self.name] += loss
+ return losses
+
def pipe(self, stream: Iterable[Doc], *, batch_size: int=128) -> Iterator[Doc]:
"""Apply the pipe to a stream of documents. This usually happens under
the hood when the nlp object is called on a text and all components are
@@ -169,6 +216,19 @@ cdef class TrainablePipe(Pipe):
"""
raise NotImplementedError(Errors.E931.format(parent="TrainablePipe", method="get_loss", name=self.name))
+ def get_teacher_student_loss(self, teacher_scores, student_scores):
+ """Calculate the loss and its gradient for a batch of student
+ scores, relative to teacher scores.
+
+ teacher_scores: Scores representing the teacher model's predictions.
+ student_scores: Scores representing the student model's predictions.
+
+ RETURNS (Tuple[float, float]): The loss and the gradient.
+
+ DOCS: https://spacy.io/api/pipe#get_teacher_student_loss
+ """
+ raise NotImplementedError(Errors.E931.format(parent="TrainablePipe", method="get_teacher_student_loss", name=self.name))
+
def create_optimizer(self) -> Optimizer:
"""Create an optimizer for the pipeline component.
@@ -205,6 +265,14 @@ cdef class TrainablePipe(Pipe):
"""
raise NotImplementedError(Errors.E931.format(parent="Pipe", method="add_label", name=self.name))
+ @property
+ def is_distillable(self) -> bool:
+ # Normally a pipe overrides `get_teacher_student_loss` to implement
+ # distillation. In more exceptional cases, a pipe can provide its
+ # own `distill` implementation. If neither of these methods is
+ # overridden, the pipe does not implement distillation.
+ return not (self.__class__.distill is TrainablePipe.distill and self.__class__.get_teacher_student_loss is TrainablePipe.get_teacher_student_loss)
+
@property
def is_trainable(self) -> bool:
return True
diff --git a/spacy/pipeline/transition_parser.pyx b/spacy/pipeline/transition_parser.pyx
index e6119ee79..09b5f6181 100644
--- a/spacy/pipeline/transition_parser.pyx
+++ b/spacy/pipeline/transition_parser.pyx
@@ -1,6 +1,6 @@
# cython: infer_types=True, cdivision=True, boundscheck=False, binding=True
from __future__ import print_function
-from typing import List
+from typing import Dict, Iterable, List, Optional, Tuple
from cymem.cymem cimport Pool
cimport numpy as np
from itertools import islice
@@ -11,9 +11,10 @@ import random
import contextlib
import srsly
-from thinc.api import get_ops, set_dropout_rate, CupyOps, NumpyOps
-from thinc.api import get_array_module
-from thinc.types import Ints1d
+from thinc.api import get_ops, set_dropout_rate, CupyOps, NumpyOps, Optimizer
+from thinc.api import chain, softmax_activation, use_ops, get_array_module
+from thinc.legacy import LegacySequenceCategoricalCrossentropy
+from thinc.types import Floats2d, Ints1d
import numpy.random
import numpy
import warnings
@@ -30,6 +31,7 @@ from ._parser_internals.transition_system cimport Transition, TransitionSystem
from ..typedefs cimport weight_t
from ..training import validate_examples, validate_get_examples
+from ..training import validate_distillation_examples
from ..errors import Errors, Warnings
from .. import util
@@ -208,6 +210,121 @@ class Parser(TrainablePipe):
# Defined in subclasses, to avoid circular import
raise NotImplementedError
+ def distill(self,
+ teacher_pipe: Optional[TrainablePipe],
+ examples: Iterable["Example"],
+ *,
+ drop: float=0.0,
+ sgd: Optional[Optimizer]=None,
+ losses: Optional[Dict[str, float]]=None):
+ """Train a pipe (the student) on the predictions of another pipe
+ (the teacher). The student is trained on the transition probabilities
+ of the teacher.
+
+ teacher_pipe (Optional[TrainablePipe]): The teacher pipe to learn
+ from.
+ examples (Iterable[Example]): Distillation examples. The reference
+ and predicted docs must have the same number of tokens and the
+ same orthography.
+ drop (float): dropout rate.
+ sgd (Optional[Optimizer]): An optimizer. Will be created via
+ create_optimizer if not set.
+ losses (Optional[Dict[str, float]]): Optional record of loss during
+ distillation.
+ RETURNS: The updated losses dictionary.
+
+ DOCS: https://spacy.io/api/dependencyparser#distill
+ """
+ if teacher_pipe is None:
+ raise ValueError(Errors.E4002.format(name=self.name))
+ if losses is None:
+ losses = {}
+ losses.setdefault(self.name, 0.0)
+
+ validate_distillation_examples(examples, "TransitionParser.distill")
+
+ set_dropout_rate(self.model, drop)
+
+ student_docs = [eg.predicted for eg in examples]
+
+ teacher_step_model = teacher_pipe.model.predict([eg.reference for eg in examples])
+ student_step_model, backprop_tok2vec = self.model.begin_update(student_docs)
+
+ # Add softmax activation, so that we can compute student losses
+ # with cross-entropy loss.
+ with use_ops("numpy"):
+ teacher_model = chain(teacher_step_model, softmax_activation())
+ student_model = chain(student_step_model, softmax_activation())
+
+ max_moves = self.cfg["update_with_oracle_cut_size"]
+ if max_moves >= 1:
+ # Chop sequences into lengths of this many words, to make the
+ # batch uniform length. Since we do not have a gold standard
+ # sequence, we use the teacher's predictions as the gold
+ # standard.
+ max_moves = int(random.uniform(max_moves // 2, max_moves * 2))
+ states = self._init_batch(teacher_step_model, student_docs, max_moves)
+ else:
+ states = self.moves.init_batch(student_docs)
+
+ loss = 0.0
+ n_moves = 0
+ while states:
+ # We do distillation as follows: (1) for every state, we compute the
+ # transition softmax distributions: (2) we backpropagate the error of
+ # the student (compared to the teacher) into the student model; (3)
+ # for all states, we move to the next state using the student's
+ # predictions.
+ teacher_scores = teacher_model.predict(states)
+ student_scores, backprop = student_model.begin_update(states)
+ state_loss, d_scores = self.get_teacher_student_loss(teacher_scores, student_scores)
+ backprop(d_scores)
+ loss += state_loss
+ self.transition_states(states, student_scores)
+ states = [state for state in states if not state.is_final()]
+
+ # Stop when we reach the maximum number of moves, otherwise we start
+ # to process the remainder of cut sequences again.
+ if max_moves >= 1 and n_moves >= max_moves:
+ break
+ n_moves += 1
+
+ backprop_tok2vec(student_docs)
+
+ if sgd is not None:
+ self.finish_update(sgd)
+
+ losses[self.name] += loss
+
+ del backprop
+ del backprop_tok2vec
+ teacher_step_model.clear_memory()
+ student_step_model.clear_memory()
+ del teacher_model
+ del student_model
+
+ return losses
+
+
+ def get_teacher_student_loss(
+ self, teacher_scores: List[Floats2d], student_scores: List[Floats2d]
+ ) -> Tuple[float, List[Floats2d]]:
+ """Calculate the loss and its gradient for a batch of student
+ scores, relative to teacher scores.
+
+ teacher_scores: Scores representing the teacher model's predictions.
+ student_scores: Scores representing the student model's predictions.
+
+ RETURNS (Tuple[float, float]): The loss and the gradient.
+
+ DOCS: https://spacy.io/api/dependencyparser#get_teacher_student_loss
+ """
+ loss_func = LegacySequenceCategoricalCrossentropy(normalize=False)
+ d_scores, loss = loss_func(student_scores, teacher_scores)
+ if self.model.ops.xp.isnan(loss):
+ raise ValueError(Errors.E910.format(name=self.name))
+ return float(loss), d_scores
+
def init_multitask_objectives(self, get_examples, pipeline, **cfg):
"""Setup models for secondary objectives, to benefit from multi-task
learning. This method is intended to be overridden by subclasses.
@@ -526,6 +643,40 @@ class Parser(TrainablePipe):
raise ValueError(Errors.E149) from None
return self
+ def _init_batch(self, teacher_step_model, docs, max_length):
+ """Make a square batch of length equal to the shortest transition
+ sequence or a cap. A long
+ doc will get multiple states. Let's say we have a doc of length 2*N,
+ where N is the shortest doc. We'll make two states, one representing
+ long_doc[:N], and another representing long_doc[N:]. In contrast to
+ _init_gold_batch, this version uses a teacher model to generate the
+ cut sequences."""
+ cdef:
+ StateClass start_state
+ StateClass state
+ Transition action
+ all_states = self.moves.init_batch(docs)
+ states = []
+ to_cut = []
+ for state, doc in zip(all_states, docs):
+ if not state.is_final():
+ if len(doc) < max_length:
+ states.append(state)
+ else:
+ to_cut.append(state)
+ while to_cut:
+ states.extend(state.copy() for state in to_cut)
+ # Move states forward max_length actions.
+ length = 0
+ while to_cut and length < max_length:
+ teacher_scores = teacher_step_model.predict(to_cut)
+ self.transition_states(to_cut, teacher_scores)
+ # States that are completed do not need further cutting.
+ to_cut = [state for state in to_cut if not state.is_final()]
+ length += 1
+ return states
+
+
def _init_gold_batch(self, examples, max_length):
"""Make a square batch, of length equal to the shortest transition
sequence or a cap. A long doc will get multiple states. Let's say we
diff --git a/spacy/tests/parser/test_ner.py b/spacy/tests/parser/test_ner.py
index 5434a2fe7..d6cd11e55 100644
--- a/spacy/tests/parser/test_ner.py
+++ b/spacy/tests/parser/test_ner.py
@@ -617,6 +617,52 @@ def test_overfitting_IO():
assert ents[1].kb_id == 0
+def test_is_distillable():
+ nlp = English()
+ ner = nlp.add_pipe("ner")
+ assert ner.is_distillable
+
+
+def test_distill():
+ teacher = English()
+ teacher_ner = teacher.add_pipe("ner")
+ train_examples = []
+ for text, annotations in TRAIN_DATA:
+ train_examples.append(Example.from_dict(teacher.make_doc(text), annotations))
+ for ent in annotations.get("entities"):
+ teacher_ner.add_label(ent[2])
+
+ optimizer = teacher.initialize(get_examples=lambda: train_examples)
+
+ for i in range(50):
+ losses = {}
+ teacher.update(train_examples, sgd=optimizer, losses=losses)
+ assert losses["ner"] < 0.00001
+
+ student = English()
+ student_ner = student.add_pipe("ner")
+ student_ner.initialize(
+ get_examples=lambda: train_examples, labels=teacher_ner.label_data
+ )
+
+ distill_examples = [
+ Example.from_dict(teacher.make_doc(t[0]), {}) for t in TRAIN_DATA
+ ]
+
+ for i in range(100):
+ losses = {}
+ student_ner.distill(teacher_ner, distill_examples, sgd=optimizer, losses=losses)
+ assert losses["ner"] < 0.0001
+
+ # test the trained model
+ test_text = "I like London."
+ doc = student(test_text)
+ ents = doc.ents
+ assert len(ents) == 1
+ assert ents[0].text == "London"
+ assert ents[0].label_ == "LOC"
+
+
def test_beam_ner_scores():
# Test that we can get confidence values out of the beam_ner pipe
beam_width = 16
diff --git a/spacy/tests/parser/test_parse.py b/spacy/tests/parser/test_parse.py
index df463b700..6dfb8a18c 100644
--- a/spacy/tests/parser/test_parse.py
+++ b/spacy/tests/parser/test_parse.py
@@ -457,6 +457,55 @@ def test_overfitting_IO(pipe_name, max_moves):
assert_equal(batch_deps_1, no_batch_deps)
+def test_is_distillable():
+ nlp = English()
+ parser = nlp.add_pipe("parser")
+ assert parser.is_distillable
+
+
+def test_distill():
+ teacher = English()
+ teacher_parser = teacher.add_pipe("parser")
+ train_examples = []
+ for text, annotations in TRAIN_DATA:
+ train_examples.append(Example.from_dict(teacher.make_doc(text), annotations))
+ for dep in annotations.get("deps", []):
+ teacher_parser.add_label(dep)
+
+ optimizer = teacher.initialize(get_examples=lambda: train_examples)
+
+ for i in range(200):
+ losses = {}
+ teacher.update(train_examples, sgd=optimizer, losses=losses)
+ assert losses["parser"] < 0.0001
+
+ student = English()
+ student_parser = student.add_pipe("parser")
+ student_parser.initialize(
+ get_examples=lambda: train_examples, labels=teacher_parser.label_data
+ )
+
+ distill_examples = [
+ Example.from_dict(teacher.make_doc(t[0]), {}) for t in TRAIN_DATA
+ ]
+
+ for i in range(200):
+ losses = {}
+ student_parser.distill(
+ teacher_parser, distill_examples, sgd=optimizer, losses=losses
+ )
+ assert losses["parser"] < 0.0001
+
+ test_text = "I like securities."
+ doc = student(test_text)
+ assert doc[0].dep_ == "nsubj"
+ assert doc[2].dep_ == "dobj"
+ assert doc[3].dep_ == "punct"
+ assert doc[0].head.i == 1
+ assert doc[2].head.i == 1
+ assert doc[3].head.i == 1
+
+
# fmt: off
@pytest.mark.slow
@pytest.mark.parametrize("pipe_name", ["parser", "beam_parser"])
diff --git a/spacy/tests/pipeline/test_edit_tree_lemmatizer.py b/spacy/tests/pipeline/test_edit_tree_lemmatizer.py
index 5eeb55aa2..b855c7a26 100644
--- a/spacy/tests/pipeline/test_edit_tree_lemmatizer.py
+++ b/spacy/tests/pipeline/test_edit_tree_lemmatizer.py
@@ -195,6 +195,53 @@ def test_overfitting_IO():
assert doc4[3].lemma_ == "egg"
+def test_is_distillable():
+ nlp = English()
+ lemmatizer = nlp.add_pipe("trainable_lemmatizer")
+ assert lemmatizer.is_distillable
+
+
+def test_distill():
+ teacher = English()
+ teacher_lemmatizer = teacher.add_pipe("trainable_lemmatizer")
+ teacher_lemmatizer.min_tree_freq = 1
+ train_examples = []
+ for t in TRAIN_DATA:
+ train_examples.append(Example.from_dict(teacher.make_doc(t[0]), t[1]))
+
+ optimizer = teacher.initialize(get_examples=lambda: train_examples)
+
+ for i in range(50):
+ losses = {}
+ teacher.update(train_examples, sgd=optimizer, losses=losses)
+ assert losses["trainable_lemmatizer"] < 0.00001
+
+ student = English()
+ student_lemmatizer = student.add_pipe("trainable_lemmatizer")
+ student_lemmatizer.min_tree_freq = 1
+ student_lemmatizer.initialize(
+ get_examples=lambda: train_examples, labels=teacher_lemmatizer.label_data
+ )
+
+ distill_examples = [
+ Example.from_dict(teacher.make_doc(t[0]), {}) for t in TRAIN_DATA
+ ]
+
+ for i in range(50):
+ losses = {}
+ student_lemmatizer.distill(
+ teacher_lemmatizer, distill_examples, sgd=optimizer, losses=losses
+ )
+ assert losses["trainable_lemmatizer"] < 0.00001
+
+ test_text = "She likes blue eggs"
+ doc = student(test_text)
+ assert doc[0].lemma_ == "she"
+ assert doc[1].lemma_ == "like"
+ assert doc[2].lemma_ == "blue"
+ assert doc[3].lemma_ == "egg"
+
+
def test_lemmatizer_requires_labels():
nlp = English()
nlp.add_pipe("trainable_lemmatizer")
diff --git a/spacy/tests/pipeline/test_morphologizer.py b/spacy/tests/pipeline/test_morphologizer.py
index 70fc77304..5b9b17c01 100644
--- a/spacy/tests/pipeline/test_morphologizer.py
+++ b/spacy/tests/pipeline/test_morphologizer.py
@@ -50,6 +50,12 @@ def test_implicit_label():
nlp.initialize(get_examples=lambda: train_examples)
+def test_is_distillable():
+ nlp = English()
+ morphologizer = nlp.add_pipe("morphologizer")
+ assert morphologizer.is_distillable
+
+
def test_no_resize():
nlp = Language()
morphologizer = nlp.add_pipe("morphologizer")
diff --git a/spacy/tests/pipeline/test_senter.py b/spacy/tests/pipeline/test_senter.py
index 3deac9e9a..a771d62fa 100644
--- a/spacy/tests/pipeline/test_senter.py
+++ b/spacy/tests/pipeline/test_senter.py
@@ -11,6 +11,12 @@ from spacy.pipeline import TrainablePipe
from spacy.tests.util import make_tempdir
+def test_is_distillable():
+ nlp = English()
+ senter = nlp.add_pipe("senter")
+ assert senter.is_distillable
+
+
def test_label_types():
nlp = Language()
senter = nlp.add_pipe("senter")
diff --git a/spacy/tests/pipeline/test_tagger.py b/spacy/tests/pipeline/test_tagger.py
index a0c71198e..344859f8d 100644
--- a/spacy/tests/pipeline/test_tagger.py
+++ b/spacy/tests/pipeline/test_tagger.py
@@ -213,6 +213,52 @@ def test_overfitting_IO():
assert doc3[0].tag_ != "N"
+def test_is_distillable():
+ nlp = English()
+ tagger = nlp.add_pipe("tagger")
+ assert tagger.is_distillable
+
+
+def test_distill():
+ teacher = English()
+ teacher_tagger = teacher.add_pipe("tagger")
+ train_examples = []
+ for t in TRAIN_DATA:
+ train_examples.append(Example.from_dict(teacher.make_doc(t[0]), t[1]))
+
+ optimizer = teacher.initialize(get_examples=lambda: train_examples)
+
+ for i in range(50):
+ losses = {}
+ teacher.update(train_examples, sgd=optimizer, losses=losses)
+ assert losses["tagger"] < 0.00001
+
+ student = English()
+ student_tagger = student.add_pipe("tagger")
+ student_tagger.min_tree_freq = 1
+ student_tagger.initialize(
+ get_examples=lambda: train_examples, labels=teacher_tagger.label_data
+ )
+
+ distill_examples = [
+ Example.from_dict(teacher.make_doc(t[0]), {}) for t in TRAIN_DATA
+ ]
+
+ for i in range(50):
+ losses = {}
+ student_tagger.distill(
+ teacher_tagger, distill_examples, sgd=optimizer, losses=losses
+ )
+ assert losses["tagger"] < 0.00001
+
+ test_text = "I like blue eggs"
+ doc = student(test_text)
+ assert doc[0].tag_ == "N"
+ assert doc[1].tag_ == "V"
+ assert doc[2].tag_ == "J"
+ assert doc[3].tag_ == "N"
+
+
def test_save_activations():
# Test if activations are correctly added to Doc when requested.
nlp = English()
diff --git a/spacy/tests/pipeline/test_textcat.py b/spacy/tests/pipeline/test_textcat.py
index 304209933..9c0eeb171 100644
--- a/spacy/tests/pipeline/test_textcat.py
+++ b/spacy/tests/pipeline/test_textcat.py
@@ -565,6 +565,12 @@ def test_initialize_examples(name, get_examples, train_data):
nlp.initialize(get_examples=get_examples())
+def test_is_distillable():
+ nlp = English()
+ textcat = nlp.add_pipe("textcat")
+ assert not textcat.is_distillable
+
+
def test_overfitting_IO():
# Simple test to try and quickly overfit the single-label textcat component - ensuring the ML models work correctly
fix_random_seed(0)
diff --git a/spacy/tests/training/test_training.py b/spacy/tests/training/test_training.py
index 7933ea31f..9fdd416b1 100644
--- a/spacy/tests/training/test_training.py
+++ b/spacy/tests/training/test_training.py
@@ -8,7 +8,7 @@ from spacy.lang.en import English
from spacy.tokens import Doc, DocBin
from spacy.training import Alignment, Corpus, Example, biluo_tags_to_offsets
from spacy.training import biluo_tags_to_spans, docs_to_json, iob_to_biluo
-from spacy.training import offsets_to_biluo_tags
+from spacy.training import offsets_to_biluo_tags, validate_distillation_examples
from spacy.training.alignment_array import AlignmentArray
from spacy.training.align import get_alignments
from spacy.training.converters import json_to_docs
@@ -365,6 +365,19 @@ def test_example_from_dict_some_ner(en_vocab):
assert ner_tags == ["U-LOC", None, None, None]
+def test_validate_distillation_examples(en_vocab):
+ words = ["a", "b", "c", "d"]
+ spaces = [True, True, False, True]
+ predicted = Doc(en_vocab, words=words, spaces=spaces)
+
+ example = Example.from_dict(predicted, {})
+ validate_distillation_examples([example], "test_validate_distillation_examples")
+
+ example = Example.from_dict(predicted, {"words": words + ["e"]})
+ with pytest.raises(ValueError, match=r"distillation"):
+ validate_distillation_examples([example], "test_validate_distillation_examples")
+
+
@pytest.mark.filterwarnings("ignore::UserWarning")
def test_json_to_docs_no_ner(en_vocab):
data = [
diff --git a/spacy/training/__init__.py b/spacy/training/__init__.py
index 71d1fa775..454437104 100644
--- a/spacy/training/__init__.py
+++ b/spacy/training/__init__.py
@@ -1,5 +1,6 @@
from .corpus import Corpus, JsonlCorpus # noqa: F401
from .example import Example, validate_examples, validate_get_examples # noqa: F401
+from .example import validate_distillation_examples # noqa: F401
from .alignment import Alignment # noqa: F401
from .augment import dont_augment, orth_variants_augmenter # noqa: F401
from .iob_utils import iob_to_biluo, biluo_to_iob # noqa: F401
diff --git a/spacy/training/example.pyx b/spacy/training/example.pyx
index 1908bf042..a36fa0d73 100644
--- a/spacy/training/example.pyx
+++ b/spacy/training/example.pyx
@@ -46,6 +46,13 @@ def validate_examples(examples, method):
raise TypeError(err)
+def validate_distillation_examples(examples, method):
+ validate_examples(examples, method)
+ for eg in examples:
+ if [token.text for token in eg.reference] != [token.text for token in eg.predicted]:
+ raise ValueError(Errors.E4003)
+
+
def validate_get_examples(get_examples, method):
"""Check that a generator of a batch of examples received during processing is valid:
the callable produces a non-empty list of Example objects.
diff --git a/website/docs/api/dependencyparser.mdx b/website/docs/api/dependencyparser.mdx
index 771a00aee..5179ce48b 100644
--- a/website/docs/api/dependencyparser.mdx
+++ b/website/docs/api/dependencyparser.mdx
@@ -131,6 +131,39 @@ and all pipeline components are applied to the `Doc` in order. Both
| `doc` | The document to process. ~~Doc~~ |
| **RETURNS** | The processed document. ~~Doc~~ |
+## DependencyParser.distill {id="distill", tag="method,experimental", version="4"}
+
+Train a pipe (the student) on the predictions of another pipe (the teacher). The
+student is typically trained on the probability distribution of the teacher, but
+details may differ per pipe. The goal of distillation is to transfer knowledge
+from the teacher to the student.
+
+The distillation is performed on ~~Example~~ objects. The `Example.reference`
+and `Example.predicted` ~~Doc~~s must have the same number of tokens and the
+same orthography. Even though the reference does not need have to have gold
+annotations, the teacher could adds its own annotations when necessary.
+
+This feature is experimental.
+
+> #### Example
+>
+> ```python
+> teacher_pipe = teacher.add_pipe("parser")
+> student_pipe = student.add_pipe("parser")
+> optimizer = nlp.resume_training()
+> losses = student.distill(teacher_pipe, examples, sgd=optimizer)
+> ```
+
+| Name | Description |
+| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------- |
+| `teacher_pipe` | The teacher pipe to learn from. ~~Optional[TrainablePipe]~~ |
+| `examples` | Distillation examples. The reference and predicted docs must have the same number of tokens and the same orthography. ~~Iterable[Example]~~ |
+| _keyword-only_ | |
+| `drop` | Dropout rate. ~~float~~ |
+| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ |
+| `losses` | Optional record of the loss during distillation. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ |
+| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
+
## DependencyParser.pipe {id="pipe",tag="method"}
Apply the pipe to a stream of documents. This usually happens under the hood
@@ -268,6 +301,27 @@ predicted scores.
| `scores` | Scores representing the model's predictions. ~~StateClass~~ |
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
+## DependencyParser.get_teacher_student_loss {id="get_teacher_student_loss", tag="method", version="4"}
+
+Calculate the loss and its gradient for the batch of student scores relative to
+the teacher scores.
+
+> #### Example
+>
+> ```python
+> teacher_parser = teacher.get_pipe("parser")
+> student_parser = student.add_pipe("parser")
+> student_scores = student_parser.predict([eg.predicted for eg in examples])
+> teacher_scores = teacher_parser.predict([eg.predicted for eg in examples])
+> loss, d_loss = student_parser.get_teacher_student_loss(teacher_scores, student_scores)
+> ```
+
+| Name | Description |
+| ---------------- | --------------------------------------------------------------------------- |
+| `teacher_scores` | Scores representing the teacher model's predictions. |
+| `student_scores` | Scores representing the student model's predictions. |
+| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
+
## DependencyParser.create_optimizer {id="create_optimizer",tag="method"}
Create an [`Optimizer`](https://thinc.ai/docs/api-optimizers) for the pipeline
diff --git a/website/docs/api/edittreelemmatizer.mdx b/website/docs/api/edittreelemmatizer.mdx
index 17af19e8c..2e0993657 100644
--- a/website/docs/api/edittreelemmatizer.mdx
+++ b/website/docs/api/edittreelemmatizer.mdx
@@ -115,6 +115,39 @@ and all pipeline components are applied to the `Doc` in order. Both
| `doc` | The document to process. ~~Doc~~ |
| **RETURNS** | The processed document. ~~Doc~~ |
+## EditTreeLemmatizer.distill {id="distill", tag="method,experimental", version="4"}
+
+Train a pipe (the student) on the predictions of another pipe (the teacher). The
+student is typically trained on the probability distribution of the teacher, but
+details may differ per pipe. The goal of distillation is to transfer knowledge
+from the teacher to the student.
+
+The distillation is performed on ~~Example~~ objects. The `Example.reference`
+and `Example.predicted` ~~Doc~~s must have the same number of tokens and the
+same orthography. Even though the reference does not need have to have gold
+annotations, the teacher could adds its own annotations when necessary.
+
+This feature is experimental.
+
+> #### Example
+>
+> ```python
+> teacher_pipe = teacher.add_pipe("trainable_lemmatizer")
+> student_pipe = student.add_pipe("trainable_lemmatizer")
+> optimizer = nlp.resume_training()
+> losses = student.distill(teacher_pipe, examples, sgd=optimizer)
+> ```
+
+| Name | Description |
+| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------- |
+| `teacher_pipe` | The teacher pipe to learn from. ~~Optional[TrainablePipe]~~ |
+| `examples` | Distillation examples. The reference and predicted docs must have the same number of tokens and the same orthography. ~~Iterable[Example]~~ |
+| _keyword-only_ | |
+| `drop` | Dropout rate. ~~float~~ |
+| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ |
+| `losses` | Optional record of the loss during distillation. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ |
+| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
+
## EditTreeLemmatizer.pipe {id="pipe",tag="method"}
Apply the pipe to a stream of documents. This usually happens under the hood
@@ -269,6 +302,27 @@ Create an optimizer for the pipeline component.
| ----------- | ---------------------------- |
| **RETURNS** | The optimizer. ~~Optimizer~~ |
+## EditTreeLemmatizer.get_teacher_student_loss {id="get_teacher_student_loss", tag="method", version="4"}
+
+Calculate the loss and its gradient for the batch of student scores relative to
+the teacher scores.
+
+> #### Example
+>
+> ```python
+> teacher_lemmatizer = teacher.get_pipe("trainable_lemmatizer")
+> student_lemmatizer = student.add_pipe("trainable_lemmatizer")
+> student_scores = student_lemmatizer.predict([eg.predicted for eg in examples])
+> teacher_scores = teacher_lemmatizer.predict([eg.predicted for eg in examples])
+> loss, d_loss = student_lemmatizer.get_teacher_student_loss(teacher_scores, student_scores)
+> ```
+
+| Name | Description |
+| ---------------- | --------------------------------------------------------------------------- |
+| `teacher_scores` | Scores representing the teacher model's predictions. |
+| `student_scores` | Scores representing the student model's predictions. |
+| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
+
## EditTreeLemmatizer.use_params {id="use_params",tag="method, contextmanager"}
Modify the pipe's model, to use the given parameter values. At the end of the
diff --git a/website/docs/api/entityrecognizer.mdx b/website/docs/api/entityrecognizer.mdx
index 1f386bbb6..005d5d11d 100644
--- a/website/docs/api/entityrecognizer.mdx
+++ b/website/docs/api/entityrecognizer.mdx
@@ -127,6 +127,39 @@ and all pipeline components are applied to the `Doc` in order. Both
| `doc` | The document to process. ~~Doc~~ |
| **RETURNS** | The processed document. ~~Doc~~ |
+## EntityRecognizer.distill {id="distill", tag="method,experimental", version="4"}
+
+Train a pipe (the student) on the predictions of another pipe (the teacher). The
+student is typically trained on the probability distribution of the teacher, but
+details may differ per pipe. The goal of distillation is to transfer knowledge
+from the teacher to the student.
+
+The distillation is performed on ~~Example~~ objects. The `Example.reference`
+and `Example.predicted` ~~Doc~~s must have the same number of tokens and the
+same orthography. Even though the reference does not need have to have gold
+annotations, the teacher could adds its own annotations when necessary.
+
+This feature is experimental.
+
+> #### Example
+>
+> ```python
+> teacher_pipe = teacher.add_pipe("ner")
+> student_pipe = student.add_pipe("ner")
+> optimizer = nlp.resume_training()
+> losses = student.distill(teacher_pipe, examples, sgd=optimizer)
+> ```
+
+| Name | Description |
+| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------- |
+| `teacher_pipe` | The teacher pipe to learn from. ~~Optional[TrainablePipe]~~ |
+| `examples` | Distillation examples. The reference and predicted docs must have the same number of tokens and the same orthography. ~~Iterable[Example]~~ |
+| _keyword-only_ | |
+| `drop` | Dropout rate. ~~float~~ |
+| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ |
+| `losses` | Optional record of the loss during distillation. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ |
+| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
+
## EntityRecognizer.pipe {id="pipe",tag="method"}
Apply the pipe to a stream of documents. This usually happens under the hood
@@ -264,6 +297,27 @@ predicted scores.
| `scores` | Scores representing the model's predictions. ~~StateClass~~ |
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
+## EntityRecognizer.get_teacher_student_loss {id="get_teacher_student_loss", tag="method", version="4"}
+
+Calculate the loss and its gradient for the batch of student scores relative to
+the teacher scores.
+
+> #### Example
+>
+> ```python
+> teacher_ner = teacher.get_pipe("ner")
+> student_ner = student.add_pipe("ner")
+> student_scores = student_ner.predict([eg.predicted for eg in examples])
+> teacher_scores = teacher_ner.predict([eg.predicted for eg in examples])
+> loss, d_loss = student_ner.get_teacher_student_loss(teacher_scores, student_scores)
+> ```
+
+| Name | Description |
+| ---------------- | --------------------------------------------------------------------------- |
+| `teacher_scores` | Scores representing the teacher model's predictions. |
+| `student_scores` | Scores representing the student model's predictions. |
+| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
+
## EntityRecognizer.create_optimizer {id="create_optimizer",tag="method"}
Create an optimizer for the pipeline component.
diff --git a/website/docs/api/morphologizer.mdx b/website/docs/api/morphologizer.mdx
index 1fda807cb..4f79458d3 100644
--- a/website/docs/api/morphologizer.mdx
+++ b/website/docs/api/morphologizer.mdx
@@ -121,6 +121,39 @@ delegate to the [`predict`](/api/morphologizer#predict) and
| `doc` | The document to process. ~~Doc~~ |
| **RETURNS** | The processed document. ~~Doc~~ |
+## Morphologizer.distill {id="distill", tag="method,experimental", version="4"}
+
+Train a pipe (the student) on the predictions of another pipe (the teacher). The
+student is typically trained on the probability distribution of the teacher, but
+details may differ per pipe. The goal of distillation is to transfer knowledge
+from the teacher to the student.
+
+The distillation is performed on ~~Example~~ objects. The `Example.reference`
+and `Example.predicted` ~~Doc~~s must have the same number of tokens and the
+same orthography. Even though the reference does not need have to have gold
+annotations, the teacher could adds its own annotations when necessary.
+
+This feature is experimental.
+
+> #### Example
+>
+> ```python
+> teacher_pipe = teacher.add_pipe("morphologizer")
+> student_pipe = student.add_pipe("morphologizer")
+> optimizer = nlp.resume_training()
+> losses = student.distill(teacher_pipe, examples, sgd=optimizer)
+> ```
+
+| Name | Description |
+| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------- |
+| `teacher_pipe` | The teacher pipe to learn from. ~~Optional[TrainablePipe]~~ |
+| `examples` | Distillation examples. The reference and predicted docs must have the same number of tokens and the same orthography. ~~Iterable[Example]~~ |
+| _keyword-only_ | |
+| `drop` | Dropout rate. ~~float~~ |
+| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ |
+| `losses` | Optional record of the loss during distillation. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ |
+| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
+
## Morphologizer.pipe {id="pipe",tag="method"}
Apply the pipe to a stream of documents. This usually happens under the hood
@@ -259,6 +292,27 @@ predicted scores.
| `scores` | Scores representing the model's predictions. |
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
+## Morphologizer.get_teacher_student_loss {id="get_teacher_student_loss", tag="method", version="4"}
+
+Calculate the loss and its gradient for the batch of student scores relative to
+the teacher scores.
+
+> #### Example
+>
+> ```python
+> teacher_morphologizer = teacher.get_pipe("morphologizer")
+> student_morphologizer = student.add_pipe("morphologizer")
+> student_scores = student_morphologizer.predict([eg.predicted for eg in examples])
+> teacher_scores = teacher_morphologizer.predict([eg.predicted for eg in examples])
+> loss, d_loss = student_morphologizer.get_teacher_student_loss(teacher_scores, student_scores)
+> ```
+
+| Name | Description |
+| ---------------- | --------------------------------------------------------------------------- |
+| `teacher_scores` | Scores representing the teacher model's predictions. |
+| `student_scores` | Scores representing the student model's predictions. |
+| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
+
## Morphologizer.create_optimizer {id="create_optimizer",tag="method"}
Create an optimizer for the pipeline component.
diff --git a/website/docs/api/pipe.mdx b/website/docs/api/pipe.mdx
index b387ea586..120c8f690 100644
--- a/website/docs/api/pipe.mdx
+++ b/website/docs/api/pipe.mdx
@@ -234,6 +234,39 @@ predictions and gold-standard annotations, and update the component's model.
| `losses` | Optional record of the loss during training. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ |
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
+## TrainablePipe.distill {id="distill", tag="method,experimental", version="4"}
+
+Train a pipe (the student) on the predictions of another pipe (the teacher). The
+student is typically trained on the probability distribution of the teacher, but
+details may differ per pipe. The goal of distillation is to transfer knowledge
+from the teacher to the student.
+
+The distillation is performed on ~~Example~~ objects. The `Example.reference`
+and `Example.predicted` ~~Doc~~s must have the same number of tokens and the
+same orthography. Even though the reference does not need have to have gold
+annotations, the teacher could adds its own annotations when necessary.
+
+This feature is experimental.
+
+> #### Example
+>
+> ```python
+> teacher_pipe = teacher.add_pipe("your_custom_pipe")
+> student_pipe = student.add_pipe("your_custom_pipe")
+> optimizer = nlp.resume_training()
+> losses = student.distill(teacher_pipe, examples, sgd=optimizer)
+> ```
+
+| Name | Description |
+| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------- |
+| `teacher_pipe` | The teacher pipe to learn from. ~~Optional[TrainablePipe]~~ |
+| `examples` | Distillation examples. The reference and predicted docs must have the same number of tokens and the same orthography. ~~Iterable[Example]~~ |
+| _keyword-only_ | |
+| `drop` | Dropout rate. ~~float~~ |
+| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ |
+| `losses` | Optional record of the loss during distillation. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ |
+| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
+
## TrainablePipe.rehearse {id="rehearse",tag="method,experimental",version="3"}
Perform a "rehearsal" update from a batch of data. Rehearsal updates teach the
@@ -281,6 +314,34 @@ This method needs to be overwritten with your own custom `get_loss` method.
| `scores` | Scores representing the model's predictions. |
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
+## TrainablePipe.get_teacher_student_loss {id="get_teacher_student_loss", tag="method", version="4"}
+
+Calculate the loss and its gradient for the batch of student scores relative to
+the teacher scores.
+
+
+
+This method needs to be overwritten with your own custom
+`get_teacher_student_loss` method.
+
+
+
+> #### Example
+>
+> ```python
+> teacher_pipe = teacher.get_pipe("your_custom_pipe")
+> student_pipe = student.add_pipe("your_custom_pipe")
+> student_scores = student_pipe.predict([eg.predicted for eg in examples])
+> teacher_scores = teacher_pipe.predict([eg.predicted for eg in examples])
+> loss, d_loss = student_pipe.get_teacher_student_loss(teacher_scores, student_scores)
+> ```
+
+| Name | Description |
+| ---------------- | --------------------------------------------------------------------------- |
+| `teacher_scores` | Scores representing the teacher model's predictions. |
+| `student_scores` | Scores representing the student model's predictions. |
+| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
+
## TrainablePipe.score {id="score",tag="method",version="3"}
Score a batch of examples.
diff --git a/website/docs/api/sentencerecognizer.mdx b/website/docs/api/sentencerecognizer.mdx
index d5d096d76..02fd57102 100644
--- a/website/docs/api/sentencerecognizer.mdx
+++ b/website/docs/api/sentencerecognizer.mdx
@@ -106,6 +106,39 @@ and all pipeline components are applied to the `Doc` in order. Both
| `doc` | The document to process. ~~Doc~~ |
| **RETURNS** | The processed document. ~~Doc~~ |
+## SentenceRecognizer.distill {id="distill", tag="method,experimental", version="4"}
+
+Train a pipe (the student) on the predictions of another pipe (the teacher). The
+student is typically trained on the probability distribution of the teacher, but
+details may differ per pipe. The goal of distillation is to transfer knowledge
+from the teacher to the student.
+
+The distillation is performed on ~~Example~~ objects. The `Example.reference`
+and `Example.predicted` ~~Doc~~s must have the same number of tokens and the
+same orthography. Even though the reference does not need have to have gold
+annotations, the teacher could adds its own annotations when necessary.
+
+This feature is experimental.
+
+> #### Example
+>
+> ```python
+> teacher_pipe = teacher.add_pipe("senter")
+> student_pipe = student.add_pipe("senter")
+> optimizer = nlp.resume_training()
+> losses = student.distill(teacher_pipe, examples, sgd=optimizer)
+> ```
+
+| Name | Description |
+| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------- |
+| `teacher_pipe` | The teacher pipe to learn from. ~~Optional[TrainablePipe]~~ |
+| `examples` | Distillation examples. The reference and predicted docs must have the same number of tokens and the same orthography. ~~Iterable[Example]~~ |
+| _keyword-only_ | |
+| `drop` | Dropout rate. ~~float~~ |
+| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ |
+| `losses` | Optional record of the loss during distillation. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ |
+| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
+
## SentenceRecognizer.pipe {id="pipe",tag="method"}
Apply the pipe to a stream of documents. This usually happens under the hood
@@ -254,6 +287,27 @@ predicted scores.
| `scores` | Scores representing the model's predictions. |
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
+## SentenceRecognizer.get_teacher_student_loss {id="get_teacher_student_loss", tag="method", version="4"}
+
+Calculate the loss and its gradient for the batch of student scores relative to
+the teacher scores.
+
+> #### Example
+>
+> ```python
+> teacher_senter = teacher.get_pipe("senter")
+> student_senter = student.add_pipe("senter")
+> student_scores = student_senter.predict([eg.predicted for eg in examples])
+> teacher_scores = teacher_senter.predict([eg.predicted for eg in examples])
+> loss, d_loss = student_senter.get_teacher_student_loss(teacher_scores, student_scores)
+> ```
+
+| Name | Description |
+| ---------------- | --------------------------------------------------------------------------- |
+| `teacher_scores` | Scores representing the teacher model's predictions. |
+| `student_scores` | Scores representing the student model's predictions. |
+| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
+
## SentenceRecognizer.create_optimizer {id="create_optimizer",tag="method"}
Create an optimizer for the pipeline component.
diff --git a/website/docs/api/tagger.mdx b/website/docs/api/tagger.mdx
index ae14df212..664fd7940 100644
--- a/website/docs/api/tagger.mdx
+++ b/website/docs/api/tagger.mdx
@@ -105,6 +105,39 @@ and all pipeline components are applied to the `Doc` in order. Both
| `doc` | The document to process. ~~Doc~~ |
| **RETURNS** | The processed document. ~~Doc~~ |
+## Tagger.distill {id="distill", tag="method,experimental", version="4"}
+
+Train a pipe (the student) on the predictions of another pipe (the teacher). The
+student is typically trained on the probability distribution of the teacher, but
+details may differ per pipe. The goal of distillation is to transfer knowledge
+from the teacher to the student.
+
+The distillation is performed on ~~Example~~ objects. The `Example.reference`
+and `Example.predicted` ~~Doc~~s must have the same number of tokens and the
+same orthography. Even though the reference does not need have to have gold
+annotations, the teacher could adds its own annotations when necessary.
+
+This feature is experimental.
+
+> #### Example
+>
+> ```python
+> teacher_pipe = teacher.add_pipe("tagger")
+> student_pipe = student.add_pipe("tagger")
+> optimizer = nlp.resume_training()
+> losses = student.distill(teacher_pipe, examples, sgd=optimizer)
+> ```
+
+| Name | Description |
+| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------- |
+| `teacher_pipe` | The teacher pipe to learn from. ~~Optional[TrainablePipe]~~ |
+| `examples` | Distillation examples. The reference and predicted docs must have the same number of tokens and the same orthography. ~~Iterable[Example]~~ |
+| _keyword-only_ | |
+| `drop` | Dropout rate. ~~float~~ |
+| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ |
+| `losses` | Optional record of the loss during distillation. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ |
+| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
+
## Tagger.pipe {id="pipe",tag="method"}
Apply the pipe to a stream of documents. This usually happens under the hood
@@ -265,6 +298,27 @@ predicted scores.
| `scores` | Scores representing the model's predictions. |
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
+## Tagger.get_teacher_student_loss {id="get_teacher_student_loss", tag="method", version="4"}
+
+Calculate the loss and its gradient for the batch of student scores relative to
+the teacher scores.
+
+> #### Example
+>
+> ```python
+> teacher_tagger = teacher.get_pipe("tagger")
+> student_tagger = student.add_pipe("tagger")
+> student_scores = student_tagger.predict([eg.predicted for eg in examples])
+> teacher_scores = teacher_tagger.predict([eg.predicted for eg in examples])
+> loss, d_loss = student_tagger.get_teacher_student_loss(teacher_scores, student_scores)
+> ```
+
+| Name | Description |
+| ---------------- | --------------------------------------------------------------------------- |
+| `teacher_scores` | Scores representing the teacher model's predictions. |
+| `student_scores` | Scores representing the student model's predictions. |
+| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
+
## Tagger.create_optimizer {id="create_optimizer",tag="method"}
Create an optimizer for the pipeline component.
diff --git a/website/docs/api/top-level.mdx b/website/docs/api/top-level.mdx
index a222cfa8f..7e47d324a 100644
--- a/website/docs/api/top-level.mdx
+++ b/website/docs/api/top-level.mdx
@@ -921,7 +921,8 @@ backprop passes.
Recursively wrap both the models and methods of each pipe using
[NVTX](https://nvidia.github.io/NVTX/) range markers. By default, the following
methods are wrapped: `pipe`, `predict`, `set_annotations`, `update`, `rehearse`,
-`get_loss`, `initialize`, `begin_update`, `finish_update`, `update`.
+`get_loss`, `get_teacher_student_loss`, `initialize`, `begin_update`,
+`finish_update`, `update`.
| Name | Description |
| --------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------- |
diff --git a/website/docs/usage/processing-pipelines.mdx b/website/docs/usage/processing-pipelines.mdx
index 11e1cb620..08cd64aa7 100644
--- a/website/docs/usage/processing-pipelines.mdx
+++ b/website/docs/usage/processing-pipelines.mdx
@@ -1354,12 +1354,14 @@ For some use cases, it makes sense to also overwrite additional methods to
customize how the model is updated from examples, how it's initialized, how the
loss is calculated and to add evaluation scores to the training output.
-| Name | Description |
-| ------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
-| [`update`](/api/pipe#update) | Learn from a batch of [`Example`](/api/example) objects containing the predictions and gold-standard annotations, and update the component's model. |
-| [`initialize`](/api/pipe#initialize) | Initialize the model. Typically calls into [`Model.initialize`](https://thinc.ai/docs/api-model#initialize) and can be passed custom arguments via the [`[initialize]`](/api/data-formats#config-initialize) config block that are only loaded during training or when you call [`nlp.initialize`](/api/language#initialize), not at runtime. |
-| [`get_loss`](/api/pipe#get_loss) | Return a tuple of the loss and the gradient for a batch of [`Example`](/api/example) objects. |
-| [`score`](/api/pipe#score) | Score a batch of [`Example`](/api/example) objects and return a dictionary of scores. The [`@Language.factory`](/api/language#factory) decorator can define the `default_score_weights` of the component to decide which keys of the scores to display during training and how they count towards the final score. |
+| Name | Description |
+| ---------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| [`update`](/api/pipe#update) | Learn from a batch of [`Example`](/api/example) objects containing the predictions and gold-standard annotations, and update the component's model. |
+| [`distill`](/api/pipe#distill) | Learn from a teacher pipeline using a batch of [`Doc`](/api/doc) objects and update the component's model. |
+| [`initialize`](/api/pipe#initialize) | Initialize the model. Typically calls into [`Model.initialize`](https://thinc.ai/docs/api-model#initialize) and can be passed custom arguments via the [`[initialize]`](/api/data-formats#config-initialize) config block that are only loaded during training or when you call [`nlp.initialize`](/api/language#initialize), not at runtime. |
+| [`get_loss`](/api/pipe#get_loss) | Return a tuple of the loss and the gradient for a batch of [`Example`](/api/example) objects. |
+| [`get_teacher_student_loss`](/api/pipe#get_teacher_student_loss) | Return a tuple of the loss and the gradient for the student scores relative to the teacher scores. |
+| [`score`](/api/pipe#score) | Score a batch of [`Example`](/api/example) objects and return a dictionary of scores. The [`@Language.factory`](/api/language#factory) decorator can define the `default_score_weights` of the component to decide which keys of the scores to display during training and how they count towards the final score. |