mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 16:07:41 +03:00 
			
		
		
		
	Language.replace_listeners: Pass the replaced listener and the tok2vec pipe to the callback (#12785)
				
					
				
			* `Language.replace_listeners`: Pass the replaced listener and the `tok2vec` pipe to the callback * Update developer docs * `isort` fixes * Add error message to assertion * Add clarification to dev docs * Replace assertion with exception * Doc fixes
This commit is contained in:
		
							parent
							
								
									6f3a71999e
								
							
						
					
					
						commit
						8113cfb257
					
				|  | @ -1,14 +1,17 @@ | ||||||
| # Listeners | # Listeners | ||||||
| 
 | 
 | ||||||
| 1. [Overview](#1-overview) | - [1. Overview](#1-overview) | ||||||
| 2. [Initialization](#2-initialization) | - [2. Initialization](#2-initialization) | ||||||
|    - [A. Linking listeners to the embedding component](#2a-linking-listeners-to-the-embedding-component) |   - [2A. Linking listeners to the embedding component](#2a-linking-listeners-to-the-embedding-component) | ||||||
|    - [B. Shape inference](#2b-shape-inference) |   - [2B. Shape inference](#2b-shape-inference) | ||||||
| 3. [Internal communication](#3-internal-communication) | - [3. Internal communication](#3-internal-communication) | ||||||
|    - [A. During prediction](#3a-during-prediction) |   - [3A. During prediction](#3a-during-prediction) | ||||||
|    - [B. During training](#3b-during-training) |   - [3B. During training](#3b-during-training) | ||||||
|    - [C. Frozen components](#3c-frozen-components) |     - [Training with multiple listeners](#training-with-multiple-listeners) | ||||||
| 4. [Replacing listener with standalone](#4-replacing-listener-with-standalone) |   - [3C. Frozen components](#3c-frozen-components) | ||||||
|  |     - [The Tok2Vec or Transformer is frozen](#the-tok2vec-or-transformer-is-frozen) | ||||||
|  |     - [The upstream component is frozen](#the-upstream-component-is-frozen) | ||||||
|  | - [4. Replacing listener with standalone](#4-replacing-listener-with-standalone) | ||||||
| 
 | 
 | ||||||
| ## 1. Overview | ## 1. Overview | ||||||
| 
 | 
 | ||||||
|  | @ -62,7 +65,7 @@ of this `find_listener()` method will specifically identify sublayers of a model | ||||||
| 
 | 
 | ||||||
| If it's a Transformer-based pipeline, a | If it's a Transformer-based pipeline, a | ||||||
| [`transformer` component](https://github.com/explosion/spacy-transformers/blob/master/spacy_transformers/pipeline_component.py) | [`transformer` component](https://github.com/explosion/spacy-transformers/blob/master/spacy_transformers/pipeline_component.py) | ||||||
| has a similar implementation but its `find_listener()` function will specifically look for `TransformerListener`  | has a similar implementation but its `find_listener()` function will specifically look for `TransformerListener` | ||||||
| sublayers of downstream components. | sublayers of downstream components. | ||||||
| 
 | 
 | ||||||
| ### 2B. Shape inference | ### 2B. Shape inference | ||||||
|  | @ -154,7 +157,7 @@ as a tagger or a parser. This used to be impossible before 3.1, but has become s | ||||||
| embedding component in the [`annotating_components`](https://spacy.io/usage/training#annotating-components) | embedding component in the [`annotating_components`](https://spacy.io/usage/training#annotating-components) | ||||||
| list of the config. This works like any other "annotating component" because it relies on the `Doc` attributes. | list of the config. This works like any other "annotating component" because it relies on the `Doc` attributes. | ||||||
| 
 | 
 | ||||||
| However, if the `Tok2Vec` or `Transformer` is frozen, and not present in `annotating_components`, and a related  | However, if the `Tok2Vec` or `Transformer` is frozen, and not present in `annotating_components`, and a related | ||||||
| listener isn't frozen, then a `W086` warning is shown and further training of the pipeline will likely end with `E954`. | listener isn't frozen, then a `W086` warning is shown and further training of the pipeline will likely end with `E954`. | ||||||
| 
 | 
 | ||||||
| #### The upstream component is frozen | #### The upstream component is frozen | ||||||
|  | @ -216,5 +219,17 @@ new_model = tok2vec_model.attrs["replace_listener"](new_model) | ||||||
| ``` | ``` | ||||||
| 
 | 
 | ||||||
| The new config and model are then properly stored on the `nlp` object. | The new config and model are then properly stored on the `nlp` object. | ||||||
| Note that this functionality (running the replacement for a transformer listener) was broken prior to  | Note that this functionality (running the replacement for a transformer listener) was broken prior to | ||||||
| `spacy-transformers` 1.0.5. | `spacy-transformers` 1.0.5. | ||||||
|  | 
 | ||||||
|  | In spaCy 3.7, `Language.replace_listeners` was updated to pass the following additional arguments to the `replace_listener` callback: | ||||||
|  | the listener to be replaced and the `tok2vec`/`transformer` pipe from which the new model was copied. To maintain backwards-compatiblity, | ||||||
|  | the method only passes these extra arguments for callbacks that support them: | ||||||
|  | 
 | ||||||
|  | ``` | ||||||
|  | def replace_listener_pre_37(copied_tok2vec_model): | ||||||
|  |   ... | ||||||
|  | 
 | ||||||
|  | def replace_listener_post_37(copied_tok2vec_model, replaced_listener, tok2vec_pipe): | ||||||
|  |   ... | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | @ -981,6 +981,8 @@ class Errors(metaclass=ErrorsWithCodes): | ||||||
|              " 'min_length': {min_length}, 'max_length': {max_length}") |              " 'min_length': {min_length}, 'max_length': {max_length}") | ||||||
|     E1054 = ("The text, including whitespace, must match between reference and " |     E1054 = ("The text, including whitespace, must match between reference and " | ||||||
|              "predicted docs when training {component}.") |              "predicted docs when training {component}.") | ||||||
|  |     E1055 = ("The 'replace_listener' callback expects {num_params} parameters, " | ||||||
|  |              "but only callbacks with one or three parameters are supported") | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| # Deprecated model shortcuts, only used in errors and warnings | # Deprecated model shortcuts, only used in errors and warnings | ||||||
|  |  | ||||||
|  | @ -1,4 +1,5 @@ | ||||||
| import functools | import functools | ||||||
|  | import inspect | ||||||
| import itertools | import itertools | ||||||
| import multiprocessing as mp | import multiprocessing as mp | ||||||
| import random | import random | ||||||
|  | @ -2033,8 +2034,20 @@ class Language: | ||||||
|             # Go over the listener layers and replace them |             # Go over the listener layers and replace them | ||||||
|             for listener in pipe_listeners: |             for listener in pipe_listeners: | ||||||
|                 new_model = tok2vec_model.copy() |                 new_model = tok2vec_model.copy() | ||||||
|                 if "replace_listener" in tok2vec_model.attrs: |                 replace_listener_func = tok2vec_model.attrs.get("replace_listener") | ||||||
|                     new_model = tok2vec_model.attrs["replace_listener"](new_model) |                 if replace_listener_func is not None: | ||||||
|  |                     # Pass the extra args to the callback without breaking compatibility with | ||||||
|  |                     # old library versions that only expect a single parameter. | ||||||
|  |                     num_params = len( | ||||||
|  |                         inspect.signature(replace_listener_func).parameters | ||||||
|  |                     ) | ||||||
|  |                     if num_params == 1: | ||||||
|  |                         new_model = replace_listener_func(new_model) | ||||||
|  |                     elif num_params == 3: | ||||||
|  |                         new_model = replace_listener_func(new_model, listener, tok2vec) | ||||||
|  |                     else: | ||||||
|  |                         raise ValueError(Errors.E1055.format(num_params=num_params)) | ||||||
|  | 
 | ||||||
|                 util.replace_model_node(pipe.model, listener, new_model)  # type: ignore[attr-defined] |                 util.replace_model_node(pipe.model, listener, new_model)  # type: ignore[attr-defined] | ||||||
|                 tok2vec.remove_listener(listener, pipe_name) |                 tok2vec.remove_listener(listener, pipe_name) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user