From 6b07be2110be529a96ead3362a7ef85e64170862 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Mon, 30 Jan 2023 12:44:11 +0100 Subject: [PATCH] Add `Language.distill` (#12116) * Add `Language.distill` This method is the distillation counterpart of `Language.update`. It takes a teacher `Language` instance and distills the student pipes on the teacher pipes. * Apply suggestions from code review Co-authored-by: Madeesh Kannan * Clarify that how Example is used in distillation * Update transition parser distill docstring for examples argument * Pass optimizer to `TrainablePipe.distill` * Annotate pipe before update As discussed internally, we want to let a pipe annotate before doing an update with gold/silver data. Otherwise, the output may be (too) informed by the gold/silver data. * Rename `component_map` to `student_to_teacher` * Better synopsis in `Language.distill` docstring * `name` -> `student_name` * Fix labels type in docstring * Mark distill test as slow * Fix `student_to_teacher` type in docs --------- Co-authored-by: Madeesh Kannan --- spacy/language.py | 108 +++++++++++++++++++++++- spacy/pipeline/trainable_pipe.pyx | 4 +- spacy/pipeline/transition_parser.pyx | 4 +- spacy/tests/test_language.py | 69 +++++++++++++++ spacy/ty.py | 19 +++++ website/docs/api/dependencyparser.mdx | 18 ++-- website/docs/api/edittreelemmatizer.mdx | 18 ++-- website/docs/api/entityrecognizer.mdx | 18 ++-- website/docs/api/language.mdx | 28 ++++++ website/docs/api/morphologizer.mdx | 18 ++-- website/docs/api/pipe.mdx | 18 ++-- website/docs/api/sentencerecognizer.mdx | 18 ++-- website/docs/api/tagger.mdx | 18 ++-- 13 files changed, 290 insertions(+), 68 deletions(-) diff --git a/spacy/language.py b/spacy/language.py index 50b60fb97..2e3c6d2a2 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -22,7 +22,7 @@ from . import ty from .tokens.underscore import Underscore from .vocab import Vocab, create_vocab from .pipe_analysis import validate_attrs, analyze_pipes, print_pipe_analysis -from .training import Example, validate_examples +from .training import Example, validate_examples, validate_distillation_examples from .training.initialize import init_vocab, init_tok2vec from .scorer import Scorer from .util import registry, SimpleFrozenList, _pipe, raise_error, _DEFAULT_EMPTY_PIPES @@ -1017,6 +1017,102 @@ class Language: raise ValueError(Errors.E005.format(name=name, returned_type=type(doc))) return doc + def distill( + self, + teacher: "Language", + examples: Iterable[Example], + *, + drop: float = 0.0, + sgd: Optional[Optimizer] = None, + losses: Optional[Dict[str, float]] = None, + component_cfg: Optional[Dict[str, Dict[str, Any]]] = None, + exclude: Iterable[str] = SimpleFrozenList(), + annotates: Iterable[str] = SimpleFrozenList(), + student_to_teacher: Optional[Dict[str, str]] = None, + ): + """Distill the models in a student pipeline from a teacher pipeline. + teacher (Language): Teacher to distill from. + examples (Iterable[Example]): Distillation examples. The reference + (teacher) and predicted (student) docs must have the same number of + tokens and the same orthography. + drop (float): The dropout rate. + sgd (Optional[Optimizer]): An optimizer. + losses (Optional(Dict[str, float])): Dictionary to update with the loss, + keyed by component. + component_cfg (Optional[Dict[str, Dict[str, Any]]]): Config parameters + for specific pipeline components, keyed by component name. + exclude (Iterable[str]): Names of components that shouldn't be updated. + annotates (Iterable[str]): Names of components that should set + annotations on the predicted examples after updating. + student_to_teacher (Optional[Dict[str, str]]): Map student pipe name to + teacher pipe name, only needed for pipes where the student pipe + name does not match the teacher pipe name. + RETURNS (Dict[str, float]): The updated losses dictionary + + DOCS: https://spacy.io/api/language#distill + """ + if student_to_teacher is None: + student_to_teacher = {} + if losses is None: + losses = {} + if isinstance(examples, list) and len(examples) == 0: + return losses + + validate_distillation_examples(examples, "Language.distill") + examples = _copy_examples(examples) + + if sgd is None: + if self._optimizer is None: + self._optimizer = self.create_optimizer() + sgd = self._optimizer + + if component_cfg is None: + component_cfg = {} + pipe_kwargs = {} + for student_name, student_proc in self.pipeline: + component_cfg.setdefault(student_name, {}) + pipe_kwargs[student_name] = deepcopy(component_cfg[student_name]) + component_cfg[student_name].setdefault("drop", drop) + pipe_kwargs[student_name].setdefault("batch_size", self.batch_size) + + teacher_pipes = dict(teacher.pipeline) + for student_name, student_proc in self.pipeline: + if student_name in annotates: + for doc, eg in zip( + _pipe( + (eg.predicted for eg in examples), + proc=student_proc, + name=student_name, + default_error_handler=self.default_error_handler, + kwargs=pipe_kwargs[student_name], + ), + examples, + ): + eg.predicted = doc + + if ( + student_name not in exclude + and isinstance(student_proc, ty.DistillableComponent) + and student_proc.is_distillable + ): + # A missing teacher pipe is not an error, some student pipes + # do not need a teacher, such as tok2vec layer losses. + teacher_name = ( + student_to_teacher[student_name] + if student_name in student_to_teacher + else student_name + ) + teacher_pipe = teacher_pipes.get(teacher_name, None) + student_proc.distill( + teacher_pipe, + examples, + sgd=sgd, + losses=losses, + **component_cfg[student_name], + ) + + return losses + def disable_pipes(self, *names) -> "DisabledPipes": """Disable one or more pipeline components. If used as a context manager, the pipeline will be restored to the initial state at the end @@ -1242,12 +1338,16 @@ class Language: self, get_examples: Optional[Callable[[], Iterable[Example]]] = None, *, + labels: Optional[Dict[str, Any]] = None, sgd: Optional[Optimizer] = None, ) -> Optimizer: """Initialize the pipe for training, using data examples if available. get_examples (Callable[[], Iterable[Example]]): Optional function that returns gold-standard Example objects. + labels (Optional[Dict[str, Any]]): Labels to pass to pipe initialization, + using the names of the pipes as keys. Overrides labels that are in + the model configuration. sgd (Optional[Optimizer]): An optimizer to use for updates. If not provided, will be created using the .create_optimizer() method. RETURNS (thinc.api.Optimizer): The optimizer. @@ -1292,6 +1392,8 @@ class Language: for name, proc in self.pipeline: if isinstance(proc, ty.InitializableComponent): p_settings = I["components"].get(name, {}) + if labels is not None and name in labels: + p_settings["labels"] = labels[name] p_settings = validate_init_settings( proc.initialize, p_settings, section="components", name=name ) @@ -1725,6 +1827,7 @@ class Language: # using the nlp.config with all defaults. config = util.copy_config(config) orig_pipeline = config.pop("components", {}) + orig_distill = config.pop("distill", None) orig_pretraining = config.pop("pretraining", None) config["components"] = {} if auto_fill: @@ -1733,6 +1836,9 @@ class Language: filled = config filled["components"] = orig_pipeline config["components"] = orig_pipeline + if orig_distill is not None: + filled["distill"] = orig_distill + config["distill"] = orig_distill if orig_pretraining is not None: filled["pretraining"] = orig_pretraining config["pretraining"] = orig_pretraining diff --git a/spacy/pipeline/trainable_pipe.pyx b/spacy/pipeline/trainable_pipe.pyx index 77259fc0b..fcffd11ee 100644 --- a/spacy/pipeline/trainable_pipe.pyx +++ b/spacy/pipeline/trainable_pipe.pyx @@ -71,8 +71,8 @@ cdef class TrainablePipe(Pipe): teacher_pipe (Optional[TrainablePipe]): The teacher pipe to learn from. examples (Iterable[Example]): Distillation examples. The reference - and predicted docs must have the same number of tokens and the - same orthography. + (teacher) and predicted (student) docs must have the same number of + tokens and the same orthography. drop (float): dropout rate. sgd (Optional[Optimizer]): An optimizer. Will be created via create_optimizer if not set. diff --git a/spacy/pipeline/transition_parser.pyx b/spacy/pipeline/transition_parser.pyx index a2b6c167f..9e50dd7b2 100644 --- a/spacy/pipeline/transition_parser.pyx +++ b/spacy/pipeline/transition_parser.pyx @@ -224,8 +224,8 @@ class Parser(TrainablePipe): teacher_pipe (Optional[TrainablePipe]): The teacher pipe to learn from. examples (Iterable[Example]): Distillation examples. The reference - and predicted docs must have the same number of tokens and the - same orthography. + (teacher) and predicted (student) docs must have the same number of + tokens and the same orthography. drop (float): dropout rate. sgd (Optional[Optimizer]): An optimizer. Will be created via create_optimizer if not set. diff --git a/spacy/tests/test_language.py b/spacy/tests/test_language.py index 03790eb86..89fa08ec7 100644 --- a/spacy/tests/test_language.py +++ b/spacy/tests/test_language.py @@ -26,6 +26,12 @@ except ImportError: pass +TAGGER_TRAIN_DATA = [ + ("I like green eggs", {"tags": ["N", "V", "J", "N"]}), + ("Eat blue ham", {"tags": ["V", "J", "N"]}), +] + + def evil_component(doc): if "2" in doc.text: raise ValueError("no dice") @@ -799,3 +805,66 @@ def test_component_return(): nlp.add_pipe("test_component_bad_pipe") with pytest.raises(ValueError, match="instead of a Doc"): nlp("text") + + +@pytest.mark.slow +@pytest.mark.parametrize("teacher_tagger_name", ["tagger", "teacher_tagger"]) +def test_distill(teacher_tagger_name): + teacher = English() + teacher_tagger = teacher.add_pipe("tagger", name=teacher_tagger_name) + train_examples = [] + for t in TAGGER_TRAIN_DATA: + train_examples.append(Example.from_dict(teacher.make_doc(t[0]), t[1])) + + optimizer = teacher.initialize(get_examples=lambda: train_examples) + + for i in range(50): + losses = {} + teacher.update(train_examples, sgd=optimizer, losses=losses) + assert losses[teacher_tagger_name] < 0.00001 + + student = English() + student_tagger = student.add_pipe("tagger") + student_tagger.min_tree_freq = 1 + student_tagger.initialize( + get_examples=lambda: train_examples, labels=teacher_tagger.label_data + ) + + distill_examples = [ + Example.from_dict(teacher.make_doc(t[0]), {}) for t in TAGGER_TRAIN_DATA + ] + + student_to_teacher = ( + None + if teacher_tagger.name == student_tagger.name + else {student_tagger.name: teacher_tagger.name} + ) + + for i in range(50): + losses = {} + student.distill( + teacher, + distill_examples, + sgd=optimizer, + losses=losses, + student_to_teacher=student_to_teacher, + ) + assert losses["tagger"] < 0.00001 + + test_text = "I like blue eggs" + doc = student(test_text) + assert doc[0].tag_ == "N" + assert doc[1].tag_ == "V" + assert doc[2].tag_ == "J" + assert doc[3].tag_ == "N" + + # Do an extra update to check if annotates works, though we can't really + # validate the resuls, since the annotations are ephemeral. + student.distill( + teacher, + distill_examples, + sgd=optimizer, + losses=losses, + student_to_teacher=student_to_teacher, + annotates=["tagger"], + ) diff --git a/spacy/ty.py b/spacy/ty.py index 52b38d515..f6dece840 100644 --- a/spacy/ty.py +++ b/spacy/ty.py @@ -26,6 +26,25 @@ class TrainableComponent(Protocol): ... +@runtime_checkable +class DistillableComponent(Protocol): + is_distillable: bool + + def distill( + self, + teacher_pipe: Optional[TrainableComponent], + examples: Iterable["Example"], + *, + drop: float = 0.0, + sgd: Optional[Optimizer] = None, + losses: Optional[Dict[str, float]] = None + ) -> Dict[str, float]: + ... + + def finish_update(self, sgd: Optimizer) -> None: + ... + + @runtime_checkable class InitializableComponent(Protocol): def initialize( diff --git a/website/docs/api/dependencyparser.mdx b/website/docs/api/dependencyparser.mdx index 5179ce48b..296d6d87d 100644 --- a/website/docs/api/dependencyparser.mdx +++ b/website/docs/api/dependencyparser.mdx @@ -154,15 +154,15 @@ This feature is experimental. > losses = student.distill(teacher_pipe, examples, sgd=optimizer) > ``` -| Name | Description | -| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------- | -| `teacher_pipe` | The teacher pipe to learn from. ~~Optional[TrainablePipe]~~ | -| `examples` | Distillation examples. The reference and predicted docs must have the same number of tokens and the same orthography. ~~Iterable[Example]~~ | -| _keyword-only_ | | -| `drop` | Dropout rate. ~~float~~ | -| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ | -| `losses` | Optional record of the loss during distillation. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ | -| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ | +| Name | Description | +| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `teacher_pipe` | The teacher pipe to learn from. ~~Optional[TrainablePipe]~~ | +| `examples` | A batch of [`Example`](/api/example) distillation examples. The reference (teacher) and predicted (student) docs must have the same number of tokens and orthography. ~~Iterable[Example]~~ | +| _keyword-only_ | | +| `drop` | Dropout rate. ~~float~~ | +| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ | +| `losses` | Optional record of the loss during distillation. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ | +| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ | ## DependencyParser.pipe {id="pipe",tag="method"} diff --git a/website/docs/api/edittreelemmatizer.mdx b/website/docs/api/edittreelemmatizer.mdx index 2e0993657..c8b5c7180 100644 --- a/website/docs/api/edittreelemmatizer.mdx +++ b/website/docs/api/edittreelemmatizer.mdx @@ -138,15 +138,15 @@ This feature is experimental. > losses = student.distill(teacher_pipe, examples, sgd=optimizer) > ``` -| Name | Description | -| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------- | -| `teacher_pipe` | The teacher pipe to learn from. ~~Optional[TrainablePipe]~~ | -| `examples` | Distillation examples. The reference and predicted docs must have the same number of tokens and the same orthography. ~~Iterable[Example]~~ | -| _keyword-only_ | | -| `drop` | Dropout rate. ~~float~~ | -| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ | -| `losses` | Optional record of the loss during distillation. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ | -| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ | +| Name | Description | +| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `teacher_pipe` | The teacher pipe to learn from. ~~Optional[TrainablePipe]~~ | +| `examples` | A batch of [`Example`](/api/example) distillation examples. The reference (teacher) and predicted (student) docs must have the same number of tokens and orthography. ~~Iterable[Example]~~ | +| _keyword-only_ | | +| `drop` | Dropout rate. ~~float~~ | +| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ | +| `losses` | Optional record of the loss during distillation. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ | +| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ | ## EditTreeLemmatizer.pipe {id="pipe",tag="method"} diff --git a/website/docs/api/entityrecognizer.mdx b/website/docs/api/entityrecognizer.mdx index 005d5d11d..f503cc998 100644 --- a/website/docs/api/entityrecognizer.mdx +++ b/website/docs/api/entityrecognizer.mdx @@ -150,15 +150,15 @@ This feature is experimental. > losses = student.distill(teacher_pipe, examples, sgd=optimizer) > ``` -| Name | Description | -| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------- | -| `teacher_pipe` | The teacher pipe to learn from. ~~Optional[TrainablePipe]~~ | -| `examples` | Distillation examples. The reference and predicted docs must have the same number of tokens and the same orthography. ~~Iterable[Example]~~ | -| _keyword-only_ | | -| `drop` | Dropout rate. ~~float~~ | -| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ | -| `losses` | Optional record of the loss during distillation. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ | -| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ | +| Name | Description | +| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `teacher_pipe` | The teacher pipe to learn from. ~~Optional[TrainablePipe]~~ | +| `examples` | A batch of [`Example`](/api/example) distillation examples. The reference (teacher) and predicted (student) docs must have the same number of tokens and orthography. ~~Iterable[Example]~~ | +| _keyword-only_ | | +| `drop` | Dropout rate. ~~float~~ | +| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ | +| `losses` | Optional record of the loss during distillation. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ | +| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ | ## EntityRecognizer.pipe {id="pipe",tag="method"} diff --git a/website/docs/api/language.mdx b/website/docs/api/language.mdx index a34ea7242..c25bfcee5 100644 --- a/website/docs/api/language.mdx +++ b/website/docs/api/language.mdx @@ -333,6 +333,34 @@ and custom registered functions if needed. See the | `component_cfg` | Optional dictionary of keyword arguments for components, keyed by component names. Defaults to `None`. ~~Optional[Dict[str, Dict[str, Any]]]~~ | | **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ | +## Language.distill {id="distill",tag="method,experimental",version="4"} + +Distill the models in a student pipeline from a teacher pipeline. + +> #### Example +> +> ```python +> +> teacher = spacy.load("en_core_web_lg") +> student = English() +> student.add_pipe("tagger") +> student.distill(teacher, examples, sgd=optimizer) +> ``` + +| Name | Description | +| -------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `teacher` | The teacher pipeline to distill from. ~~Language~~ | +| `examples` | A batch of [`Example`](/api/example) distillation examples. The reference (teacher) and predicted (student) docs must have the same number of tokens and orthography. ~~Iterable[Example]~~ | +| _keyword-only_ | | +| `drop` | The dropout rate. ~~float~~ | +| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ | +| `losses` | Dictionary to update with the loss, keyed by pipeline component. ~~Optional[Dict[str, float]]~~ | +| `component_cfg` | Optional dictionary of keyword arguments for components, keyed by component names. Defaults to `None`. ~~Optional[Dict[str, Dict[str, Any]]]~~ | +| `exclude` | Names of components that shouldn't be updated. Defaults to `[]`. ~~Iterable[str]~~ | +| `annotates` | Names of components that should set annotations on the prediced examples after updating. Defaults to `[]`. ~~Iterable[str]~~ | +| `student_to_teacher` | Map student component names to teacher component names, only necessary when the names differ. Defaults to `None`. ~~Optional[Dict[str, str]]~~ | +| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ | + ## Language.rehearse {id="rehearse",tag="method,experimental",version="3"} Perform a "rehearsal" update from a batch of data. Rehearsal updates teach the diff --git a/website/docs/api/morphologizer.mdx b/website/docs/api/morphologizer.mdx index 4f79458d3..4660ec312 100644 --- a/website/docs/api/morphologizer.mdx +++ b/website/docs/api/morphologizer.mdx @@ -144,15 +144,15 @@ This feature is experimental. > losses = student.distill(teacher_pipe, examples, sgd=optimizer) > ``` -| Name | Description | -| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------- | -| `teacher_pipe` | The teacher pipe to learn from. ~~Optional[TrainablePipe]~~ | -| `examples` | Distillation examples. The reference and predicted docs must have the same number of tokens and the same orthography. ~~Iterable[Example]~~ | -| _keyword-only_ | | -| `drop` | Dropout rate. ~~float~~ | -| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ | -| `losses` | Optional record of the loss during distillation. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ | -| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ | +| Name | Description | +| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `teacher_pipe` | The teacher pipe to learn from. ~~Optional[TrainablePipe]~~ | +| `examples` | A batch of [`Example`](/api/example) distillation examples. The reference (teacher) and predicted (student) docs must have the same number of tokens and orthography. ~~Iterable[Example]~~ | +| _keyword-only_ | | +| `drop` | Dropout rate. ~~float~~ | +| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ | +| `losses` | Optional record of the loss during distillation. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ | +| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ | ## Morphologizer.pipe {id="pipe",tag="method"} diff --git a/website/docs/api/pipe.mdx b/website/docs/api/pipe.mdx index 120c8f690..e1e7f5d70 100644 --- a/website/docs/api/pipe.mdx +++ b/website/docs/api/pipe.mdx @@ -257,15 +257,15 @@ This feature is experimental. > losses = student.distill(teacher_pipe, examples, sgd=optimizer) > ``` -| Name | Description | -| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------- | -| `teacher_pipe` | The teacher pipe to learn from. ~~Optional[TrainablePipe]~~ | -| `examples` | Distillation examples. The reference and predicted docs must have the same number of tokens and the same orthography. ~~Iterable[Example]~~ | -| _keyword-only_ | | -| `drop` | Dropout rate. ~~float~~ | -| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ | -| `losses` | Optional record of the loss during distillation. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ | -| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ | +| Name | Description | +| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `teacher_pipe` | The teacher pipe to learn from. ~~Optional[TrainablePipe]~~ | +| `examples` | A batch of [`Example`](/api/example) distillation examples. The reference (teacher) and predicted (student) docs must have the same number of tokens and orthography. ~~Iterable[Example]~~ | +| _keyword-only_ | | +| `drop` | Dropout rate. ~~float~~ | +| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ | +| `losses` | Optional record of the loss during distillation. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ | +| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ | ## TrainablePipe.rehearse {id="rehearse",tag="method,experimental",version="3"} diff --git a/website/docs/api/sentencerecognizer.mdx b/website/docs/api/sentencerecognizer.mdx index 02fd57102..dfb7ed308 100644 --- a/website/docs/api/sentencerecognizer.mdx +++ b/website/docs/api/sentencerecognizer.mdx @@ -129,15 +129,15 @@ This feature is experimental. > losses = student.distill(teacher_pipe, examples, sgd=optimizer) > ``` -| Name | Description | -| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------- | -| `teacher_pipe` | The teacher pipe to learn from. ~~Optional[TrainablePipe]~~ | -| `examples` | Distillation examples. The reference and predicted docs must have the same number of tokens and the same orthography. ~~Iterable[Example]~~ | -| _keyword-only_ | | -| `drop` | Dropout rate. ~~float~~ | -| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ | -| `losses` | Optional record of the loss during distillation. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ | -| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ | +| Name | Description | +| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `teacher_pipe` | The teacher pipe to learn from. ~~Optional[TrainablePipe]~~ | +| `examples` | A batch of [`Example`](/api/example) distillation examples. The reference (teacher) and predicted (student) docs must have the same number of tokens and orthography. ~~Iterable[Example]~~ | +| _keyword-only_ | | +| `drop` | Dropout rate. ~~float~~ | +| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ | +| `losses` | Optional record of the loss during distillation. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ | +| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ | ## SentenceRecognizer.pipe {id="pipe",tag="method"} diff --git a/website/docs/api/tagger.mdx b/website/docs/api/tagger.mdx index 664fd7940..35e7a23b1 100644 --- a/website/docs/api/tagger.mdx +++ b/website/docs/api/tagger.mdx @@ -128,15 +128,15 @@ This feature is experimental. > losses = student.distill(teacher_pipe, examples, sgd=optimizer) > ``` -| Name | Description | -| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------- | -| `teacher_pipe` | The teacher pipe to learn from. ~~Optional[TrainablePipe]~~ | -| `examples` | Distillation examples. The reference and predicted docs must have the same number of tokens and the same orthography. ~~Iterable[Example]~~ | -| _keyword-only_ | | -| `drop` | Dropout rate. ~~float~~ | -| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ | -| `losses` | Optional record of the loss during distillation. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ | -| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ | +| Name | Description | +| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `teacher_pipe` | The teacher pipe to learn from. ~~Optional[TrainablePipe]~~ | +| `examples` | A batch of [`Example`](/api/example) distillation examples. The reference (teacher) and predicted (student) docs must have the same number of tokens and orthography. ~~Iterable[Example]~~ | +| _keyword-only_ | | +| `drop` | Dropout rate. ~~float~~ | +| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ | +| `losses` | Optional record of the loss during distillation. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ | +| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ | ## Tagger.pipe {id="pipe",tag="method"}