mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 07:57:35 +03:00 
			
		
		
		
	Avoid TrainablePipe.finish_update getting called twice during training
				
					
				
			PR #12136 fixed an issue where the tok2vec pipe was updated before gradient were accumulated. However, it introduced a new bug that cause `finish_update` to be called twice when using the training loop. This causes a fairly large slowdown. The `Language.update` method accepts the `sgd` argument for passing an optimizer. This argument has three possible values: - `Optimizer`: use the given optimizer to finish pipe updates. - `None`: use a default optimizer to finish pipe updates. - `False`: do not finish pipe updates. However, the latter option was not documented and not valid with the existing type of `sgd`. I assumed that this was a remnant of earlier spaCy versions and removed handling of `False`. However, with that change, we are passing `None` to `Language.update`. As a result, we were calling `finish_update` in both `Language.update` and in the training loop after all subbatches are processed. This change restores proper handling/use of `False`. Moreover, the role of `False` is now documented and added to the type to avoid future accidents.
This commit is contained in:
		
							parent
							
								
									9340eb8ad2
								
							
						
					
					
						commit
						c53606d3b3
					
				|  | @ -1202,7 +1202,7 @@ class Language: | |||
|         _: Optional[Any] = None, | ||||
|         *, | ||||
|         drop: float = 0.0, | ||||
|         sgd: Optional[Optimizer] = None, | ||||
|         sgd: Union[Optimizer, None, Literal[False]] = None, | ||||
|         losses: Optional[Dict[str, float]] = None, | ||||
|         component_cfg: Optional[Dict[str, Dict[str, Any]]] = None, | ||||
|         exclude: Iterable[str] = SimpleFrozenList(), | ||||
|  | @ -1213,7 +1213,9 @@ class Language: | |||
|         examples (Iterable[Example]): A batch of examples | ||||
|         _: Should not be set - serves to catch backwards-incompatible scripts. | ||||
|         drop (float): The dropout rate. | ||||
|         sgd (Optimizer): An optimizer. | ||||
|         sgd (Union[Optimizer, None, Literal[False]]): An optimizer. Will | ||||
|             be created via create_optimizer if 'None'. No optimizer will | ||||
|             be used when set to 'False'. | ||||
|         losses (Dict[str, float]): Dictionary to update with the loss, keyed by | ||||
|             component. | ||||
|         component_cfg (Dict[str, Dict]): Config parameters for specific pipeline | ||||
|  | @ -1272,6 +1274,7 @@ class Language: | |||
|                 name not in exclude | ||||
|                 and isinstance(proc, ty.TrainableComponent) | ||||
|                 and proc.is_trainable | ||||
|                 and sgd not in (None, False) | ||||
|             ): | ||||
|                 proc.finish_update(sgd) | ||||
| 
 | ||||
|  |  | |||
|  | @ -157,6 +157,24 @@ def test_language_update_updates(): | |||
|     ) | ||||
| 
 | ||||
| 
 | ||||
| def test_language_update_does_not_update_with_sgd_false(): | ||||
|     config = Config().from_str(TAGGER_CFG_STRING) | ||||
|     nlp = load_model_from_config(config, auto_fill=True, validate=True) | ||||
| 
 | ||||
|     train_examples = [] | ||||
|     for t in TAGGER_TRAIN_DATA: | ||||
|         train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1])) | ||||
| 
 | ||||
|     nlp.initialize(get_examples=lambda: train_examples) | ||||
| 
 | ||||
|     docs_before_update = list(nlp.pipe([eg.predicted.copy() for eg in train_examples])) | ||||
|     nlp.update(train_examples, sgd=False) | ||||
|     docs_after_update = list(nlp.pipe([eg.predicted.copy() for eg in train_examples])) | ||||
| 
 | ||||
|     xp = get_array_module(docs_after_update[0].tensor) | ||||
|     xp.testing.assert_equal(docs_before_update[0].tensor, docs_after_update[0].tensor) | ||||
| 
 | ||||
| 
 | ||||
| def test_language_evaluate(nlp): | ||||
|     text = "hello world" | ||||
|     annots = {"doc_annotation": {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}}} | ||||
|  |  | |||
|  | @ -210,7 +210,7 @@ def train_while_improving( | |||
|                 subbatch, | ||||
|                 drop=dropout, | ||||
|                 losses=losses, | ||||
|                 sgd=None, | ||||
|                 sgd=False, | ||||
|                 exclude=exclude, | ||||
|                 annotates=annotating_components, | ||||
|             ) | ||||
|  |  | |||
|  | @ -323,15 +323,15 @@ and custom registered functions if needed. See the | |||
| >     nlp.update([example], sgd=optimizer) | ||||
| > ``` | ||||
| 
 | ||||
| | Name            | Description                                                                                                                                    | | ||||
| | --------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- | | ||||
| | `examples`      | A batch of [`Example`](/api/example) objects to learn from. ~~Iterable[Example]~~                                                              | | ||||
| | _keyword-only_  |                                                                                                                                                | | ||||
| | `drop`          | The dropout rate. ~~float~~                                                                                                                    | | ||||
| | `sgd`           | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~                                  | | ||||
| | `losses`        | Dictionary to update with the loss, keyed by pipeline component. ~~Optional[Dict[str, float]]~~                                                | | ||||
| | `component_cfg` | Optional dictionary of keyword arguments for components, keyed by component names. Defaults to `None`. ~~Optional[Dict[str, Dict[str, Any]]]~~ | | ||||
| | **RETURNS**     | The updated `losses` dictionary. ~~Dict[str, float]~~                                                                                          | | ||||
| | Name            | Description                                                                                                                                                                                           | | ||||
| | --------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | ||||
| | `examples`      | A batch of [`Example`](/api/example) objects to learn from. ~~Iterable[Example]~~                                                                                                                     | | ||||
| | _keyword-only_  |                                                                                                                                                                                                       | | ||||
| | `drop`          | The dropout rate. ~~float~~                                                                                                                                                                           | | ||||
| | `sgd`           | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if `None`. No optimizer will not be used when set to `False`. Defaults to `None` ~~Union[Optimizer, None, Literal[False]]~~ | | ||||
| | `losses`        | Dictionary to update with the loss, keyed by pipeline component. ~~Optional[Dict[str, float]]~~                                                                                                       | | ||||
| | `component_cfg` | Optional dictionary of keyword arguments for components, keyed by component names. Defaults to `None`. ~~Optional[Dict[str, Dict[str, Any]]]~~                                                        | | ||||
| | **RETURNS**     | The updated `losses` dictionary. ~~Dict[str, float]~~                                                                                                                                                 | | ||||
| 
 | ||||
| ## Language.distill {id="distill",tag="method,experimental",version="4"} | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user