From b734e5314d3b8faa7d463c265db9a823a113165c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Thu, 30 Mar 2023 09:30:42 +0200 Subject: [PATCH] Avoid `TrainablePipe.finish_update` getting called twice during training (#12450) * Avoid `TrainablePipe.finish_update` getting called twice during training PR #12136 fixed an issue where the tok2vec pipe was updated before gradient were accumulated. However, it introduced a new bug that cause `finish_update` to be called twice when using the training loop. This causes a fairly large slowdown. The `Language.update` method accepts the `sgd` argument for passing an optimizer. This argument has three possible values: - `Optimizer`: use the given optimizer to finish pipe updates. - `None`: use a default optimizer to finish pipe updates. - `False`: do not finish pipe updates. However, the latter option was not documented and not valid with the existing type of `sgd`. I assumed that this was a remnant of earlier spaCy versions and removed handling of `False`. However, with that change, we are passing `None` to `Language.update`. As a result, we were calling `finish_update` in both `Language.update` and in the training loop after all subbatches are processed. This change restores proper handling/use of `False`. Moreover, the role of `False` is now documented and added to the type to avoid future accidents. * Fix typo * Document defaults for `Language.update` --- spacy/language.py | 7 +++++-- spacy/tests/test_language.py | 18 ++++++++++++++++++ spacy/training/loop.py | 2 +- website/docs/api/language.mdx | 18 +++++++++--------- 4 files changed, 33 insertions(+), 12 deletions(-) diff --git a/spacy/language.py b/spacy/language.py index 3b86fdde7..ce3630629 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -1202,7 +1202,7 @@ class Language: _: Optional[Any] = None, *, drop: float = 0.0, - sgd: Optional[Optimizer] = None, + sgd: Union[Optimizer, None, Literal[False]] = None, losses: Optional[Dict[str, float]] = None, component_cfg: Optional[Dict[str, Dict[str, Any]]] = None, exclude: Iterable[str] = SimpleFrozenList(), @@ -1213,7 +1213,9 @@ class Language: examples (Iterable[Example]): A batch of examples _: Should not be set - serves to catch backwards-incompatible scripts. drop (float): The dropout rate. - sgd (Optimizer): An optimizer. + sgd (Union[Optimizer, None, Literal[False]]): An optimizer. Will + be created via create_optimizer if 'None'. No optimizer will + be used when set to 'False'. losses (Dict[str, float]): Dictionary to update with the loss, keyed by component. component_cfg (Dict[str, Dict]): Config parameters for specific pipeline @@ -1272,6 +1274,7 @@ class Language: name not in exclude and isinstance(proc, ty.TrainableComponent) and proc.is_trainable + and sgd not in (None, False) ): proc.finish_update(sgd) diff --git a/spacy/tests/test_language.py b/spacy/tests/test_language.py index 9b8c7b9c7..08a7d28a4 100644 --- a/spacy/tests/test_language.py +++ b/spacy/tests/test_language.py @@ -157,6 +157,24 @@ def test_language_update_updates(): ) +def test_language_update_does_not_update_with_sgd_false(): + config = Config().from_str(TAGGER_CFG_STRING) + nlp = load_model_from_config(config, auto_fill=True, validate=True) + + train_examples = [] + for t in TAGGER_TRAIN_DATA: + train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1])) + + nlp.initialize(get_examples=lambda: train_examples) + + docs_before_update = list(nlp.pipe([eg.predicted.copy() for eg in train_examples])) + nlp.update(train_examples, sgd=False) + docs_after_update = list(nlp.pipe([eg.predicted.copy() for eg in train_examples])) + + xp = get_array_module(docs_after_update[0].tensor) + xp.testing.assert_equal(docs_before_update[0].tensor, docs_after_update[0].tensor) + + def test_language_evaluate(nlp): text = "hello world" annots = {"doc_annotation": {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}}} diff --git a/spacy/training/loop.py b/spacy/training/loop.py index c737d7c01..587a2516c 100644 --- a/spacy/training/loop.py +++ b/spacy/training/loop.py @@ -210,7 +210,7 @@ def train_while_improving( subbatch, drop=dropout, losses=losses, - sgd=None, + sgd=False, exclude=exclude, annotates=annotating_components, ) diff --git a/website/docs/api/language.mdx b/website/docs/api/language.mdx index c25bfcee5..5cd9e4af8 100644 --- a/website/docs/api/language.mdx +++ b/website/docs/api/language.mdx @@ -323,15 +323,15 @@ and custom registered functions if needed. See the > nlp.update([example], sgd=optimizer) > ``` -| Name | Description | -| --------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- | -| `examples` | A batch of [`Example`](/api/example) objects to learn from. ~~Iterable[Example]~~ | -| _keyword-only_ | | -| `drop` | The dropout rate. ~~float~~ | -| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ | -| `losses` | Dictionary to update with the loss, keyed by pipeline component. ~~Optional[Dict[str, float]]~~ | -| `component_cfg` | Optional dictionary of keyword arguments for components, keyed by component names. Defaults to `None`. ~~Optional[Dict[str, Dict[str, Any]]]~~ | -| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ | +| Name | Description | +| --------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `examples` | A batch of [`Example`](/api/example) objects to learn from. ~~Iterable[Example]~~ | +| _keyword-only_ | | +| `drop` | The dropout rate. Defaults to `0.0`. ~~float~~ | +| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if `None`. No optimizer will be used when set to `False`. Defaults to `None`. ~~Union[Optimizer, None, Literal[False]]~~ | +| `losses` | Dictionary to update with the loss, keyed by pipeline component. Defaults to `None`. ~~Optional[Dict[str, float]]~~ | +| `component_cfg` | Optional dictionary of keyword arguments for components, keyed by component names. Defaults to `None`. ~~Optional[Dict[str, Dict[str, Any]]]~~ | +| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ | ## Language.distill {id="distill",tag="method,experimental",version="4"}