mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 02:06:31 +03:00
Language.replace_listeners
: Pass the replaced listener and the tok2vec
pipe to the callback (#12785)
* `Language.replace_listeners`: Pass the replaced listener and the `tok2vec` pipe to the callback * Update developer docs * `isort` fixes * Add error message to assertion * Add clarification to dev docs * Replace assertion with exception * Doc fixes
This commit is contained in:
parent
6f3a71999e
commit
8113cfb257
|
@ -1,14 +1,17 @@
|
|||
# Listeners
|
||||
|
||||
1. [Overview](#1-overview)
|
||||
2. [Initialization](#2-initialization)
|
||||
- [A. Linking listeners to the embedding component](#2a-linking-listeners-to-the-embedding-component)
|
||||
- [B. Shape inference](#2b-shape-inference)
|
||||
3. [Internal communication](#3-internal-communication)
|
||||
- [A. During prediction](#3a-during-prediction)
|
||||
- [B. During training](#3b-during-training)
|
||||
- [C. Frozen components](#3c-frozen-components)
|
||||
4. [Replacing listener with standalone](#4-replacing-listener-with-standalone)
|
||||
- [1. Overview](#1-overview)
|
||||
- [2. Initialization](#2-initialization)
|
||||
- [2A. Linking listeners to the embedding component](#2a-linking-listeners-to-the-embedding-component)
|
||||
- [2B. Shape inference](#2b-shape-inference)
|
||||
- [3. Internal communication](#3-internal-communication)
|
||||
- [3A. During prediction](#3a-during-prediction)
|
||||
- [3B. During training](#3b-during-training)
|
||||
- [Training with multiple listeners](#training-with-multiple-listeners)
|
||||
- [3C. Frozen components](#3c-frozen-components)
|
||||
- [The Tok2Vec or Transformer is frozen](#the-tok2vec-or-transformer-is-frozen)
|
||||
- [The upstream component is frozen](#the-upstream-component-is-frozen)
|
||||
- [4. Replacing listener with standalone](#4-replacing-listener-with-standalone)
|
||||
|
||||
## 1. Overview
|
||||
|
||||
|
@ -218,3 +221,15 @@ new_model = tok2vec_model.attrs["replace_listener"](new_model)
|
|||
The new config and model are then properly stored on the `nlp` object.
|
||||
Note that this functionality (running the replacement for a transformer listener) was broken prior to
|
||||
`spacy-transformers` 1.0.5.
|
||||
|
||||
In spaCy 3.7, `Language.replace_listeners` was updated to pass the following additional arguments to the `replace_listener` callback:
|
||||
the listener to be replaced and the `tok2vec`/`transformer` pipe from which the new model was copied. To maintain backwards-compatiblity,
|
||||
the method only passes these extra arguments for callbacks that support them:
|
||||
|
||||
```
|
||||
def replace_listener_pre_37(copied_tok2vec_model):
|
||||
...
|
||||
|
||||
def replace_listener_post_37(copied_tok2vec_model, replaced_listener, tok2vec_pipe):
|
||||
...
|
||||
```
|
||||
|
|
|
@ -981,6 +981,8 @@ class Errors(metaclass=ErrorsWithCodes):
|
|||
" 'min_length': {min_length}, 'max_length': {max_length}")
|
||||
E1054 = ("The text, including whitespace, must match between reference and "
|
||||
"predicted docs when training {component}.")
|
||||
E1055 = ("The 'replace_listener' callback expects {num_params} parameters, "
|
||||
"but only callbacks with one or three parameters are supported")
|
||||
|
||||
|
||||
# Deprecated model shortcuts, only used in errors and warnings
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import functools
|
||||
import inspect
|
||||
import itertools
|
||||
import multiprocessing as mp
|
||||
import random
|
||||
|
@ -2033,8 +2034,20 @@ class Language:
|
|||
# Go over the listener layers and replace them
|
||||
for listener in pipe_listeners:
|
||||
new_model = tok2vec_model.copy()
|
||||
if "replace_listener" in tok2vec_model.attrs:
|
||||
new_model = tok2vec_model.attrs["replace_listener"](new_model)
|
||||
replace_listener_func = tok2vec_model.attrs.get("replace_listener")
|
||||
if replace_listener_func is not None:
|
||||
# Pass the extra args to the callback without breaking compatibility with
|
||||
# old library versions that only expect a single parameter.
|
||||
num_params = len(
|
||||
inspect.signature(replace_listener_func).parameters
|
||||
)
|
||||
if num_params == 1:
|
||||
new_model = replace_listener_func(new_model)
|
||||
elif num_params == 3:
|
||||
new_model = replace_listener_func(new_model, listener, tok2vec)
|
||||
else:
|
||||
raise ValueError(Errors.E1055.format(num_params=num_params))
|
||||
|
||||
util.replace_model_node(pipe.model, listener, new_model) # type: ignore[attr-defined]
|
||||
tok2vec.remove_listener(listener, pipe_name)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user