mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 02:06:31 +03:00
Language.replace_listeners
: Pass the replaced listener and the tok2vec
pipe to the callback (#12785)
* `Language.replace_listeners`: Pass the replaced listener and the `tok2vec` pipe to the callback * Update developer docs * `isort` fixes * Add error message to assertion * Add clarification to dev docs * Replace assertion with exception * Doc fixes
This commit is contained in:
parent
6f3a71999e
commit
8113cfb257
|
@ -1,14 +1,17 @@
|
||||||
# Listeners
|
# Listeners
|
||||||
|
|
||||||
1. [Overview](#1-overview)
|
- [1. Overview](#1-overview)
|
||||||
2. [Initialization](#2-initialization)
|
- [2. Initialization](#2-initialization)
|
||||||
- [A. Linking listeners to the embedding component](#2a-linking-listeners-to-the-embedding-component)
|
- [2A. Linking listeners to the embedding component](#2a-linking-listeners-to-the-embedding-component)
|
||||||
- [B. Shape inference](#2b-shape-inference)
|
- [2B. Shape inference](#2b-shape-inference)
|
||||||
3. [Internal communication](#3-internal-communication)
|
- [3. Internal communication](#3-internal-communication)
|
||||||
- [A. During prediction](#3a-during-prediction)
|
- [3A. During prediction](#3a-during-prediction)
|
||||||
- [B. During training](#3b-during-training)
|
- [3B. During training](#3b-during-training)
|
||||||
- [C. Frozen components](#3c-frozen-components)
|
- [Training with multiple listeners](#training-with-multiple-listeners)
|
||||||
4. [Replacing listener with standalone](#4-replacing-listener-with-standalone)
|
- [3C. Frozen components](#3c-frozen-components)
|
||||||
|
- [The Tok2Vec or Transformer is frozen](#the-tok2vec-or-transformer-is-frozen)
|
||||||
|
- [The upstream component is frozen](#the-upstream-component-is-frozen)
|
||||||
|
- [4. Replacing listener with standalone](#4-replacing-listener-with-standalone)
|
||||||
|
|
||||||
## 1. Overview
|
## 1. Overview
|
||||||
|
|
||||||
|
@ -62,7 +65,7 @@ of this `find_listener()` method will specifically identify sublayers of a model
|
||||||
|
|
||||||
If it's a Transformer-based pipeline, a
|
If it's a Transformer-based pipeline, a
|
||||||
[`transformer` component](https://github.com/explosion/spacy-transformers/blob/master/spacy_transformers/pipeline_component.py)
|
[`transformer` component](https://github.com/explosion/spacy-transformers/blob/master/spacy_transformers/pipeline_component.py)
|
||||||
has a similar implementation but its `find_listener()` function will specifically look for `TransformerListener`
|
has a similar implementation but its `find_listener()` function will specifically look for `TransformerListener`
|
||||||
sublayers of downstream components.
|
sublayers of downstream components.
|
||||||
|
|
||||||
### 2B. Shape inference
|
### 2B. Shape inference
|
||||||
|
@ -154,7 +157,7 @@ as a tagger or a parser. This used to be impossible before 3.1, but has become s
|
||||||
embedding component in the [`annotating_components`](https://spacy.io/usage/training#annotating-components)
|
embedding component in the [`annotating_components`](https://spacy.io/usage/training#annotating-components)
|
||||||
list of the config. This works like any other "annotating component" because it relies on the `Doc` attributes.
|
list of the config. This works like any other "annotating component" because it relies on the `Doc` attributes.
|
||||||
|
|
||||||
However, if the `Tok2Vec` or `Transformer` is frozen, and not present in `annotating_components`, and a related
|
However, if the `Tok2Vec` or `Transformer` is frozen, and not present in `annotating_components`, and a related
|
||||||
listener isn't frozen, then a `W086` warning is shown and further training of the pipeline will likely end with `E954`.
|
listener isn't frozen, then a `W086` warning is shown and further training of the pipeline will likely end with `E954`.
|
||||||
|
|
||||||
#### The upstream component is frozen
|
#### The upstream component is frozen
|
||||||
|
@ -216,5 +219,17 @@ new_model = tok2vec_model.attrs["replace_listener"](new_model)
|
||||||
```
|
```
|
||||||
|
|
||||||
The new config and model are then properly stored on the `nlp` object.
|
The new config and model are then properly stored on the `nlp` object.
|
||||||
Note that this functionality (running the replacement for a transformer listener) was broken prior to
|
Note that this functionality (running the replacement for a transformer listener) was broken prior to
|
||||||
`spacy-transformers` 1.0.5.
|
`spacy-transformers` 1.0.5.
|
||||||
|
|
||||||
|
In spaCy 3.7, `Language.replace_listeners` was updated to pass the following additional arguments to the `replace_listener` callback:
|
||||||
|
the listener to be replaced and the `tok2vec`/`transformer` pipe from which the new model was copied. To maintain backwards-compatiblity,
|
||||||
|
the method only passes these extra arguments for callbacks that support them:
|
||||||
|
|
||||||
|
```
|
||||||
|
def replace_listener_pre_37(copied_tok2vec_model):
|
||||||
|
...
|
||||||
|
|
||||||
|
def replace_listener_post_37(copied_tok2vec_model, replaced_listener, tok2vec_pipe):
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
|
@ -981,6 +981,8 @@ class Errors(metaclass=ErrorsWithCodes):
|
||||||
" 'min_length': {min_length}, 'max_length': {max_length}")
|
" 'min_length': {min_length}, 'max_length': {max_length}")
|
||||||
E1054 = ("The text, including whitespace, must match between reference and "
|
E1054 = ("The text, including whitespace, must match between reference and "
|
||||||
"predicted docs when training {component}.")
|
"predicted docs when training {component}.")
|
||||||
|
E1055 = ("The 'replace_listener' callback expects {num_params} parameters, "
|
||||||
|
"but only callbacks with one or three parameters are supported")
|
||||||
|
|
||||||
|
|
||||||
# Deprecated model shortcuts, only used in errors and warnings
|
# Deprecated model shortcuts, only used in errors and warnings
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import functools
|
import functools
|
||||||
|
import inspect
|
||||||
import itertools
|
import itertools
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
import random
|
import random
|
||||||
|
@ -2033,8 +2034,20 @@ class Language:
|
||||||
# Go over the listener layers and replace them
|
# Go over the listener layers and replace them
|
||||||
for listener in pipe_listeners:
|
for listener in pipe_listeners:
|
||||||
new_model = tok2vec_model.copy()
|
new_model = tok2vec_model.copy()
|
||||||
if "replace_listener" in tok2vec_model.attrs:
|
replace_listener_func = tok2vec_model.attrs.get("replace_listener")
|
||||||
new_model = tok2vec_model.attrs["replace_listener"](new_model)
|
if replace_listener_func is not None:
|
||||||
|
# Pass the extra args to the callback without breaking compatibility with
|
||||||
|
# old library versions that only expect a single parameter.
|
||||||
|
num_params = len(
|
||||||
|
inspect.signature(replace_listener_func).parameters
|
||||||
|
)
|
||||||
|
if num_params == 1:
|
||||||
|
new_model = replace_listener_func(new_model)
|
||||||
|
elif num_params == 3:
|
||||||
|
new_model = replace_listener_func(new_model, listener, tok2vec)
|
||||||
|
else:
|
||||||
|
raise ValueError(Errors.E1055.format(num_params=num_params))
|
||||||
|
|
||||||
util.replace_model_node(pipe.model, listener, new_model) # type: ignore[attr-defined]
|
util.replace_model_node(pipe.model, listener, new_model) # type: ignore[attr-defined]
|
||||||
tok2vec.remove_listener(listener, pipe_name)
|
tok2vec.remove_listener(listener, pipe_name)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user