_distill_loop: fix distill_data docstring

Make similar changes in train_while_improving, since it also had
incorrect types and missing type annotations.
This commit is contained in:
Daniël de Kok 2023-04-19 20:13:58 +02:00
parent abc4a7d320
commit c355367f87

View File

@ -295,10 +295,9 @@ def _distill_loop(
teacher (Language): The teacher pipeline to distill from. teacher (Language): The teacher pipeline to distill from.
student (Language): The student pipeline to distill into. student (Language): The student pipeline to distill into.
optimizer: The optimizer callable. optimizer: The optimizer callable.
distill_data (Iterable[Batch]): A generator of batches, with the distillation distill_data (Iterable[List[Example]]): A generator of batches,
data. Each batch should be a Sized[Tuple[Input, Annot]]. The distillation with the distillation data. The distillation data iterable
data iterable needs to take care of iterating over the epochs and needs to take care of iterating over the epochs and shuffling.
shuffling.
evaluate (Callable[[], Tuple[float, Any]]): A callback to perform evaluation. evaluate (Callable[[], Tuple[float, Any]]): A callback to perform evaluation.
The callback should take no arguments and return a tuple The callback should take no arguments and return a tuple
`(main_score, other_scores)`. The main_score should be a float where `(main_score, other_scores)`. The main_score should be a float where
@ -392,8 +391,8 @@ def _distill_loop(
def train_while_improving( def train_while_improving(
nlp: "Language", nlp: "Language",
optimizer: Optimizer, optimizer: Optimizer,
train_data, train_data: Iterable[List[Example]],
evaluate, evaluate: Callable[[], Tuple[float, Dict[str, float]]],
*, *,
dropout: float, dropout: float,
eval_frequency: int, eval_frequency: int,
@ -413,10 +412,9 @@ def train_while_improving(
Positional arguments: Positional arguments:
nlp: The spaCy pipeline to evaluate. nlp: The spaCy pipeline to evaluate.
optimizer: The optimizer callable. optimizer: The optimizer callable.
train_data (Iterable[Batch]): A generator of batches, with the training train_data (Iterable[List[Example]]): A generator of batches, with the
data. Each batch should be a Sized[Tuple[Input, Annot]]. The training training data. The training data iterable needs to take care of
data iterable needs to take care of iterating over the epochs and iterating over the epochs and shuffling.
shuffling.
evaluate (Callable[[], Tuple[float, Any]]): A callback to perform evaluation. evaluate (Callable[[], Tuple[float, Any]]): A callback to perform evaluation.
The callback should take no arguments and return a tuple The callback should take no arguments and return a tuple
`(main_score, other_scores)`. The main_score should be a float where `(main_score, other_scores)`. The main_score should be a float where
@ -480,7 +478,7 @@ def train_while_improving(
score, other_scores = evaluate() score, other_scores = evaluate()
else: else:
score, other_scores = evaluate() score, other_scores = evaluate()
optimizer.last_score = score optimizer.last_score = score # type: ignore[assignment]
results.append((score, step)) results.append((score, step))
is_best_checkpoint = score == max(results)[0] is_best_checkpoint = score == max(results)[0]
else: else: