mirror of
https://github.com/explosion/spaCy.git
synced 2025-06-13 01:23:04 +03:00
_distill_loop: fix distill_data docstring
Make similar changes in train_while_improving, since it also had incorrect types and missing type annotations.
This commit is contained in:
parent
abc4a7d320
commit
c355367f87
|
@ -295,10 +295,9 @@ def _distill_loop(
|
||||||
teacher (Language): The teacher pipeline to distill from.
|
teacher (Language): The teacher pipeline to distill from.
|
||||||
student (Language): The student pipeline to distill into.
|
student (Language): The student pipeline to distill into.
|
||||||
optimizer: The optimizer callable.
|
optimizer: The optimizer callable.
|
||||||
distill_data (Iterable[Batch]): A generator of batches, with the distillation
|
distill_data (Iterable[List[Example]]): A generator of batches,
|
||||||
data. Each batch should be a Sized[Tuple[Input, Annot]]. The distillation
|
with the distillation data. The distillation data iterable
|
||||||
data iterable needs to take care of iterating over the epochs and
|
needs to take care of iterating over the epochs and shuffling.
|
||||||
shuffling.
|
|
||||||
evaluate (Callable[[], Tuple[float, Any]]): A callback to perform evaluation.
|
evaluate (Callable[[], Tuple[float, Any]]): A callback to perform evaluation.
|
||||||
The callback should take no arguments and return a tuple
|
The callback should take no arguments and return a tuple
|
||||||
`(main_score, other_scores)`. The main_score should be a float where
|
`(main_score, other_scores)`. The main_score should be a float where
|
||||||
|
@ -392,8 +391,8 @@ def _distill_loop(
|
||||||
def train_while_improving(
|
def train_while_improving(
|
||||||
nlp: "Language",
|
nlp: "Language",
|
||||||
optimizer: Optimizer,
|
optimizer: Optimizer,
|
||||||
train_data,
|
train_data: Iterable[List[Example]],
|
||||||
evaluate,
|
evaluate: Callable[[], Tuple[float, Dict[str, float]]],
|
||||||
*,
|
*,
|
||||||
dropout: float,
|
dropout: float,
|
||||||
eval_frequency: int,
|
eval_frequency: int,
|
||||||
|
@ -413,10 +412,9 @@ def train_while_improving(
|
||||||
Positional arguments:
|
Positional arguments:
|
||||||
nlp: The spaCy pipeline to evaluate.
|
nlp: The spaCy pipeline to evaluate.
|
||||||
optimizer: The optimizer callable.
|
optimizer: The optimizer callable.
|
||||||
train_data (Iterable[Batch]): A generator of batches, with the training
|
train_data (Iterable[List[Example]]): A generator of batches, with the
|
||||||
data. Each batch should be a Sized[Tuple[Input, Annot]]. The training
|
training data. The training data iterable needs to take care of
|
||||||
data iterable needs to take care of iterating over the epochs and
|
iterating over the epochs and shuffling.
|
||||||
shuffling.
|
|
||||||
evaluate (Callable[[], Tuple[float, Any]]): A callback to perform evaluation.
|
evaluate (Callable[[], Tuple[float, Any]]): A callback to perform evaluation.
|
||||||
The callback should take no arguments and return a tuple
|
The callback should take no arguments and return a tuple
|
||||||
`(main_score, other_scores)`. The main_score should be a float where
|
`(main_score, other_scores)`. The main_score should be a float where
|
||||||
|
@ -480,7 +478,7 @@ def train_while_improving(
|
||||||
score, other_scores = evaluate()
|
score, other_scores = evaluate()
|
||||||
else:
|
else:
|
||||||
score, other_scores = evaluate()
|
score, other_scores = evaluate()
|
||||||
optimizer.last_score = score
|
optimizer.last_score = score # type: ignore[assignment]
|
||||||
results.append((score, step))
|
results.append((score, step))
|
||||||
is_best_checkpoint = score == max(results)[0]
|
is_best_checkpoint = score == max(results)[0]
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user