_distill_loop: fix distill_data docstring

Make similar changes in train_while_improving, since it also had
incorrect types and missing type annotations.
This commit is contained in:
Daniël de Kok 2023-04-19 20:13:58 +02:00
parent abc4a7d320
commit c355367f87

View File

@ -295,10 +295,9 @@ def _distill_loop(
teacher (Language): The teacher pipeline to distill from.
student (Language): The student pipeline to distill into.
optimizer: The optimizer callable.
distill_data (Iterable[Batch]): A generator of batches, with the distillation
data. Each batch should be a Sized[Tuple[Input, Annot]]. The distillation
data iterable needs to take care of iterating over the epochs and
shuffling.
distill_data (Iterable[List[Example]]): A generator of batches,
with the distillation data. The distillation data iterable
needs to take care of iterating over the epochs and shuffling.
evaluate (Callable[[], Tuple[float, Any]]): A callback to perform evaluation.
The callback should take no arguments and return a tuple
`(main_score, other_scores)`. The main_score should be a float where
@ -392,8 +391,8 @@ def _distill_loop(
def train_while_improving(
nlp: "Language",
optimizer: Optimizer,
train_data,
evaluate,
train_data: Iterable[List[Example]],
evaluate: Callable[[], Tuple[float, Dict[str, float]]],
*,
dropout: float,
eval_frequency: int,
@ -413,10 +412,9 @@ def train_while_improving(
Positional arguments:
nlp: The spaCy pipeline to evaluate.
optimizer: The optimizer callable.
train_data (Iterable[Batch]): A generator of batches, with the training
data. Each batch should be a Sized[Tuple[Input, Annot]]. The training
data iterable needs to take care of iterating over the epochs and
shuffling.
train_data (Iterable[List[Example]]): A generator of batches, with the
training data. The training data iterable needs to take care of
iterating over the epochs and shuffling.
evaluate (Callable[[], Tuple[float, Any]]): A callback to perform evaluation.
The callback should take no arguments and return a tuple
`(main_score, other_scores)`. The main_score should be a float where
@ -480,7 +478,7 @@ def train_while_improving(
score, other_scores = evaluate()
else:
score, other_scores = evaluate()
optimizer.last_score = score
optimizer.last_score = score # type: ignore[assignment]
results.append((score, step))
is_best_checkpoint = score == max(results)[0]
else: