Generalize arguments passed to the callback

This commit is contained in:
shademe 2022-11-07 11:29:04 +01:00
parent 43e5fb7b6b
commit e477eb2f7f
No known key found for this signature in database
GPG Key ID: 6FCA9FC635B2A402
2 changed files with 4 additions and 3 deletions

View File

@ -152,7 +152,7 @@ def train_while_improving(
max_steps: int,
exclude: List[str],
annotating_components: List[str],
before_update: Optional[Callable[["Language", int], None]],
before_update: Optional[Callable[["Language", Dict[str, Any]], None]],
):
"""Train until an evaluation stops improving. Works as a generator,
with each iteration yielding a tuple `(batch, info, is_best_checkpoint)`,
@ -202,7 +202,8 @@ def train_while_improving(
start_time = timer()
for step, (epoch, batch) in enumerate(train_data):
if before_update:
before_update(nlp, step)
before_update_args = {"current_step": step}
before_update(nlp, before_update_args)
dropout = next(dropouts) # type: ignore
for subbatch in subdivide_batch(batch, accumulate_gradient):
nlp.update(

View File

@ -186,7 +186,7 @@ process that are used when you run [`spacy train`](/api/cli#train).
| `accumulate_gradient` | Whether to divide the batch up into substeps. Defaults to `1`. ~~int~~ |
| `batcher` | Callable that takes an iterator of [`Doc`](/api/doc) objects and yields batches of `Doc`s. Defaults to [`batch_by_words`](/api/top-level#batch_by_words). ~~Callable[[Iterator[Doc], Iterator[List[Doc]]]]~~ |
| `before_to_disk` | Optional callback to modify `nlp` object right before it is saved to disk during and after training. Can be used to remove or reset config values or disable components. Defaults to `null`. ~~Optional[Callable[[Language], Language]]~~ |
| `before_update` | Optional callback that is invoked at the start of each training step with the `nlp` object and the current step. Can be used to make deferred changes to components. Defaults to `null`. ~~Optional[Callable[[Language, int], None]]~~ |
| `before_update` | Optional callback that is invoked at the start of each training step with the `nlp` object and a `Dict` containing the following entries: `current_step`. Can be used to make deferred changes to components. Defaults to `null`. ~~Optional[Callable[[Language, Dict[str, Any]], None]]~~ |
| `dev_corpus` | Dot notation of the config location defining the dev corpus. Defaults to `corpora.dev`. ~~str~~ |
| `dropout` | The dropout rate. Defaults to `0.1`. ~~float~~ |
| `eval_frequency` | How often to evaluate during training (steps). Defaults to `200`. ~~int~~ |