Avoid TrainablePipe.finish_update getting called twice during training (#12450)

* Avoid `TrainablePipe.finish_update` getting called twice during training

PR #12136 fixed an issue where the tok2vec pipe was updated before
gradient were accumulated. However, it introduced a new bug that cause
`finish_update` to be called twice when using the training loop. This
causes a fairly large slowdown.

The `Language.update` method accepts the `sgd` argument for passing an
optimizer. This argument has three possible values:

- `Optimizer`: use the given optimizer to finish pipe updates.
- `None`: use a default optimizer to finish pipe updates.
- `False`: do not finish pipe updates.

However, the latter option was not documented and not valid with the
existing type of `sgd`. I assumed that this was a remnant of earlier
spaCy versions and removed handling of `False`.

However, with that change, we are passing `None` to `Language.update`.
As a result, we were calling `finish_update` in both `Language.update`
and in the training loop after all subbatches are processed.

This change restores proper handling/use of `False`. Moreover, the role
of `False` is now documented and added to the type to avoid future
accidents.

* Fix typo

* Document defaults for `Language.update`
This commit is contained in:
Daniël de Kok 2023-03-30 09:30:42 +02:00 committed by GitHub
parent a653dec654
commit b734e5314d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 33 additions and 12 deletions

View File

@ -1202,7 +1202,7 @@ class Language:
_: Optional[Any] = None,
*,
drop: float = 0.0,
sgd: Optional[Optimizer] = None,
sgd: Union[Optimizer, None, Literal[False]] = None,
losses: Optional[Dict[str, float]] = None,
component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
exclude: Iterable[str] = SimpleFrozenList(),
@ -1213,7 +1213,9 @@ class Language:
examples (Iterable[Example]): A batch of examples
_: Should not be set - serves to catch backwards-incompatible scripts.
drop (float): The dropout rate.
sgd (Optimizer): An optimizer.
sgd (Union[Optimizer, None, Literal[False]]): An optimizer. Will
be created via create_optimizer if 'None'. No optimizer will
be used when set to 'False'.
losses (Dict[str, float]): Dictionary to update with the loss, keyed by
component.
component_cfg (Dict[str, Dict]): Config parameters for specific pipeline
@ -1272,6 +1274,7 @@ class Language:
name not in exclude
and isinstance(proc, ty.TrainableComponent)
and proc.is_trainable
and sgd not in (None, False)
):
proc.finish_update(sgd)

View File

@ -157,6 +157,24 @@ def test_language_update_updates():
)
def test_language_update_does_not_update_with_sgd_false():
config = Config().from_str(TAGGER_CFG_STRING)
nlp = load_model_from_config(config, auto_fill=True, validate=True)
train_examples = []
for t in TAGGER_TRAIN_DATA:
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
nlp.initialize(get_examples=lambda: train_examples)
docs_before_update = list(nlp.pipe([eg.predicted.copy() for eg in train_examples]))
nlp.update(train_examples, sgd=False)
docs_after_update = list(nlp.pipe([eg.predicted.copy() for eg in train_examples]))
xp = get_array_module(docs_after_update[0].tensor)
xp.testing.assert_equal(docs_before_update[0].tensor, docs_after_update[0].tensor)
def test_language_evaluate(nlp):
text = "hello world"
annots = {"doc_annotation": {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}}}

View File

@ -210,7 +210,7 @@ def train_while_improving(
subbatch,
drop=dropout,
losses=losses,
sgd=None,
sgd=False,
exclude=exclude,
annotates=annotating_components,
)

View File

@ -323,15 +323,15 @@ and custom registered functions if needed. See the
> nlp.update([example], sgd=optimizer)
> ```
| Name | Description |
| --------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- |
| `examples` | A batch of [`Example`](/api/example) objects to learn from. ~~Iterable[Example]~~ |
| _keyword-only_ | |
| `drop` | The dropout rate. ~~float~~ |
| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ |
| `losses` | Dictionary to update with the loss, keyed by pipeline component. ~~Optional[Dict[str, float]]~~ |
| `component_cfg` | Optional dictionary of keyword arguments for components, keyed by component names. Defaults to `None`. ~~Optional[Dict[str, Dict[str, Any]]]~~ |
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
| Name | Description |
| --------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `examples` | A batch of [`Example`](/api/example) objects to learn from. ~~Iterable[Example]~~ |
| _keyword-only_ | |
| `drop` | The dropout rate. Defaults to `0.0`. ~~float~~ |
| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if `None`. No optimizer will be used when set to `False`. Defaults to `None`. ~~Union[Optimizer, None, Literal[False]]~~ |
| `losses` | Dictionary to update with the loss, keyed by pipeline component. Defaults to `None`. ~~Optional[Dict[str, float]]~~ |
| `component_cfg` | Optional dictionary of keyword arguments for components, keyed by component names. Defaults to `None`. ~~Optional[Dict[str, Dict[str, Any]]]~~ |
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
## Language.distill {id="distill",tag="method,experimental",version="4"}