Apply suggestions from code review

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
This commit is contained in:
Daniël de Kok 2023-01-12 12:12:58 +01:00 committed by GitHub
parent b1b8b72a0c
commit ad1a330a41
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 15 additions and 3 deletions

View File

@ -949,7 +949,7 @@ class Errors(metaclass=ErrorsWithCodes):
E4000 = ("Expected a Doc as input, but got: '{type}'")
E4001 = ("Expected input to be one of the following types: ({expected_types}), "
"but got '{received_type}'")
E4002 = ("Pipe '{name}' requires teacher pipe for distillation.")
E4002 = ("Pipe '{name}' requires a teacher pipe for distillation.")
# fmt: on

View File

@ -164,6 +164,8 @@ class EditTreeLemmatizer(TrainablePipe):
teacher_scores: Scores representing the teacher model's predictions.
student_scores: Scores representing the student model's predictions.
RETURNS (Tuple[float, float]): The loss and the gradient.
DOCS: https://spacy.io/api/edittreelemmatizer#get_teacher_student_loss
"""
loss_func = LegacySequenceCategoricalCrossentropy(normalize=False)

View File

@ -275,6 +275,8 @@ class Tagger(TrainablePipe):
teacher_scores: Scores representing the teacher model's predictions.
student_scores: Scores representing the student model's predictions.
RETURNS (Tuple[float, float]): The loss and the gradient.
DOCS: https://spacy.io/api/tagger#get_teacher_student_loss
"""
loss_func = LegacySequenceCategoricalCrossentropy(normalize=False)

View File

@ -81,6 +81,8 @@ cdef class TrainablePipe(Pipe):
losses (Optional[Dict[str, float]]): Optional record of loss during
distillation.
RETURNS: The updated losses dictionary.
DOCS: https://spacy.io/api/pipe#distill
"""
# By default we require a teacher pipe, but there are downstream
# implementations that don't require a pipe.
@ -226,6 +228,8 @@ cdef class TrainablePipe(Pipe):
teacher_scores: Scores representing the teacher model's predictions.
student_scores: Scores representing the student model's predictions.
RETURNS (Tuple[float, float]): The loss and the gradient.
DOCS: https://spacy.io/api/pipe#get_teacher_student_loss
"""
raise NotImplementedError(Errors.E931.format(parent="TrainablePipe", method="get_teacher_student_loss", name=self.name))

View File

@ -231,6 +231,8 @@ cdef class Parser(TrainablePipe):
losses (Optional[Dict[str, float]]): Optional record of loss during
distillation.
RETURNS: The updated losses dictionary.
DOCS: https://spacy.io/api/dependencyparser#distill
"""
if teacher_pipe is None:
raise ValueError(Errors.E4002.format(name=self.name))
@ -313,6 +315,8 @@ cdef class Parser(TrainablePipe):
teacher_scores: Scores representing the teacher model's predictions.
student_scores: Scores representing the student model's predictions.
RETURNS (Tuple[float, float]): The loss and the gradient.
DOCS: https://spacy.io/api/dependencyparser#get_teacher_student_loss
"""
loss_func = LegacySequenceCategoricalCrossentropy(normalize=False)

View File

@ -234,7 +234,7 @@ predictions and gold-standard annotations, and update the component's model.
| `losses` | Optional record of the loss during training. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ |
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
## TrainablePipe.distill {#rehearse tag="method,experimental" new="4"}
## TrainablePipe.distill {#distill tag="method,experimental" new="4"}
Train a pipe (the student) on the predictions of another pipe (the teacher). The
student is typically trained on the probability distribution of the teacher, but
@ -308,7 +308,7 @@ This method needs to be overwritten with your own custom `get_loss` method.
| `scores` | Scores representing the model's predictions. |
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
## TrainablePipe.get_teacher_student_loss {#get_teacher_student_loss tag="method"}
## TrainablePipe.get_teacher_student_loss {#get_teacher_student_loss tag="method" new="4"}
Calculate the loss and its gradient for the batch of student scores relative to
the teacher scores.