diff --git a/spacy/language.py b/spacy/language.py index cc0c9b55a..b16056465 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -1029,7 +1029,7 @@ class Language: component_cfg: Optional[Dict[str, Dict[str, Any]]] = None, exclude: Iterable[str] = SimpleFrozenList(), annotates: Iterable[str] = SimpleFrozenList(), - component_map: Optional[Dict[str, str]] = None, + student_to_teacher: Optional[Dict[str, str]] = None, ): """Update the models in the pipeline. teacher (Language): Teacher to distill from. @@ -1045,15 +1045,15 @@ class Language: exclude (Iterable[str]): Names of components that shouldn't be updated. annotates (Iterable[str]): Names of components that should set annotations on the predicted examples after updating. - component_map (Optional[Dict[str, str]]): Map student pipe name to + student_to_teacher (Optional[Dict[str, str]]): Map student pipe name to teacher pipe name, only needed for pipes where the student pipe name does not match the teacher pipe name. RETURNS (Dict[str, float]): The updated losses dictionary DOCS: https://spacy.io/api/language#distill """ - if component_map is None: - component_map = {} + if student_to_teacher is None: + student_to_teacher = {} if losses is None: losses = {} if isinstance(examples, list) and len(examples) == 0: @@ -1099,7 +1099,7 @@ class Language: # A missing teacher pipe is not an error, some student pipes # do not need a teacher, such as tok2vec layer losses. teacher_pipe_name = ( - component_map[name] if name in component_map else name + student_to_teacher[name] if name in student_to_teacher else name ) teacher_pipe = teacher_pipes.get(teacher_pipe_name, None) student_proc.distill( diff --git a/spacy/tests/test_language.py b/spacy/tests/test_language.py index 1f20fe455..e7fd5d024 100644 --- a/spacy/tests/test_language.py +++ b/spacy/tests/test_language.py @@ -833,7 +833,7 @@ def test_distill(teacher_tagger_name): Example.from_dict(teacher.make_doc(t[0]), {}) for t in TAGGER_TRAIN_DATA ] - component_map = ( + student_to_teacher = ( None if teacher_tagger.name == student_tagger.name else {student_tagger.name: teacher_tagger.name} @@ -846,7 +846,7 @@ def test_distill(teacher_tagger_name): distill_examples, sgd=optimizer, losses=losses, - component_map=component_map, + student_to_teacher=student_to_teacher, ) assert losses["tagger"] < 0.00001 @@ -864,6 +864,6 @@ def test_distill(teacher_tagger_name): distill_examples, sgd=optimizer, losses=losses, - component_map=component_map, + student_to_teacher=student_to_teacher, annotates=["tagger"], ) diff --git a/website/docs/api/language.mdx b/website/docs/api/language.mdx index 8262d6e68..916e4e82b 100644 --- a/website/docs/api/language.mdx +++ b/website/docs/api/language.mdx @@ -347,19 +347,19 @@ Distill the models in a student pipeline from a teacher pipeline. > student.distill(teacher, examples, sgd=optimizer) > ``` -| Name | Description | -| --------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `teacher` | The teacher pipeline to distill from. ~~Language~~ | -| `examples` | A batch of [`Example`](/api/example) distillation examples. The reference (teacher) and predicted (student) docs must have the same number of tokens and orthography. ~~Iterable[Example]~~ | -| _keyword-only_ | | -| `drop` | The dropout rate. ~~float~~ | -| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ | -| `losses` | Dictionary to update with the loss, keyed by pipeline component. ~~Optional[Dict[str, float]]~~ | -| `component_cfg` | Optional dictionary of keyword arguments for components, keyed by component names. Defaults to `None`. ~~Optional[Dict[str, Dict[str, Any]]]~~ | -| `exclude` | Names of components that shouldn't be updated. Defaults to `[]`. ~~Iterable[str]~~ | -| `annotates` | Names of components that should set annotations on the prediced examples after updating. Defaults to `[]`. ~~Iterable[str]~~ | -| `component_map` | Map student component names to teacher component names, only necessary when the names differ. Defaults to `None`. ~~Iterable[str]~~ | -| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ | +| Name | Description | +| -------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `teacher` | The teacher pipeline to distill from. ~~Language~~ | +| `examples` | A batch of [`Example`](/api/example) distillation examples. The reference (teacher) and predicted (student) docs must have the same number of tokens and orthography. ~~Iterable[Example]~~ | +| _keyword-only_ | | +| `drop` | The dropout rate. ~~float~~ | +| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ | +| `losses` | Dictionary to update with the loss, keyed by pipeline component. ~~Optional[Dict[str, float]]~~ | +| `component_cfg` | Optional dictionary of keyword arguments for components, keyed by component names. Defaults to `None`. ~~Optional[Dict[str, Dict[str, Any]]]~~ | +| `exclude` | Names of components that shouldn't be updated. Defaults to `[]`. ~~Iterable[str]~~ | +| `annotates` | Names of components that should set annotations on the prediced examples after updating. Defaults to `[]`. ~~Iterable[str]~~ | +| `student_to_teacher` | Map student component names to teacher component names, only necessary when the names differ. Defaults to `None`. ~~Iterable[str]~~ | +| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ | ## Language.rehearse {id="rehearse",tag="method,experimental",version="3"}