mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-04 20:30:24 +03:00
Add TrainablePipe.{distill,get_teacher_student_loss}
This change adds two methods: - `TrainablePipe::distill` which performs a training step of a student pipe on a teacher pipe, giving a batch of `Doc`s. - `TrainablePipe::get_teacher_student_loss` computes the loss of a student relative to the teacher. The `distill` or `get_teacher_student_loss` methods are also implemented in the tagger, edit tree lemmatizer, and parser pipes, to enable distillation in those pipes and as an example for other pipes.
This commit is contained in:
parent
d30ba9b7b8
commit
adead2a104
|
@ -949,6 +949,7 @@ class Errors(metaclass=ErrorsWithCodes):
|
||||||
E4000 = ("Expected a Doc as input, but got: '{type}'")
|
E4000 = ("Expected a Doc as input, but got: '{type}'")
|
||||||
E4001 = ("Expected input to be one of the following types: ({expected_types}), "
|
E4001 = ("Expected input to be one of the following types: ({expected_types}), "
|
||||||
"but got '{received_type}'")
|
"but got '{received_type}'")
|
||||||
|
E4002 = ("Pipe '{name}' requires teacher pipe for distillation.")
|
||||||
|
|
||||||
|
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
|
@ -23,6 +23,7 @@ DEFAULT_NVTX_ANNOTATABLE_PIPE_METHODS = [
|
||||||
"update",
|
"update",
|
||||||
"rehearse",
|
"rehearse",
|
||||||
"get_loss",
|
"get_loss",
|
||||||
|
"get_student_teacher_loss",
|
||||||
"initialize",
|
"initialize",
|
||||||
"begin_update",
|
"begin_update",
|
||||||
"finish_update",
|
"finish_update",
|
||||||
|
|
|
@ -155,6 +155,23 @@ class EditTreeLemmatizer(TrainablePipe):
|
||||||
|
|
||||||
return float(loss), d_scores
|
return float(loss), d_scores
|
||||||
|
|
||||||
|
def get_teacher_student_loss(
|
||||||
|
self, teacher_scores: List[Floats2d], student_scores: List[Floats2d]
|
||||||
|
) -> Tuple[float, List[Floats2d]]:
|
||||||
|
"""Calculate the loss and its gradient for a batch of student
|
||||||
|
scores, relative to teacher scores.
|
||||||
|
|
||||||
|
teacher_scores: Scores representing the teacher model's predictions.
|
||||||
|
student_scores: Scores representing the student model's predictions.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/edittreelemmatizer#get_teacher_student_loss
|
||||||
|
"""
|
||||||
|
loss_func = LegacySequenceCategoricalCrossentropy(normalize=False)
|
||||||
|
d_scores, loss = loss_func(student_scores, teacher_scores)
|
||||||
|
if self.model.ops.xp.isnan(loss):
|
||||||
|
raise ValueError(Errors.E910.format(name=self.name))
|
||||||
|
return float(loss), d_scores
|
||||||
|
|
||||||
def predict(self, docs: Iterable[Doc]) -> ActivationsT:
|
def predict(self, docs: Iterable[Doc]) -> ActivationsT:
|
||||||
n_docs = len(list(docs))
|
n_docs = len(list(docs))
|
||||||
if not any(len(doc) for doc in docs):
|
if not any(len(doc) for doc in docs):
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# cython: infer_types=True, profile=True, binding=True
|
# cython: infer_types=True, profile=True, binding=True
|
||||||
from typing import Callable, Dict, Iterable, List, Optional, Union
|
from typing import Callable, Dict, Iterable, List, Optional, Union
|
||||||
|
from typing import Tuple
|
||||||
import numpy
|
import numpy
|
||||||
import srsly
|
import srsly
|
||||||
from thinc.api import Model, set_dropout_rate, Config
|
from thinc.api import Model, set_dropout_rate, Config
|
||||||
|
@ -245,7 +246,6 @@ class Tagger(TrainablePipe):
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/tagger#rehearse
|
DOCS: https://spacy.io/api/tagger#rehearse
|
||||||
"""
|
"""
|
||||||
loss_func = LegacySequenceCategoricalCrossentropy()
|
|
||||||
if losses is None:
|
if losses is None:
|
||||||
losses = {}
|
losses = {}
|
||||||
losses.setdefault(self.name, 0.0)
|
losses.setdefault(self.name, 0.0)
|
||||||
|
@ -259,12 +259,30 @@ class Tagger(TrainablePipe):
|
||||||
set_dropout_rate(self.model, drop)
|
set_dropout_rate(self.model, drop)
|
||||||
tag_scores, bp_tag_scores = self.model.begin_update(docs)
|
tag_scores, bp_tag_scores = self.model.begin_update(docs)
|
||||||
tutor_tag_scores, _ = self._rehearsal_model.begin_update(docs)
|
tutor_tag_scores, _ = self._rehearsal_model.begin_update(docs)
|
||||||
grads, loss = loss_func(tag_scores, tutor_tag_scores)
|
loss, grads = self.get_teacher_student_loss(tutor_tag_scores, tag_scores)
|
||||||
bp_tag_scores(grads)
|
bp_tag_scores(grads)
|
||||||
self.finish_update(sgd)
|
if sgd is not None:
|
||||||
|
self.finish_update(sgd)
|
||||||
losses[self.name] += loss
|
losses[self.name] += loss
|
||||||
return losses
|
return losses
|
||||||
|
|
||||||
|
def get_teacher_student_loss(
|
||||||
|
self, teacher_scores: List[Floats2d], student_scores: List[Floats2d]
|
||||||
|
) -> Tuple[float, List[Floats2d]]:
|
||||||
|
"""Calculate the loss and its gradient for a batch of student
|
||||||
|
scores, relative to teacher scores.
|
||||||
|
|
||||||
|
teacher_scores: Scores representing the teacher model's predictions.
|
||||||
|
student_scores: Scores representing the student model's predictions.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/tagger#get_teacher_student_loss
|
||||||
|
"""
|
||||||
|
loss_func = LegacySequenceCategoricalCrossentropy(normalize=False)
|
||||||
|
d_scores, loss = loss_func(student_scores, teacher_scores)
|
||||||
|
if self.model.ops.xp.isnan(loss):
|
||||||
|
raise ValueError(Errors.E910.format(name=self.name))
|
||||||
|
return float(loss), d_scores
|
||||||
|
|
||||||
def get_loss(self, examples, scores):
|
def get_loss(self, examples, scores):
|
||||||
"""Find the loss and gradient of loss for the batch of documents and
|
"""Find the loss and gradient of loss for the batch of documents and
|
||||||
their predicted scores.
|
their predicted scores.
|
||||||
|
|
|
@ -56,6 +56,56 @@ cdef class TrainablePipe(Pipe):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_handler(self.name, self, [doc], e)
|
error_handler(self.name, self, [doc], e)
|
||||||
|
|
||||||
|
|
||||||
|
def distill(self,
|
||||||
|
teacher_pipe: Optional["TrainablePipe"],
|
||||||
|
teacher_docs: Iterable["Doc"],
|
||||||
|
student_docs: Iterable["Doc"],
|
||||||
|
*,
|
||||||
|
drop: float=0.0,
|
||||||
|
sgd: Optional[Optimizer]=None,
|
||||||
|
losses: Optional[Dict[str, float]]=None) -> Dict[str, float]:
|
||||||
|
"""Train a pipe (the student) on the predictions of another pipe
|
||||||
|
(the teacher). The student is typically trained on the probability
|
||||||
|
distribution of the teacher, but details may differ per pipe.
|
||||||
|
|
||||||
|
teacher_pipe (Optional[TrainablePipe]): The teacher pipe to learn
|
||||||
|
from.
|
||||||
|
teacher_docs (Iterable[Doc]): Documents passed through teacher pipes.
|
||||||
|
student_docs (Iterable[Doc]): Documents passed through student pipes.
|
||||||
|
Must contain the same tokens as `teacher_docs` but may have
|
||||||
|
different annotations.
|
||||||
|
drop (float): dropout rate.
|
||||||
|
sgd (Optional[Optimizer]): An optimizer. Will be created via
|
||||||
|
create_optimizer if not set.
|
||||||
|
losses (Optional[Dict[str, float]]): Optional record of loss during
|
||||||
|
distillation.
|
||||||
|
RETURNS: The updated losses dictionary.
|
||||||
|
"""
|
||||||
|
# By default we require a teacher pipe, but there are downstream
|
||||||
|
# implementations that don't require a pipe.
|
||||||
|
if teacher_pipe is None:
|
||||||
|
raise ValueError(Errors.E4002.format(name=self.name))
|
||||||
|
if losses is None:
|
||||||
|
losses = {}
|
||||||
|
losses.setdefault(self.name, 0.0)
|
||||||
|
if not any(len(doc) for doc in teacher_docs):
|
||||||
|
return losses
|
||||||
|
if not any(len(doc) for doc in student_docs):
|
||||||
|
return losses
|
||||||
|
set_dropout_rate(self.model, drop)
|
||||||
|
for node in teacher_pipe.model.walk():
|
||||||
|
if node.name == "softmax":
|
||||||
|
node.attrs["softmax_normalize"] = True
|
||||||
|
teacher_scores = teacher_pipe.model.predict(teacher_docs)
|
||||||
|
student_scores, bp_student_scores = self.model.begin_update(student_docs)
|
||||||
|
loss, d_scores = self.get_teacher_student_loss(teacher_scores, student_scores)
|
||||||
|
bp_student_scores(d_scores)
|
||||||
|
if sgd is not None:
|
||||||
|
self.finish_update(sgd)
|
||||||
|
losses[self.name] += loss
|
||||||
|
return losses
|
||||||
|
|
||||||
def pipe(self, stream: Iterable[Doc], *, batch_size: int=128) -> Iterator[Doc]:
|
def pipe(self, stream: Iterable[Doc], *, batch_size: int=128) -> Iterator[Doc]:
|
||||||
"""Apply the pipe to a stream of documents. This usually happens under
|
"""Apply the pipe to a stream of documents. This usually happens under
|
||||||
the hood when the nlp object is called on a text and all components are
|
the hood when the nlp object is called on a text and all components are
|
||||||
|
@ -169,6 +219,17 @@ cdef class TrainablePipe(Pipe):
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError(Errors.E931.format(parent="TrainablePipe", method="get_loss", name=self.name))
|
raise NotImplementedError(Errors.E931.format(parent="TrainablePipe", method="get_loss", name=self.name))
|
||||||
|
|
||||||
|
def get_teacher_student_loss(self, teacher_scores, student_scores):
|
||||||
|
"""Calculate the loss and its gradient for a batch of student
|
||||||
|
scores, relative to teacher scores.
|
||||||
|
|
||||||
|
teacher_scores: Scores representing the teacher model's predictions.
|
||||||
|
student_scores: Scores representing the student model's predictions.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/pipe#get_teacher_student_loss
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(Errors.E931.format(parent="TrainablePipe", method="get_teacher_student_loss", name=self.name))
|
||||||
|
|
||||||
def create_optimizer(self) -> Optimizer:
|
def create_optimizer(self) -> Optimizer:
|
||||||
"""Create an optimizer for the pipeline component.
|
"""Create an optimizer for the pipeline component.
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# cython: infer_types=True, cdivision=True, boundscheck=False, binding=True
|
# cython: infer_types=True, cdivision=True, boundscheck=False, binding=True
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
from typing import Dict, Iterable, List, Optional, Tuple
|
||||||
from cymem.cymem cimport Pool
|
from cymem.cymem cimport Pool
|
||||||
cimport numpy as np
|
cimport numpy as np
|
||||||
from itertools import islice
|
from itertools import islice
|
||||||
|
@ -9,7 +10,10 @@ from libc.stdlib cimport calloc, free
|
||||||
import random
|
import random
|
||||||
|
|
||||||
import srsly
|
import srsly
|
||||||
from thinc.api import get_ops, set_dropout_rate, CupyOps, NumpyOps
|
from thinc.api import get_ops, set_dropout_rate, CupyOps, NumpyOps, Optimizer
|
||||||
|
from thinc.api import LegacySequenceCategoricalCrossentropy, chain, softmax_activation, use_ops
|
||||||
|
from thinc.extra.search cimport Beam
|
||||||
|
from thinc.types import Floats2d
|
||||||
import numpy.random
|
import numpy.random
|
||||||
import numpy
|
import numpy
|
||||||
import warnings
|
import warnings
|
||||||
|
@ -203,6 +207,120 @@ cdef class Parser(TrainablePipe):
|
||||||
# Defined in subclasses, to avoid circular import
|
# Defined in subclasses, to avoid circular import
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def distill(self,
|
||||||
|
teacher_pipe: Optional[TrainablePipe],
|
||||||
|
teacher_docs: Iterable[Doc],
|
||||||
|
student_docs: Iterable[Doc],
|
||||||
|
*,
|
||||||
|
drop: float=0.0,
|
||||||
|
sgd: Optional[Optimizer]=None,
|
||||||
|
losses: Optional[Dict[str, float]]=None):
|
||||||
|
"""Train a pipe (the student) on the predictions of another pipe
|
||||||
|
(the teacher). The student is trained on the transition probabilities
|
||||||
|
of the teacher.
|
||||||
|
|
||||||
|
teacher_pipe (Optional[TrainablePipe]): The teacher pipe to learn
|
||||||
|
from.
|
||||||
|
teacher_docs (Iterable[Doc]): Documents passed through teacher pipes.
|
||||||
|
student_docs (Iterable[Doc]): Documents passed through student pipes.
|
||||||
|
Must contain the same tokens as `teacher_docs` but may have
|
||||||
|
different annotations.
|
||||||
|
drop (float): dropout rate.
|
||||||
|
sgd (Optional[Optimizer]): An optimizer. Will be created via
|
||||||
|
create_optimizer if not set.
|
||||||
|
losses (Optional[Dict[str, float]]): Optional record of loss during
|
||||||
|
distillation.
|
||||||
|
RETURNS: The updated losses dictionary.
|
||||||
|
"""
|
||||||
|
if teacher_pipe is None:
|
||||||
|
raise ValueError(Errors.E4002.format(name=self.name))
|
||||||
|
if losses is None:
|
||||||
|
losses = {}
|
||||||
|
losses.setdefault(self.name, 0.0)
|
||||||
|
|
||||||
|
if not any(len(doc) for doc in teacher_docs):
|
||||||
|
return losses
|
||||||
|
if not any(len(doc) for doc in student_docs):
|
||||||
|
return losses
|
||||||
|
|
||||||
|
set_dropout_rate(self.model, drop)
|
||||||
|
|
||||||
|
teacher_step_model = teacher_pipe.model.predict(teacher_docs)
|
||||||
|
student_step_model, backprop_tok2vec = self.model.begin_update(student_docs)
|
||||||
|
|
||||||
|
# Add softmax activation, so that we can compute student losses
|
||||||
|
# with cross-entropy loss.
|
||||||
|
with use_ops("numpy"):
|
||||||
|
teacher_model = chain(teacher_step_model, softmax_activation())
|
||||||
|
student_model = chain(student_step_model, softmax_activation())
|
||||||
|
|
||||||
|
max_moves = self.cfg["update_with_oracle_cut_size"]
|
||||||
|
if max_moves >= 1:
|
||||||
|
# Chop sequences into lengths of this many words, to make the
|
||||||
|
# batch uniform length. Since we do not have a gold standard
|
||||||
|
# sequence, we use the teacher's predictions as the gold
|
||||||
|
# standard.
|
||||||
|
max_moves = int(random.uniform(max_moves // 2, max_moves * 2))
|
||||||
|
states = self._init_batch(teacher_step_model, student_docs, max_moves)
|
||||||
|
else:
|
||||||
|
states = self.moves.init_batch(student_docs)
|
||||||
|
|
||||||
|
loss = 0.0
|
||||||
|
n_moves = 0
|
||||||
|
while states:
|
||||||
|
# We do distillation as follows: (1) for every state, we compute the
|
||||||
|
# transition softmax distributions: (2) we backpropagate the error of
|
||||||
|
# the student (compared to the teacher) into the student model; (3)
|
||||||
|
# for all states, we move to the next state using the student's
|
||||||
|
# predictions.
|
||||||
|
teacher_scores = teacher_model.predict(states)
|
||||||
|
student_scores, backprop = student_model.begin_update(states)
|
||||||
|
state_loss, d_scores = self.get_teacher_student_loss(teacher_scores, student_scores)
|
||||||
|
backprop(d_scores)
|
||||||
|
loss += state_loss
|
||||||
|
self.transition_states(states, student_scores)
|
||||||
|
states = [state for state in states if not state.is_final()]
|
||||||
|
|
||||||
|
# Stop when we reach the maximum number of moves, otherwise we start
|
||||||
|
# to process the remainder of cut sequences again.
|
||||||
|
if max_moves >= 1 and n_moves >= max_moves:
|
||||||
|
break
|
||||||
|
n_moves += 1
|
||||||
|
|
||||||
|
backprop_tok2vec(student_docs)
|
||||||
|
|
||||||
|
if sgd is not None:
|
||||||
|
self.finish_update(sgd)
|
||||||
|
|
||||||
|
losses[self.name] += loss
|
||||||
|
|
||||||
|
del backprop
|
||||||
|
del backprop_tok2vec
|
||||||
|
teacher_step_model.clear_memory()
|
||||||
|
student_step_model.clear_memory()
|
||||||
|
del teacher_model
|
||||||
|
del student_model
|
||||||
|
|
||||||
|
return losses
|
||||||
|
|
||||||
|
|
||||||
|
def get_teacher_student_loss(
|
||||||
|
self, teacher_scores: List[Floats2d], student_scores: List[Floats2d]
|
||||||
|
) -> Tuple[float, List[Floats2d]]:
|
||||||
|
"""Calculate the loss and its gradient for a batch of student
|
||||||
|
scores, relative to teacher scores.
|
||||||
|
|
||||||
|
teacher_scores: Scores representing the teacher model's predictions.
|
||||||
|
student_scores: Scores representing the student model's predictions.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/dependencyparser#get_teacher_student_loss
|
||||||
|
"""
|
||||||
|
loss_func = LegacySequenceCategoricalCrossentropy(normalize=False)
|
||||||
|
d_scores, loss = loss_func(student_scores, teacher_scores)
|
||||||
|
if self.model.ops.xp.isnan(loss):
|
||||||
|
raise ValueError(Errors.E910.format(name=self.name))
|
||||||
|
return float(loss), d_scores
|
||||||
|
|
||||||
def init_multitask_objectives(self, get_examples, pipeline, **cfg):
|
def init_multitask_objectives(self, get_examples, pipeline, **cfg):
|
||||||
"""Setup models for secondary objectives, to benefit from multi-task
|
"""Setup models for secondary objectives, to benefit from multi-task
|
||||||
learning. This method is intended to be overridden by subclasses.
|
learning. This method is intended to be overridden by subclasses.
|
||||||
|
@ -625,6 +743,40 @@ cdef class Parser(TrainablePipe):
|
||||||
raise ValueError(Errors.E149) from None
|
raise ValueError(Errors.E149) from None
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def _init_batch(self, teacher_step_model, docs, max_length):
|
||||||
|
"""Make a square batch of length equal to the shortest transition
|
||||||
|
sequence or a cap. A long
|
||||||
|
doc will get multiple states. Let's say we have a doc of length 2*N,
|
||||||
|
where N is the shortest doc. We'll make two states, one representing
|
||||||
|
long_doc[:N], and another representing long_doc[N:]. In contrast to
|
||||||
|
_init_gold_batch, this version uses a teacher model to generate the
|
||||||
|
cut sequences."""
|
||||||
|
cdef:
|
||||||
|
StateClass start_state
|
||||||
|
StateClass state
|
||||||
|
Transition action
|
||||||
|
all_states = self.moves.init_batch(docs)
|
||||||
|
states = []
|
||||||
|
to_cut = []
|
||||||
|
for state, doc in zip(all_states, docs):
|
||||||
|
if not state.is_final():
|
||||||
|
if len(doc) < max_length:
|
||||||
|
states.append(state)
|
||||||
|
else:
|
||||||
|
to_cut.append(state)
|
||||||
|
while to_cut:
|
||||||
|
states.extend(state.copy() for state in to_cut)
|
||||||
|
# Move states forward max_length actions.
|
||||||
|
length = 0
|
||||||
|
while to_cut and length < max_length:
|
||||||
|
teacher_scores = teacher_step_model.predict(to_cut)
|
||||||
|
self.transition_states(to_cut, teacher_scores)
|
||||||
|
# States that are completed do not need further cutting.
|
||||||
|
to_cut = [state for state in to_cut if not state.is_final()]
|
||||||
|
length += 1
|
||||||
|
return states
|
||||||
|
|
||||||
|
|
||||||
def _init_gold_batch(self, examples, max_length):
|
def _init_gold_batch(self, examples, max_length):
|
||||||
"""Make a square batch, of length equal to the shortest transition
|
"""Make a square batch, of length equal to the shortest transition
|
||||||
sequence or a cap. A long
|
sequence or a cap. A long
|
||||||
|
|
|
@ -617,6 +617,44 @@ def test_overfitting_IO(use_upper):
|
||||||
assert ents[1].kb_id == 0
|
assert ents[1].kb_id == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_distill():
|
||||||
|
teacher = English()
|
||||||
|
teacher_ner = teacher.add_pipe("ner")
|
||||||
|
train_examples = []
|
||||||
|
for text, annotations in TRAIN_DATA:
|
||||||
|
train_examples.append(Example.from_dict(teacher.make_doc(text), annotations))
|
||||||
|
for ent in annotations.get("entities"):
|
||||||
|
teacher_ner.add_label(ent[2])
|
||||||
|
|
||||||
|
optimizer = teacher.initialize(get_examples=lambda: train_examples)
|
||||||
|
|
||||||
|
for i in range(50):
|
||||||
|
losses = {}
|
||||||
|
teacher.update(train_examples, sgd=optimizer, losses=losses)
|
||||||
|
assert losses["ner"] < 0.00001
|
||||||
|
|
||||||
|
student = English()
|
||||||
|
student_ner = student.add_pipe("ner")
|
||||||
|
student_ner.initialize(
|
||||||
|
get_examples=lambda: train_examples, labels=teacher_ner.label_data
|
||||||
|
)
|
||||||
|
|
||||||
|
docs = [eg.predicted for eg in train_examples]
|
||||||
|
|
||||||
|
for i in range(100):
|
||||||
|
losses = {}
|
||||||
|
student_ner.distill(teacher_ner, docs, docs, sgd=optimizer, losses=losses)
|
||||||
|
assert losses["ner"] < 0.0001
|
||||||
|
|
||||||
|
# test the trained model
|
||||||
|
test_text = "I like London."
|
||||||
|
doc = student(test_text)
|
||||||
|
ents = doc.ents
|
||||||
|
assert len(ents) == 1
|
||||||
|
assert ents[0].text == "London"
|
||||||
|
assert ents[0].label_ == "LOC"
|
||||||
|
|
||||||
|
|
||||||
def test_beam_ner_scores():
|
def test_beam_ner_scores():
|
||||||
# Test that we can get confidence values out of the beam_ner pipe
|
# Test that we can get confidence values out of the beam_ner pipe
|
||||||
beam_width = 16
|
beam_width = 16
|
||||||
|
|
|
@ -396,6 +396,45 @@ def test_overfitting_IO(pipe_name):
|
||||||
assert_equal(batch_deps_1, no_batch_deps)
|
assert_equal(batch_deps_1, no_batch_deps)
|
||||||
|
|
||||||
|
|
||||||
|
def test_distill():
|
||||||
|
teacher = English()
|
||||||
|
teacher_parser = teacher.add_pipe("parser")
|
||||||
|
train_examples = []
|
||||||
|
for text, annotations in TRAIN_DATA:
|
||||||
|
train_examples.append(Example.from_dict(teacher.make_doc(text), annotations))
|
||||||
|
for dep in annotations.get("deps", []):
|
||||||
|
teacher_parser.add_label(dep)
|
||||||
|
|
||||||
|
optimizer = teacher.initialize(get_examples=lambda: train_examples)
|
||||||
|
|
||||||
|
for i in range(200):
|
||||||
|
losses = {}
|
||||||
|
teacher.update(train_examples, sgd=optimizer, losses=losses)
|
||||||
|
assert losses["parser"] < 0.0001
|
||||||
|
|
||||||
|
student = English()
|
||||||
|
student_parser = student.add_pipe("parser")
|
||||||
|
student_parser.initialize(
|
||||||
|
get_examples=lambda: train_examples, labels=teacher_parser.label_data
|
||||||
|
)
|
||||||
|
|
||||||
|
docs = [eg.predicted for eg in train_examples]
|
||||||
|
|
||||||
|
for i in range(200):
|
||||||
|
losses = {}
|
||||||
|
student_parser.distill(teacher_parser, docs, docs, sgd=optimizer, losses=losses)
|
||||||
|
assert losses["parser"] < 0.0001
|
||||||
|
|
||||||
|
test_text = "I like securities."
|
||||||
|
doc = student(test_text)
|
||||||
|
assert doc[0].dep_ == "nsubj"
|
||||||
|
assert doc[2].dep_ == "dobj"
|
||||||
|
assert doc[3].dep_ == "punct"
|
||||||
|
assert doc[0].head.i == 1
|
||||||
|
assert doc[2].head.i == 1
|
||||||
|
assert doc[3].head.i == 1
|
||||||
|
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
@pytest.mark.parametrize("pipe_name", ["parser", "beam_parser"])
|
@pytest.mark.parametrize("pipe_name", ["parser", "beam_parser"])
|
||||||
|
|
|
@ -195,6 +195,45 @@ def test_overfitting_IO():
|
||||||
assert doc4[3].lemma_ == "egg"
|
assert doc4[3].lemma_ == "egg"
|
||||||
|
|
||||||
|
|
||||||
|
def test_distill():
|
||||||
|
teacher = English()
|
||||||
|
teacher_lemmatizer = teacher.add_pipe("trainable_lemmatizer")
|
||||||
|
teacher_lemmatizer.min_tree_freq = 1
|
||||||
|
train_examples = []
|
||||||
|
for t in TRAIN_DATA:
|
||||||
|
train_examples.append(Example.from_dict(teacher.make_doc(t[0]), t[1]))
|
||||||
|
|
||||||
|
optimizer = teacher.initialize(get_examples=lambda: train_examples)
|
||||||
|
|
||||||
|
for i in range(50):
|
||||||
|
losses = {}
|
||||||
|
teacher.update(train_examples, sgd=optimizer, losses=losses)
|
||||||
|
assert losses["trainable_lemmatizer"] < 0.00001
|
||||||
|
|
||||||
|
student = English()
|
||||||
|
student_lemmatizer = student.add_pipe("trainable_lemmatizer")
|
||||||
|
student_lemmatizer.min_tree_freq = 1
|
||||||
|
student_lemmatizer.initialize(
|
||||||
|
get_examples=lambda: train_examples, labels=teacher_lemmatizer.label_data
|
||||||
|
)
|
||||||
|
|
||||||
|
docs = [eg.predicted for eg in train_examples]
|
||||||
|
|
||||||
|
for i in range(50):
|
||||||
|
losses = {}
|
||||||
|
student_lemmatizer.distill(
|
||||||
|
teacher_lemmatizer, docs, docs, sgd=optimizer, losses=losses
|
||||||
|
)
|
||||||
|
assert losses["trainable_lemmatizer"] < 0.00001
|
||||||
|
|
||||||
|
test_text = "She likes blue eggs"
|
||||||
|
doc = student(test_text)
|
||||||
|
assert doc[0].lemma_ == "she"
|
||||||
|
assert doc[1].lemma_ == "like"
|
||||||
|
assert doc[2].lemma_ == "blue"
|
||||||
|
assert doc[3].lemma_ == "egg"
|
||||||
|
|
||||||
|
|
||||||
def test_lemmatizer_requires_labels():
|
def test_lemmatizer_requires_labels():
|
||||||
nlp = English()
|
nlp = English()
|
||||||
nlp.add_pipe("trainable_lemmatizer")
|
nlp.add_pipe("trainable_lemmatizer")
|
||||||
|
|
|
@ -213,6 +213,42 @@ def test_overfitting_IO():
|
||||||
assert doc3[0].tag_ != "N"
|
assert doc3[0].tag_ != "N"
|
||||||
|
|
||||||
|
|
||||||
|
def test_distill():
|
||||||
|
teacher = English()
|
||||||
|
teacher_tagger = teacher.add_pipe("tagger")
|
||||||
|
train_examples = []
|
||||||
|
for t in TRAIN_DATA:
|
||||||
|
train_examples.append(Example.from_dict(teacher.make_doc(t[0]), t[1]))
|
||||||
|
|
||||||
|
optimizer = teacher.initialize(get_examples=lambda: train_examples)
|
||||||
|
|
||||||
|
for i in range(50):
|
||||||
|
losses = {}
|
||||||
|
teacher.update(train_examples, sgd=optimizer, losses=losses)
|
||||||
|
assert losses["tagger"] < 0.00001
|
||||||
|
|
||||||
|
student = English()
|
||||||
|
student_tagger = student.add_pipe("tagger")
|
||||||
|
student_tagger.min_tree_freq = 1
|
||||||
|
student_tagger.initialize(
|
||||||
|
get_examples=lambda: train_examples, labels=teacher_tagger.label_data
|
||||||
|
)
|
||||||
|
|
||||||
|
docs = [eg.predicted for eg in train_examples]
|
||||||
|
|
||||||
|
for i in range(50):
|
||||||
|
losses = {}
|
||||||
|
student_tagger.distill(teacher_tagger, docs, docs, sgd=optimizer, losses=losses)
|
||||||
|
assert losses["tagger"] < 0.00001
|
||||||
|
|
||||||
|
test_text = "I like blue eggs"
|
||||||
|
doc = student(test_text)
|
||||||
|
assert doc[0].tag_ == "N"
|
||||||
|
assert doc[1].tag_ == "V"
|
||||||
|
assert doc[2].tag_ == "J"
|
||||||
|
assert doc[3].tag_ == "N"
|
||||||
|
|
||||||
|
|
||||||
def test_save_activations():
|
def test_save_activations():
|
||||||
# Test if activations are correctly added to Doc when requested.
|
# Test if activations are correctly added to Doc when requested.
|
||||||
nlp = English()
|
nlp = English()
|
||||||
|
|
|
@ -268,6 +268,27 @@ predicted scores.
|
||||||
| `scores` | Scores representing the model's predictions. ~~StateClass~~ |
|
| `scores` | Scores representing the model's predictions. ~~StateClass~~ |
|
||||||
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
|
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
|
||||||
|
|
||||||
|
## DependencyParser.get_teacher_student_loss {#get_teacher_student_loss tag="method" new="4"}
|
||||||
|
|
||||||
|
Calculate the loss and its gradient for the batch of student scores relative to
|
||||||
|
the teacher scores.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> teacher_parser = teacher.get_pipe("parser")
|
||||||
|
> student_parser = student.add_pipe("parser")
|
||||||
|
> student_scores = student_parser.predict([eg.predicted for eg in examples])
|
||||||
|
> teacher_scores = teacher_parser.predict([eg.predicted for eg in examples])
|
||||||
|
> loss, d_loss = student_parser.get_teacher_student_loss(teacher_scores, student_scores)
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| ---------------- | --------------------------------------------------------------------------- |
|
||||||
|
| `teacher_scores` | Scores representing the teacher model's predictions. |
|
||||||
|
| `student_scores` | Scores representing the student model's predictions. |
|
||||||
|
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
|
||||||
|
|
||||||
## DependencyParser.create_optimizer {#create_optimizer tag="method"}
|
## DependencyParser.create_optimizer {#create_optimizer tag="method"}
|
||||||
|
|
||||||
Create an [`Optimizer`](https://thinc.ai/docs/api-optimizers) for the pipeline
|
Create an [`Optimizer`](https://thinc.ai/docs/api-optimizers) for the pipeline
|
||||||
|
|
|
@ -269,6 +269,27 @@ Create an optimizer for the pipeline component.
|
||||||
| ----------- | ---------------------------- |
|
| ----------- | ---------------------------- |
|
||||||
| **RETURNS** | The optimizer. ~~Optimizer~~ |
|
| **RETURNS** | The optimizer. ~~Optimizer~~ |
|
||||||
|
|
||||||
|
## EditTreeLemmatizer.get_teacher_student_loss {#get_teacher_student_loss tag="method" new="4"}
|
||||||
|
|
||||||
|
Calculate the loss and its gradient for the batch of student scores relative to
|
||||||
|
the teacher scores.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> teacher_lemmatizer = teacher.get_pipe("trainable_lemmatizer")
|
||||||
|
> student_lemmatizer = student.add_pipe("trainable_lemmatizer")
|
||||||
|
> student_scores = student_lemmatizer.predict([eg.predicted for eg in examples])
|
||||||
|
> teacher_scores = teacher_lemmatizer.predict([eg.predicted for eg in examples])
|
||||||
|
> loss, d_loss = student_lemmatizer.get_teacher_student_loss(teacher_scores, student_scores)
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| ---------------- | --------------------------------------------------------------------------- |
|
||||||
|
| `teacher_scores` | Scores representing the teacher model's predictions. |
|
||||||
|
| `student_scores` | Scores representing the student model's predictions. |
|
||||||
|
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
|
||||||
|
|
||||||
## EditTreeLemmatizer.use_params {#use_params tag="method, contextmanager"}
|
## EditTreeLemmatizer.use_params {#use_params tag="method, contextmanager"}
|
||||||
|
|
||||||
Modify the pipe's model, to use the given parameter values. At the end of the
|
Modify the pipe's model, to use the given parameter values. At the end of the
|
||||||
|
|
|
@ -259,6 +259,27 @@ predicted scores.
|
||||||
| `scores` | Scores representing the model's predictions. |
|
| `scores` | Scores representing the model's predictions. |
|
||||||
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
|
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
|
||||||
|
|
||||||
|
## Morphologizer.get_teacher_student_loss {#get_teacher_student_loss tag="method" new="4"}
|
||||||
|
|
||||||
|
Calculate the loss and its gradient for the batch of student scores relative to
|
||||||
|
the teacher scores.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> teacher_morphologizer = teacher.get_pipe("morphologizer")
|
||||||
|
> student_morphologizer = student.add_pipe("morphologizer")
|
||||||
|
> student_scores = student_morphologizer.predict([eg.predicted for eg in examples])
|
||||||
|
> teacher_scores = teacher_morphologizer.predict([eg.predicted for eg in examples])
|
||||||
|
> loss, d_loss = student_morphologizer.get_teacher_student_loss(teacher_scores, student_scores)
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| ---------------- | --------------------------------------------------------------------------- |
|
||||||
|
| `teacher_scores` | Scores representing the teacher model's predictions. |
|
||||||
|
| `student_scores` | Scores representing the student model's predictions. |
|
||||||
|
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
|
||||||
|
|
||||||
## Morphologizer.create_optimizer {#create_optimizer tag="method"}
|
## Morphologizer.create_optimizer {#create_optimizer tag="method"}
|
||||||
|
|
||||||
Create an optimizer for the pipeline component.
|
Create an optimizer for the pipeline component.
|
||||||
|
|
|
@ -234,6 +234,33 @@ predictions and gold-standard annotations, and update the component's model.
|
||||||
| `losses` | Optional record of the loss during training. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ |
|
| `losses` | Optional record of the loss during training. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ |
|
||||||
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
|
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
|
||||||
|
|
||||||
|
## TrainablePipe.distill {#rehearse tag="method,experimental" new="4"}
|
||||||
|
|
||||||
|
Train a pipe (the student) on the predictions of another pipe (the teacher). The
|
||||||
|
student is typically trained on the probability distribution of the teacher, but
|
||||||
|
details may differ per pipe. The goal of distillation is to transfer knowledge
|
||||||
|
from the teacher to the student. This feature is experimental.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> teacher_pipe = teacher.add_pipe("your_custom_pipe")
|
||||||
|
> student_pipe = student.add_pipe("your_custom_pipe")
|
||||||
|
> optimizer = nlp.resume_training()
|
||||||
|
> losses = student.distill(teacher_pipe, teacher_docs, student_docs, sgd=optimizer)
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| -------------- | -------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
|
| `teacher_pipe` | The teacher pipe to learn from. ~~Optional[TrainablePipe]~~ |
|
||||||
|
| `teacher_docs` | Documents passed through teacher pipes. ~~Iterable[Doc]~~ |
|
||||||
|
| `student_docs` | Documents passed through student pipes. Must contain the same tokens as `teacher_docs` but may have different annotations. ~~Iterable[Doc]~~ |
|
||||||
|
| _keyword-only_ | |
|
||||||
|
| `drop` | Dropout rate. ~~float~~ |
|
||||||
|
| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ |
|
||||||
|
| `losses` | Optional record of the loss during distillation. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ |
|
||||||
|
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
|
||||||
|
|
||||||
## TrainablePipe.rehearse {#rehearse tag="method,experimental" new="3"}
|
## TrainablePipe.rehearse {#rehearse tag="method,experimental" new="3"}
|
||||||
|
|
||||||
Perform a "rehearsal" update from a batch of data. Rehearsal updates teach the
|
Perform a "rehearsal" update from a batch of data. Rehearsal updates teach the
|
||||||
|
@ -281,6 +308,33 @@ This method needs to be overwritten with your own custom `get_loss` method.
|
||||||
| `scores` | Scores representing the model's predictions. |
|
| `scores` | Scores representing the model's predictions. |
|
||||||
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
|
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
|
||||||
|
|
||||||
|
## TrainablePipe.get_teacher_student_loss {#get_teacher_student_loss tag="method"}
|
||||||
|
|
||||||
|
Calculate the loss and its gradient for the batch of student scores relative to
|
||||||
|
the teacher scores.
|
||||||
|
|
||||||
|
<Infobox variant="danger">
|
||||||
|
|
||||||
|
This method needs to be overwritten with your own custom `get_teacher_student_loss` method.
|
||||||
|
|
||||||
|
</Infobox>
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> teacher_pipe = teacher.get_pipe("your_custom_pipe")
|
||||||
|
> student_pipe = student.add_pipe("your_custom_pipe")
|
||||||
|
> student_scores = student_pipe.predict([eg.predicted for eg in examples])
|
||||||
|
> teacher_scores = teacher_pipe.predict([eg.predicted for eg in examples])
|
||||||
|
> loss, d_loss = student_pipe.get_teacher_student_loss(teacher_scores, student_scores)
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| ---------------- | --------------------------------------------------------------------------- |
|
||||||
|
| `teacher_scores` | Scores representing the teacher model's predictions. |
|
||||||
|
| `student_scores` | Scores representing the student model's predictions. |
|
||||||
|
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
|
||||||
|
|
||||||
## TrainablePipe.score {#score tag="method" new="3"}
|
## TrainablePipe.score {#score tag="method" new="3"}
|
||||||
|
|
||||||
Score a batch of examples.
|
Score a batch of examples.
|
||||||
|
|
|
@ -254,6 +254,27 @@ predicted scores.
|
||||||
| `scores` | Scores representing the model's predictions. |
|
| `scores` | Scores representing the model's predictions. |
|
||||||
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
|
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
|
||||||
|
|
||||||
|
## SentenceRecognizer.get_teacher_student_loss {#get_teacher_student_loss tag="method" new="4"}
|
||||||
|
|
||||||
|
Calculate the loss and its gradient for the batch of student scores relative to
|
||||||
|
the teacher scores.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> teacher_senter = teacher.get_pipe("senter")
|
||||||
|
> student_senter = student.add_pipe("senter")
|
||||||
|
> student_scores = student_senter.predict([eg.predicted for eg in examples])
|
||||||
|
> teacher_scores = teacher_senter.predict([eg.predicted for eg in examples])
|
||||||
|
> loss, d_loss = student_senter.get_teacher_student_loss(teacher_scores, student_scores)
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| ---------------- | --------------------------------------------------------------------------- |
|
||||||
|
| `teacher_scores` | Scores representing the teacher model's predictions. |
|
||||||
|
| `student_scores` | Scores representing the student model's predictions. |
|
||||||
|
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
|
||||||
|
|
||||||
## SentenceRecognizer.create_optimizer {#create_optimizer tag="method"}
|
## SentenceRecognizer.create_optimizer {#create_optimizer tag="method"}
|
||||||
|
|
||||||
Create an optimizer for the pipeline component.
|
Create an optimizer for the pipeline component.
|
||||||
|
|
|
@ -265,6 +265,27 @@ predicted scores.
|
||||||
| `scores` | Scores representing the model's predictions. |
|
| `scores` | Scores representing the model's predictions. |
|
||||||
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
|
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
|
||||||
|
|
||||||
|
## Tagger.get_teacher_student_loss {#get_teacher_student_loss tag="method" new="4"}
|
||||||
|
|
||||||
|
Calculate the loss and its gradient for the batch of student scores relative to
|
||||||
|
the teacher scores.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> teacher_tagger = teacher.get_pipe("tagger")
|
||||||
|
> student_tagger = student.add_pipe("tagger")
|
||||||
|
> student_scores = student_tagger.predict([eg.predicted for eg in examples])
|
||||||
|
> teacher_scores = teacher_tagger.predict([eg.predicted for eg in examples])
|
||||||
|
> loss, d_loss = student_tagger.get_teacher_student_loss(teacher_scores, student_scores)
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| ---------------- | --------------------------------------------------------------------------- |
|
||||||
|
| `teacher_scores` | Scores representing the teacher model's predictions. |
|
||||||
|
| `student_scores` | Scores representing the student model's predictions. |
|
||||||
|
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
|
||||||
|
|
||||||
## Tagger.create_optimizer {#create_optimizer tag="method"}
|
## Tagger.create_optimizer {#create_optimizer tag="method"}
|
||||||
|
|
||||||
Create an optimizer for the pipeline component.
|
Create an optimizer for the pipeline component.
|
||||||
|
|
|
@ -899,7 +899,8 @@ backprop passes.
|
||||||
Recursively wrap both the models and methods of each pipe using
|
Recursively wrap both the models and methods of each pipe using
|
||||||
[NVTX](https://nvidia.github.io/NVTX/) range markers. By default, the following
|
[NVTX](https://nvidia.github.io/NVTX/) range markers. By default, the following
|
||||||
methods are wrapped: `pipe`, `predict`, `set_annotations`, `update`, `rehearse`,
|
methods are wrapped: `pipe`, `predict`, `set_annotations`, `update`, `rehearse`,
|
||||||
`get_loss`, `initialize`, `begin_update`, `finish_update`, `update`.
|
`get_loss`, `get_student_teacher_loss`, `initialize`, `begin_update`,
|
||||||
|
`finish_update`, `update`.
|
||||||
|
|
||||||
| Name | Description |
|
| Name | Description |
|
||||||
| --------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
| --------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
|
|
|
@ -1369,12 +1369,14 @@ For some use cases, it makes sense to also overwrite additional methods to
|
||||||
customize how the model is updated from examples, how it's initialized, how the
|
customize how the model is updated from examples, how it's initialized, how the
|
||||||
loss is calculated and to add evaluation scores to the training output.
|
loss is calculated and to add evaluation scores to the training output.
|
||||||
|
|
||||||
| Name | Description |
|
| Name | Description |
|
||||||
| ------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
| ---------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
| [`update`](/api/pipe#update) | Learn from a batch of [`Example`](/api/example) objects containing the predictions and gold-standard annotations, and update the component's model. |
|
| [`update`](/api/pipe#update) | Learn from a batch of [`Example`](/api/example) objects containing the predictions and gold-standard annotations, and update the component's model. |
|
||||||
| [`initialize`](/api/pipe#initialize) | Initialize the model. Typically calls into [`Model.initialize`](https://thinc.ai/docs/api-model#initialize) and can be passed custom arguments via the [`[initialize]`](/api/data-formats#config-initialize) config block that are only loaded during training or when you call [`nlp.initialize`](/api/language#initialize), not at runtime. |
|
| [`distill`](/api/pipe#distill) | Learn from a teacher pipeline using a batch of [`Doc`](/api/doc) objects and update the component's model. |
|
||||||
| [`get_loss`](/api/pipe#get_loss) | Return a tuple of the loss and the gradient for a batch of [`Example`](/api/example) objects. |
|
| [`initialize`](/api/pipe#initialize) | Initialize the model. Typically calls into [`Model.initialize`](https://thinc.ai/docs/api-model#initialize) and can be passed custom arguments via the [`[initialize]`](/api/data-formats#config-initialize) config block that are only loaded during training or when you call [`nlp.initialize`](/api/language#initialize), not at runtime. |
|
||||||
| [`score`](/api/pipe#score) | Score a batch of [`Example`](/api/example) objects and return a dictionary of scores. The [`@Language.factory`](/api/language#factory) decorator can define the `default_score_weights` of the component to decide which keys of the scores to display during training and how they count towards the final score. |
|
| [`get_loss`](/api/pipe#get_loss) | Return a tuple of the loss and the gradient for a batch of [`Example`](/api/example) objects. |
|
||||||
|
| [`get_teacher_student_loss`](/api/pipe#get_teacher_student_loss) | Return a tuple of the loss and the gradient for the student scores relative to the teacher scores. |
|
||||||
|
| [`score`](/api/pipe#score) | Score a batch of [`Example`](/api/example) objects and return a dictionary of scores. The [`@Language.factory`](/api/language#factory) decorator can define the `default_score_weights` of the component to decide which keys of the scores to display during training and how they count towards the final score. |
|
||||||
|
|
||||||
<Infobox title="Custom trainable components and models" emoji="📖">
|
<Infobox title="Custom trainable components and models" emoji="📖">
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user