Rename component_map to student_to_teacher

This commit is contained in:
Daniël de Kok 2023-01-30 10:06:08 +01:00
parent f470672972
commit bde20110c4
3 changed files with 21 additions and 21 deletions

View File

@ -1029,7 +1029,7 @@ class Language:
component_cfg: Optional[Dict[str, Dict[str, Any]]] = None, component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
exclude: Iterable[str] = SimpleFrozenList(), exclude: Iterable[str] = SimpleFrozenList(),
annotates: Iterable[str] = SimpleFrozenList(), annotates: Iterable[str] = SimpleFrozenList(),
component_map: Optional[Dict[str, str]] = None, student_to_teacher: Optional[Dict[str, str]] = None,
): ):
"""Update the models in the pipeline. """Update the models in the pipeline.
teacher (Language): Teacher to distill from. teacher (Language): Teacher to distill from.
@ -1045,15 +1045,15 @@ class Language:
exclude (Iterable[str]): Names of components that shouldn't be updated. exclude (Iterable[str]): Names of components that shouldn't be updated.
annotates (Iterable[str]): Names of components that should set annotates (Iterable[str]): Names of components that should set
annotations on the predicted examples after updating. annotations on the predicted examples after updating.
component_map (Optional[Dict[str, str]]): Map student pipe name to student_to_teacher (Optional[Dict[str, str]]): Map student pipe name to
teacher pipe name, only needed for pipes where the student pipe teacher pipe name, only needed for pipes where the student pipe
name does not match the teacher pipe name. name does not match the teacher pipe name.
RETURNS (Dict[str, float]): The updated losses dictionary RETURNS (Dict[str, float]): The updated losses dictionary
DOCS: https://spacy.io/api/language#distill DOCS: https://spacy.io/api/language#distill
""" """
if component_map is None: if student_to_teacher is None:
component_map = {} student_to_teacher = {}
if losses is None: if losses is None:
losses = {} losses = {}
if isinstance(examples, list) and len(examples) == 0: if isinstance(examples, list) and len(examples) == 0:
@ -1099,7 +1099,7 @@ class Language:
# A missing teacher pipe is not an error, some student pipes # A missing teacher pipe is not an error, some student pipes
# do not need a teacher, such as tok2vec layer losses. # do not need a teacher, such as tok2vec layer losses.
teacher_pipe_name = ( teacher_pipe_name = (
component_map[name] if name in component_map else name student_to_teacher[name] if name in student_to_teacher else name
) )
teacher_pipe = teacher_pipes.get(teacher_pipe_name, None) teacher_pipe = teacher_pipes.get(teacher_pipe_name, None)
student_proc.distill( student_proc.distill(

View File

@ -833,7 +833,7 @@ def test_distill(teacher_tagger_name):
Example.from_dict(teacher.make_doc(t[0]), {}) for t in TAGGER_TRAIN_DATA Example.from_dict(teacher.make_doc(t[0]), {}) for t in TAGGER_TRAIN_DATA
] ]
component_map = ( student_to_teacher = (
None None
if teacher_tagger.name == student_tagger.name if teacher_tagger.name == student_tagger.name
else {student_tagger.name: teacher_tagger.name} else {student_tagger.name: teacher_tagger.name}
@ -846,7 +846,7 @@ def test_distill(teacher_tagger_name):
distill_examples, distill_examples,
sgd=optimizer, sgd=optimizer,
losses=losses, losses=losses,
component_map=component_map, student_to_teacher=student_to_teacher,
) )
assert losses["tagger"] < 0.00001 assert losses["tagger"] < 0.00001
@ -864,6 +864,6 @@ def test_distill(teacher_tagger_name):
distill_examples, distill_examples,
sgd=optimizer, sgd=optimizer,
losses=losses, losses=losses,
component_map=component_map, student_to_teacher=student_to_teacher,
annotates=["tagger"], annotates=["tagger"],
) )

View File

@ -347,19 +347,19 @@ Distill the models in a student pipeline from a teacher pipeline.
> student.distill(teacher, examples, sgd=optimizer) > student.distill(teacher, examples, sgd=optimizer)
> ``` > ```
| Name | Description | | Name | Description |
| --------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | -------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `teacher` | The teacher pipeline to distill from. ~~Language~~ | | `teacher` | The teacher pipeline to distill from. ~~Language~~ |
| `examples` | A batch of [`Example`](/api/example) distillation examples. The reference (teacher) and predicted (student) docs must have the same number of tokens and orthography. ~~Iterable[Example]~~ | | `examples` | A batch of [`Example`](/api/example) distillation examples. The reference (teacher) and predicted (student) docs must have the same number of tokens and orthography. ~~Iterable[Example]~~ |
| _keyword-only_ | | | _keyword-only_ | |
| `drop` | The dropout rate. ~~float~~ | | `drop` | The dropout rate. ~~float~~ |
| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ | | `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ |
| `losses` | Dictionary to update with the loss, keyed by pipeline component. ~~Optional[Dict[str, float]]~~ | | `losses` | Dictionary to update with the loss, keyed by pipeline component. ~~Optional[Dict[str, float]]~~ |
| `component_cfg` | Optional dictionary of keyword arguments for components, keyed by component names. Defaults to `None`. ~~Optional[Dict[str, Dict[str, Any]]]~~ | | `component_cfg` | Optional dictionary of keyword arguments for components, keyed by component names. Defaults to `None`. ~~Optional[Dict[str, Dict[str, Any]]]~~ |
| `exclude` | Names of components that shouldn't be updated. Defaults to `[]`. ~~Iterable[str]~~ | | `exclude` | Names of components that shouldn't be updated. Defaults to `[]`. ~~Iterable[str]~~ |
| `annotates` | Names of components that should set annotations on the prediced examples after updating. Defaults to `[]`. ~~Iterable[str]~~ | | `annotates` | Names of components that should set annotations on the prediced examples after updating. Defaults to `[]`. ~~Iterable[str]~~ |
| `component_map` | Map student component names to teacher component names, only necessary when the names differ. Defaults to `None`. ~~Iterable[str]~~ | | `student_to_teacher` | Map student component names to teacher component names, only necessary when the names differ. Defaults to `None`. ~~Iterable[str]~~ |
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ | | **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
## Language.rehearse {id="rehearse",tag="method,experimental",version="3"} ## Language.rehearse {id="rehearse",tag="method,experimental",version="3"}