diff --git a/spacy/training/loop.py b/spacy/training/loop.py index ae73ea6c1..953fffaed 100644 --- a/spacy/training/loop.py +++ b/spacy/training/loop.py @@ -295,10 +295,9 @@ def _distill_loop( teacher (Language): The teacher pipeline to distill from. student (Language): The student pipeline to distill into. optimizer: The optimizer callable. - distill_data (Iterable[Batch]): A generator of batches, with the distillation - data. Each batch should be a Sized[Tuple[Input, Annot]]. The distillation - data iterable needs to take care of iterating over the epochs and - shuffling. + distill_data (Iterable[List[Example]]): A generator of batches, + with the distillation data. The distillation data iterable + needs to take care of iterating over the epochs and shuffling. evaluate (Callable[[], Tuple[float, Any]]): A callback to perform evaluation. The callback should take no arguments and return a tuple `(main_score, other_scores)`. The main_score should be a float where @@ -392,8 +391,8 @@ def _distill_loop( def train_while_improving( nlp: "Language", optimizer: Optimizer, - train_data, - evaluate, + train_data: Iterable[List[Example]], + evaluate: Callable[[], Tuple[float, Dict[str, float]]], *, dropout: float, eval_frequency: int, @@ -413,10 +412,9 @@ def train_while_improving( Positional arguments: nlp: The spaCy pipeline to evaluate. optimizer: The optimizer callable. - train_data (Iterable[Batch]): A generator of batches, with the training - data. Each batch should be a Sized[Tuple[Input, Annot]]. The training - data iterable needs to take care of iterating over the epochs and - shuffling. + train_data (Iterable[List[Example]]): A generator of batches, with the + training data. The training data iterable needs to take care of + iterating over the epochs and shuffling. evaluate (Callable[[], Tuple[float, Any]]): A callback to perform evaluation. The callback should take no arguments and return a tuple `(main_score, other_scores)`. The main_score should be a float where @@ -480,7 +478,7 @@ def train_while_improving( score, other_scores = evaluate() else: score, other_scores = evaluate() - optimizer.last_score = score + optimizer.last_score = score # type: ignore[assignment] results.append((score, step)) is_best_checkpoint = score == max(results)[0] else: