Avoid TrainablePipe.finish_update getting called twice during training (#12450)

* Avoid `TrainablePipe.finish_update` getting called twice during training

PR #12136 fixed an issue where the tok2vec pipe was updated before
gradient were accumulated. However, it introduced a new bug that cause
`finish_update` to be called twice when using the training loop. This
causes a fairly large slowdown.

The `Language.update` method accepts the `sgd` argument for passing an
optimizer. This argument has three possible values:

- `Optimizer`: use the given optimizer to finish pipe updates.
- `None`: use a default optimizer to finish pipe updates.
- `False`: do not finish pipe updates.

However, the latter option was not documented and not valid with the
existing type of `sgd`. I assumed that this was a remnant of earlier
spaCy versions and removed handling of `False`.

However, with that change, we are passing `None` to `Language.update`.
As a result, we were calling `finish_update` in both `Language.update`
and in the training loop after all subbatches are processed.

This change restores proper handling/use of `False`. Moreover, the role
of `False` is now documented and added to the type to avoid future
accidents.

* Fix typo

* Document defaults for `Language.update`
This commit is contained in:
Daniël de Kok 2023-03-30 09:30:42 +02:00 committed by GitHub
parent a653dec654
commit b734e5314d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 33 additions and 12 deletions

View File

@ -1202,7 +1202,7 @@ class Language:
_: Optional[Any] = None, _: Optional[Any] = None,
*, *,
drop: float = 0.0, drop: float = 0.0,
sgd: Optional[Optimizer] = None, sgd: Union[Optimizer, None, Literal[False]] = None,
losses: Optional[Dict[str, float]] = None, losses: Optional[Dict[str, float]] = None,
component_cfg: Optional[Dict[str, Dict[str, Any]]] = None, component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
exclude: Iterable[str] = SimpleFrozenList(), exclude: Iterable[str] = SimpleFrozenList(),
@ -1213,7 +1213,9 @@ class Language:
examples (Iterable[Example]): A batch of examples examples (Iterable[Example]): A batch of examples
_: Should not be set - serves to catch backwards-incompatible scripts. _: Should not be set - serves to catch backwards-incompatible scripts.
drop (float): The dropout rate. drop (float): The dropout rate.
sgd (Optimizer): An optimizer. sgd (Union[Optimizer, None, Literal[False]]): An optimizer. Will
be created via create_optimizer if 'None'. No optimizer will
be used when set to 'False'.
losses (Dict[str, float]): Dictionary to update with the loss, keyed by losses (Dict[str, float]): Dictionary to update with the loss, keyed by
component. component.
component_cfg (Dict[str, Dict]): Config parameters for specific pipeline component_cfg (Dict[str, Dict]): Config parameters for specific pipeline
@ -1272,6 +1274,7 @@ class Language:
name not in exclude name not in exclude
and isinstance(proc, ty.TrainableComponent) and isinstance(proc, ty.TrainableComponent)
and proc.is_trainable and proc.is_trainable
and sgd not in (None, False)
): ):
proc.finish_update(sgd) proc.finish_update(sgd)

View File

@ -157,6 +157,24 @@ def test_language_update_updates():
) )
def test_language_update_does_not_update_with_sgd_false():
config = Config().from_str(TAGGER_CFG_STRING)
nlp = load_model_from_config(config, auto_fill=True, validate=True)
train_examples = []
for t in TAGGER_TRAIN_DATA:
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
nlp.initialize(get_examples=lambda: train_examples)
docs_before_update = list(nlp.pipe([eg.predicted.copy() for eg in train_examples]))
nlp.update(train_examples, sgd=False)
docs_after_update = list(nlp.pipe([eg.predicted.copy() for eg in train_examples]))
xp = get_array_module(docs_after_update[0].tensor)
xp.testing.assert_equal(docs_before_update[0].tensor, docs_after_update[0].tensor)
def test_language_evaluate(nlp): def test_language_evaluate(nlp):
text = "hello world" text = "hello world"
annots = {"doc_annotation": {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}}} annots = {"doc_annotation": {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}}}

View File

@ -210,7 +210,7 @@ def train_while_improving(
subbatch, subbatch,
drop=dropout, drop=dropout,
losses=losses, losses=losses,
sgd=None, sgd=False,
exclude=exclude, exclude=exclude,
annotates=annotating_components, annotates=annotating_components,
) )

View File

@ -324,12 +324,12 @@ and custom registered functions if needed. See the
> ``` > ```
| Name | Description | | Name | Description |
| --------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- | | --------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `examples` | A batch of [`Example`](/api/example) objects to learn from. ~~Iterable[Example]~~ | | `examples` | A batch of [`Example`](/api/example) objects to learn from. ~~Iterable[Example]~~ |
| _keyword-only_ | | | _keyword-only_ | |
| `drop` | The dropout rate. ~~float~~ | | `drop` | The dropout rate. Defaults to `0.0`. ~~float~~ |
| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ | | `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if `None`. No optimizer will be used when set to `False`. Defaults to `None`. ~~Union[Optimizer, None, Literal[False]]~~ |
| `losses` | Dictionary to update with the loss, keyed by pipeline component. ~~Optional[Dict[str, float]]~~ | | `losses` | Dictionary to update with the loss, keyed by pipeline component. Defaults to `None`. ~~Optional[Dict[str, float]]~~ |
| `component_cfg` | Optional dictionary of keyword arguments for components, keyed by component names. Defaults to `None`. ~~Optional[Dict[str, Dict[str, Any]]]~~ | | `component_cfg` | Optional dictionary of keyword arguments for components, keyed by component names. Defaults to `None`. ~~Optional[Dict[str, Dict[str, Any]]]~~ |
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ | | **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |