Add Language.distill (#12116)

* Add `Language.distill`

This method is the distillation counterpart of `Language.update`.  It
takes a teacher `Language` instance and distills the student pipes on
the teacher pipes.

* Apply suggestions from code review

Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>

* Clarify that how Example is used in distillation

* Update transition parser distill docstring for examples argument

* Pass optimizer to `TrainablePipe.distill`

* Annotate pipe before update

As discussed internally, we want to let a pipe annotate before doing an
update with gold/silver data. Otherwise, the output may be (too)
informed by the gold/silver data.

* Rename `component_map` to `student_to_teacher`

* Better synopsis in `Language.distill` docstring

* `name` -> `student_name`

* Fix labels type in docstring

* Mark distill test as slow

* Fix `student_to_teacher` type in docs

---------

Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>
This commit is contained in:
Daniël de Kok 2023-01-30 12:44:11 +01:00 committed by GitHub
parent ec45f704b1
commit 6b07be2110
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 290 additions and 68 deletions

View File

@ -22,7 +22,7 @@ from . import ty
from .tokens.underscore import Underscore from .tokens.underscore import Underscore
from .vocab import Vocab, create_vocab from .vocab import Vocab, create_vocab
from .pipe_analysis import validate_attrs, analyze_pipes, print_pipe_analysis from .pipe_analysis import validate_attrs, analyze_pipes, print_pipe_analysis
from .training import Example, validate_examples from .training import Example, validate_examples, validate_distillation_examples
from .training.initialize import init_vocab, init_tok2vec from .training.initialize import init_vocab, init_tok2vec
from .scorer import Scorer from .scorer import Scorer
from .util import registry, SimpleFrozenList, _pipe, raise_error, _DEFAULT_EMPTY_PIPES from .util import registry, SimpleFrozenList, _pipe, raise_error, _DEFAULT_EMPTY_PIPES
@ -1017,6 +1017,102 @@ class Language:
raise ValueError(Errors.E005.format(name=name, returned_type=type(doc))) raise ValueError(Errors.E005.format(name=name, returned_type=type(doc)))
return doc return doc
def distill(
self,
teacher: "Language",
examples: Iterable[Example],
*,
drop: float = 0.0,
sgd: Optional[Optimizer] = None,
losses: Optional[Dict[str, float]] = None,
component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
exclude: Iterable[str] = SimpleFrozenList(),
annotates: Iterable[str] = SimpleFrozenList(),
student_to_teacher: Optional[Dict[str, str]] = None,
):
"""Distill the models in a student pipeline from a teacher pipeline.
teacher (Language): Teacher to distill from.
examples (Iterable[Example]): Distillation examples. The reference
(teacher) and predicted (student) docs must have the same number of
tokens and the same orthography.
drop (float): The dropout rate.
sgd (Optional[Optimizer]): An optimizer.
losses (Optional(Dict[str, float])): Dictionary to update with the loss,
keyed by component.
component_cfg (Optional[Dict[str, Dict[str, Any]]]): Config parameters
for specific pipeline components, keyed by component name.
exclude (Iterable[str]): Names of components that shouldn't be updated.
annotates (Iterable[str]): Names of components that should set
annotations on the predicted examples after updating.
student_to_teacher (Optional[Dict[str, str]]): Map student pipe name to
teacher pipe name, only needed for pipes where the student pipe
name does not match the teacher pipe name.
RETURNS (Dict[str, float]): The updated losses dictionary
DOCS: https://spacy.io/api/language#distill
"""
if student_to_teacher is None:
student_to_teacher = {}
if losses is None:
losses = {}
if isinstance(examples, list) and len(examples) == 0:
return losses
validate_distillation_examples(examples, "Language.distill")
examples = _copy_examples(examples)
if sgd is None:
if self._optimizer is None:
self._optimizer = self.create_optimizer()
sgd = self._optimizer
if component_cfg is None:
component_cfg = {}
pipe_kwargs = {}
for student_name, student_proc in self.pipeline:
component_cfg.setdefault(student_name, {})
pipe_kwargs[student_name] = deepcopy(component_cfg[student_name])
component_cfg[student_name].setdefault("drop", drop)
pipe_kwargs[student_name].setdefault("batch_size", self.batch_size)
teacher_pipes = dict(teacher.pipeline)
for student_name, student_proc in self.pipeline:
if student_name in annotates:
for doc, eg in zip(
_pipe(
(eg.predicted for eg in examples),
proc=student_proc,
name=student_name,
default_error_handler=self.default_error_handler,
kwargs=pipe_kwargs[student_name],
),
examples,
):
eg.predicted = doc
if (
student_name not in exclude
and isinstance(student_proc, ty.DistillableComponent)
and student_proc.is_distillable
):
# A missing teacher pipe is not an error, some student pipes
# do not need a teacher, such as tok2vec layer losses.
teacher_name = (
student_to_teacher[student_name]
if student_name in student_to_teacher
else student_name
)
teacher_pipe = teacher_pipes.get(teacher_name, None)
student_proc.distill(
teacher_pipe,
examples,
sgd=sgd,
losses=losses,
**component_cfg[student_name],
)
return losses
def disable_pipes(self, *names) -> "DisabledPipes": def disable_pipes(self, *names) -> "DisabledPipes":
"""Disable one or more pipeline components. If used as a context """Disable one or more pipeline components. If used as a context
manager, the pipeline will be restored to the initial state at the end manager, the pipeline will be restored to the initial state at the end
@ -1242,12 +1338,16 @@ class Language:
self, self,
get_examples: Optional[Callable[[], Iterable[Example]]] = None, get_examples: Optional[Callable[[], Iterable[Example]]] = None,
*, *,
labels: Optional[Dict[str, Any]] = None,
sgd: Optional[Optimizer] = None, sgd: Optional[Optimizer] = None,
) -> Optimizer: ) -> Optimizer:
"""Initialize the pipe for training, using data examples if available. """Initialize the pipe for training, using data examples if available.
get_examples (Callable[[], Iterable[Example]]): Optional function that get_examples (Callable[[], Iterable[Example]]): Optional function that
returns gold-standard Example objects. returns gold-standard Example objects.
labels (Optional[Dict[str, Any]]): Labels to pass to pipe initialization,
using the names of the pipes as keys. Overrides labels that are in
the model configuration.
sgd (Optional[Optimizer]): An optimizer to use for updates. If not sgd (Optional[Optimizer]): An optimizer to use for updates. If not
provided, will be created using the .create_optimizer() method. provided, will be created using the .create_optimizer() method.
RETURNS (thinc.api.Optimizer): The optimizer. RETURNS (thinc.api.Optimizer): The optimizer.
@ -1292,6 +1392,8 @@ class Language:
for name, proc in self.pipeline: for name, proc in self.pipeline:
if isinstance(proc, ty.InitializableComponent): if isinstance(proc, ty.InitializableComponent):
p_settings = I["components"].get(name, {}) p_settings = I["components"].get(name, {})
if labels is not None and name in labels:
p_settings["labels"] = labels[name]
p_settings = validate_init_settings( p_settings = validate_init_settings(
proc.initialize, p_settings, section="components", name=name proc.initialize, p_settings, section="components", name=name
) )
@ -1725,6 +1827,7 @@ class Language:
# using the nlp.config with all defaults. # using the nlp.config with all defaults.
config = util.copy_config(config) config = util.copy_config(config)
orig_pipeline = config.pop("components", {}) orig_pipeline = config.pop("components", {})
orig_distill = config.pop("distill", None)
orig_pretraining = config.pop("pretraining", None) orig_pretraining = config.pop("pretraining", None)
config["components"] = {} config["components"] = {}
if auto_fill: if auto_fill:
@ -1733,6 +1836,9 @@ class Language:
filled = config filled = config
filled["components"] = orig_pipeline filled["components"] = orig_pipeline
config["components"] = orig_pipeline config["components"] = orig_pipeline
if orig_distill is not None:
filled["distill"] = orig_distill
config["distill"] = orig_distill
if orig_pretraining is not None: if orig_pretraining is not None:
filled["pretraining"] = orig_pretraining filled["pretraining"] = orig_pretraining
config["pretraining"] = orig_pretraining config["pretraining"] = orig_pretraining

View File

@ -71,8 +71,8 @@ cdef class TrainablePipe(Pipe):
teacher_pipe (Optional[TrainablePipe]): The teacher pipe to learn teacher_pipe (Optional[TrainablePipe]): The teacher pipe to learn
from. from.
examples (Iterable[Example]): Distillation examples. The reference examples (Iterable[Example]): Distillation examples. The reference
and predicted docs must have the same number of tokens and the (teacher) and predicted (student) docs must have the same number of
same orthography. tokens and the same orthography.
drop (float): dropout rate. drop (float): dropout rate.
sgd (Optional[Optimizer]): An optimizer. Will be created via sgd (Optional[Optimizer]): An optimizer. Will be created via
create_optimizer if not set. create_optimizer if not set.

View File

@ -224,8 +224,8 @@ class Parser(TrainablePipe):
teacher_pipe (Optional[TrainablePipe]): The teacher pipe to learn teacher_pipe (Optional[TrainablePipe]): The teacher pipe to learn
from. from.
examples (Iterable[Example]): Distillation examples. The reference examples (Iterable[Example]): Distillation examples. The reference
and predicted docs must have the same number of tokens and the (teacher) and predicted (student) docs must have the same number of
same orthography. tokens and the same orthography.
drop (float): dropout rate. drop (float): dropout rate.
sgd (Optional[Optimizer]): An optimizer. Will be created via sgd (Optional[Optimizer]): An optimizer. Will be created via
create_optimizer if not set. create_optimizer if not set.

View File

@ -26,6 +26,12 @@ except ImportError:
pass pass
TAGGER_TRAIN_DATA = [
("I like green eggs", {"tags": ["N", "V", "J", "N"]}),
("Eat blue ham", {"tags": ["V", "J", "N"]}),
]
def evil_component(doc): def evil_component(doc):
if "2" in doc.text: if "2" in doc.text:
raise ValueError("no dice") raise ValueError("no dice")
@ -799,3 +805,66 @@ def test_component_return():
nlp.add_pipe("test_component_bad_pipe") nlp.add_pipe("test_component_bad_pipe")
with pytest.raises(ValueError, match="instead of a Doc"): with pytest.raises(ValueError, match="instead of a Doc"):
nlp("text") nlp("text")
@pytest.mark.slow
@pytest.mark.parametrize("teacher_tagger_name", ["tagger", "teacher_tagger"])
def test_distill(teacher_tagger_name):
teacher = English()
teacher_tagger = teacher.add_pipe("tagger", name=teacher_tagger_name)
train_examples = []
for t in TAGGER_TRAIN_DATA:
train_examples.append(Example.from_dict(teacher.make_doc(t[0]), t[1]))
optimizer = teacher.initialize(get_examples=lambda: train_examples)
for i in range(50):
losses = {}
teacher.update(train_examples, sgd=optimizer, losses=losses)
assert losses[teacher_tagger_name] < 0.00001
student = English()
student_tagger = student.add_pipe("tagger")
student_tagger.min_tree_freq = 1
student_tagger.initialize(
get_examples=lambda: train_examples, labels=teacher_tagger.label_data
)
distill_examples = [
Example.from_dict(teacher.make_doc(t[0]), {}) for t in TAGGER_TRAIN_DATA
]
student_to_teacher = (
None
if teacher_tagger.name == student_tagger.name
else {student_tagger.name: teacher_tagger.name}
)
for i in range(50):
losses = {}
student.distill(
teacher,
distill_examples,
sgd=optimizer,
losses=losses,
student_to_teacher=student_to_teacher,
)
assert losses["tagger"] < 0.00001
test_text = "I like blue eggs"
doc = student(test_text)
assert doc[0].tag_ == "N"
assert doc[1].tag_ == "V"
assert doc[2].tag_ == "J"
assert doc[3].tag_ == "N"
# Do an extra update to check if annotates works, though we can't really
# validate the resuls, since the annotations are ephemeral.
student.distill(
teacher,
distill_examples,
sgd=optimizer,
losses=losses,
student_to_teacher=student_to_teacher,
annotates=["tagger"],
)

View File

@ -26,6 +26,25 @@ class TrainableComponent(Protocol):
... ...
@runtime_checkable
class DistillableComponent(Protocol):
is_distillable: bool
def distill(
self,
teacher_pipe: Optional[TrainableComponent],
examples: Iterable["Example"],
*,
drop: float = 0.0,
sgd: Optional[Optimizer] = None,
losses: Optional[Dict[str, float]] = None
) -> Dict[str, float]:
...
def finish_update(self, sgd: Optimizer) -> None:
...
@runtime_checkable @runtime_checkable
class InitializableComponent(Protocol): class InitializableComponent(Protocol):
def initialize( def initialize(

View File

@ -155,9 +155,9 @@ This feature is experimental.
> ``` > ```
| Name | Description | | Name | Description |
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------- | | -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `teacher_pipe` | The teacher pipe to learn from. ~~Optional[TrainablePipe]~~ | | `teacher_pipe` | The teacher pipe to learn from. ~~Optional[TrainablePipe]~~ |
| `examples` | Distillation examples. The reference and predicted docs must have the same number of tokens and the same orthography. ~~Iterable[Example]~~ | | `examples` | A batch of [`Example`](/api/example) distillation examples. The reference (teacher) and predicted (student) docs must have the same number of tokens and orthography. ~~Iterable[Example]~~ |
| _keyword-only_ | | | _keyword-only_ | |
| `drop` | Dropout rate. ~~float~~ | | `drop` | Dropout rate. ~~float~~ |
| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ | | `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ |

View File

@ -139,9 +139,9 @@ This feature is experimental.
> ``` > ```
| Name | Description | | Name | Description |
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------- | | -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `teacher_pipe` | The teacher pipe to learn from. ~~Optional[TrainablePipe]~~ | | `teacher_pipe` | The teacher pipe to learn from. ~~Optional[TrainablePipe]~~ |
| `examples` | Distillation examples. The reference and predicted docs must have the same number of tokens and the same orthography. ~~Iterable[Example]~~ | | `examples` | A batch of [`Example`](/api/example) distillation examples. The reference (teacher) and predicted (student) docs must have the same number of tokens and orthography. ~~Iterable[Example]~~ |
| _keyword-only_ | | | _keyword-only_ | |
| `drop` | Dropout rate. ~~float~~ | | `drop` | Dropout rate. ~~float~~ |
| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ | | `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ |

View File

@ -151,9 +151,9 @@ This feature is experimental.
> ``` > ```
| Name | Description | | Name | Description |
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------- | | -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `teacher_pipe` | The teacher pipe to learn from. ~~Optional[TrainablePipe]~~ | | `teacher_pipe` | The teacher pipe to learn from. ~~Optional[TrainablePipe]~~ |
| `examples` | Distillation examples. The reference and predicted docs must have the same number of tokens and the same orthography. ~~Iterable[Example]~~ | | `examples` | A batch of [`Example`](/api/example) distillation examples. The reference (teacher) and predicted (student) docs must have the same number of tokens and orthography. ~~Iterable[Example]~~ |
| _keyword-only_ | | | _keyword-only_ | |
| `drop` | Dropout rate. ~~float~~ | | `drop` | Dropout rate. ~~float~~ |
| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ | | `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ |

View File

@ -333,6 +333,34 @@ and custom registered functions if needed. See the
| `component_cfg` | Optional dictionary of keyword arguments for components, keyed by component names. Defaults to `None`. ~~Optional[Dict[str, Dict[str, Any]]]~~ | | `component_cfg` | Optional dictionary of keyword arguments for components, keyed by component names. Defaults to `None`. ~~Optional[Dict[str, Dict[str, Any]]]~~ |
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ | | **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
## Language.distill {id="distill",tag="method,experimental",version="4"}
Distill the models in a student pipeline from a teacher pipeline.
> #### Example
>
> ```python
>
> teacher = spacy.load("en_core_web_lg")
> student = English()
> student.add_pipe("tagger")
> student.distill(teacher, examples, sgd=optimizer)
> ```
| Name | Description |
| -------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `teacher` | The teacher pipeline to distill from. ~~Language~~ |
| `examples` | A batch of [`Example`](/api/example) distillation examples. The reference (teacher) and predicted (student) docs must have the same number of tokens and orthography. ~~Iterable[Example]~~ |
| _keyword-only_ | |
| `drop` | The dropout rate. ~~float~~ |
| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ |
| `losses` | Dictionary to update with the loss, keyed by pipeline component. ~~Optional[Dict[str, float]]~~ |
| `component_cfg` | Optional dictionary of keyword arguments for components, keyed by component names. Defaults to `None`. ~~Optional[Dict[str, Dict[str, Any]]]~~ |
| `exclude` | Names of components that shouldn't be updated. Defaults to `[]`. ~~Iterable[str]~~ |
| `annotates` | Names of components that should set annotations on the prediced examples after updating. Defaults to `[]`. ~~Iterable[str]~~ |
| `student_to_teacher` | Map student component names to teacher component names, only necessary when the names differ. Defaults to `None`. ~~Optional[Dict[str, str]]~~ |
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
## Language.rehearse {id="rehearse",tag="method,experimental",version="3"} ## Language.rehearse {id="rehearse",tag="method,experimental",version="3"}
Perform a "rehearsal" update from a batch of data. Rehearsal updates teach the Perform a "rehearsal" update from a batch of data. Rehearsal updates teach the

View File

@ -145,9 +145,9 @@ This feature is experimental.
> ``` > ```
| Name | Description | | Name | Description |
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------- | | -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `teacher_pipe` | The teacher pipe to learn from. ~~Optional[TrainablePipe]~~ | | `teacher_pipe` | The teacher pipe to learn from. ~~Optional[TrainablePipe]~~ |
| `examples` | Distillation examples. The reference and predicted docs must have the same number of tokens and the same orthography. ~~Iterable[Example]~~ | | `examples` | A batch of [`Example`](/api/example) distillation examples. The reference (teacher) and predicted (student) docs must have the same number of tokens and orthography. ~~Iterable[Example]~~ |
| _keyword-only_ | | | _keyword-only_ | |
| `drop` | Dropout rate. ~~float~~ | | `drop` | Dropout rate. ~~float~~ |
| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ | | `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ |

View File

@ -258,9 +258,9 @@ This feature is experimental.
> ``` > ```
| Name | Description | | Name | Description |
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------- | | -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `teacher_pipe` | The teacher pipe to learn from. ~~Optional[TrainablePipe]~~ | | `teacher_pipe` | The teacher pipe to learn from. ~~Optional[TrainablePipe]~~ |
| `examples` | Distillation examples. The reference and predicted docs must have the same number of tokens and the same orthography. ~~Iterable[Example]~~ | | `examples` | A batch of [`Example`](/api/example) distillation examples. The reference (teacher) and predicted (student) docs must have the same number of tokens and orthography. ~~Iterable[Example]~~ |
| _keyword-only_ | | | _keyword-only_ | |
| `drop` | Dropout rate. ~~float~~ | | `drop` | Dropout rate. ~~float~~ |
| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ | | `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ |

View File

@ -130,9 +130,9 @@ This feature is experimental.
> ``` > ```
| Name | Description | | Name | Description |
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------- | | -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `teacher_pipe` | The teacher pipe to learn from. ~~Optional[TrainablePipe]~~ | | `teacher_pipe` | The teacher pipe to learn from. ~~Optional[TrainablePipe]~~ |
| `examples` | Distillation examples. The reference and predicted docs must have the same number of tokens and the same orthography. ~~Iterable[Example]~~ | | `examples` | A batch of [`Example`](/api/example) distillation examples. The reference (teacher) and predicted (student) docs must have the same number of tokens and orthography. ~~Iterable[Example]~~ |
| _keyword-only_ | | | _keyword-only_ | |
| `drop` | Dropout rate. ~~float~~ | | `drop` | Dropout rate. ~~float~~ |
| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ | | `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ |

View File

@ -129,9 +129,9 @@ This feature is experimental.
> ``` > ```
| Name | Description | | Name | Description |
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------- | | -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `teacher_pipe` | The teacher pipe to learn from. ~~Optional[TrainablePipe]~~ | | `teacher_pipe` | The teacher pipe to learn from. ~~Optional[TrainablePipe]~~ |
| `examples` | Distillation examples. The reference and predicted docs must have the same number of tokens and the same orthography. ~~Iterable[Example]~~ | | `examples` | A batch of [`Example`](/api/example) distillation examples. The reference (teacher) and predicted (student) docs must have the same number of tokens and orthography. ~~Iterable[Example]~~ |
| _keyword-only_ | | | _keyword-only_ | |
| `drop` | Dropout rate. ~~float~~ | | `drop` | Dropout rate. ~~float~~ |
| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ | | `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ |