Rename component_map to student_to_teacher

This commit is contained in:
Daniël de Kok 2023-01-30 10:06:08 +01:00
parent f470672972
commit bde20110c4
3 changed files with 21 additions and 21 deletions

View File

@ -1029,7 +1029,7 @@ class Language:
component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
exclude: Iterable[str] = SimpleFrozenList(),
annotates: Iterable[str] = SimpleFrozenList(),
component_map: Optional[Dict[str, str]] = None,
student_to_teacher: Optional[Dict[str, str]] = None,
):
"""Update the models in the pipeline.
teacher (Language): Teacher to distill from.
@ -1045,15 +1045,15 @@ class Language:
exclude (Iterable[str]): Names of components that shouldn't be updated.
annotates (Iterable[str]): Names of components that should set
annotations on the predicted examples after updating.
component_map (Optional[Dict[str, str]]): Map student pipe name to
student_to_teacher (Optional[Dict[str, str]]): Map student pipe name to
teacher pipe name, only needed for pipes where the student pipe
name does not match the teacher pipe name.
RETURNS (Dict[str, float]): The updated losses dictionary
DOCS: https://spacy.io/api/language#distill
"""
if component_map is None:
component_map = {}
if student_to_teacher is None:
student_to_teacher = {}
if losses is None:
losses = {}
if isinstance(examples, list) and len(examples) == 0:
@ -1099,7 +1099,7 @@ class Language:
# A missing teacher pipe is not an error, some student pipes
# do not need a teacher, such as tok2vec layer losses.
teacher_pipe_name = (
component_map[name] if name in component_map else name
student_to_teacher[name] if name in student_to_teacher else name
)
teacher_pipe = teacher_pipes.get(teacher_pipe_name, None)
student_proc.distill(

View File

@ -833,7 +833,7 @@ def test_distill(teacher_tagger_name):
Example.from_dict(teacher.make_doc(t[0]), {}) for t in TAGGER_TRAIN_DATA
]
component_map = (
student_to_teacher = (
None
if teacher_tagger.name == student_tagger.name
else {student_tagger.name: teacher_tagger.name}
@ -846,7 +846,7 @@ def test_distill(teacher_tagger_name):
distill_examples,
sgd=optimizer,
losses=losses,
component_map=component_map,
student_to_teacher=student_to_teacher,
)
assert losses["tagger"] < 0.00001
@ -864,6 +864,6 @@ def test_distill(teacher_tagger_name):
distill_examples,
sgd=optimizer,
losses=losses,
component_map=component_map,
student_to_teacher=student_to_teacher,
annotates=["tagger"],
)

View File

@ -348,7 +348,7 @@ Distill the models in a student pipeline from a teacher pipeline.
> ```
| Name | Description |
| --------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| -------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `teacher` | The teacher pipeline to distill from. ~~Language~~ |
| `examples` | A batch of [`Example`](/api/example) distillation examples. The reference (teacher) and predicted (student) docs must have the same number of tokens and orthography. ~~Iterable[Example]~~ |
| _keyword-only_ | |
@ -358,7 +358,7 @@ Distill the models in a student pipeline from a teacher pipeline.
| `component_cfg` | Optional dictionary of keyword arguments for components, keyed by component names. Defaults to `None`. ~~Optional[Dict[str, Dict[str, Any]]]~~ |
| `exclude` | Names of components that shouldn't be updated. Defaults to `[]`. ~~Iterable[str]~~ |
| `annotates` | Names of components that should set annotations on the prediced examples after updating. Defaults to `[]`. ~~Iterable[str]~~ |
| `component_map` | Map student component names to teacher component names, only necessary when the names differ. Defaults to `None`. ~~Iterable[str]~~ |
| `student_to_teacher` | Map student component names to teacher component names, only necessary when the names differ. Defaults to `None`. ~~Iterable[str]~~ |
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
## Language.rehearse {id="rehearse",tag="method,experimental",version="3"}