Language.replace_listeners: Pass the replaced listener and the tok2vec pipe to the callback (#12785)

* `Language.replace_listeners`: Pass the replaced listener and the `tok2vec` pipe to the callback

* Update developer docs

* `isort` fixes

* Add error message to assertion

* Add clarification to dev docs

* Replace assertion with exception

* Doc fixes
This commit is contained in:
Madeesh Kannan 2023-07-05 13:36:04 +02:00 committed by GitHub
parent 6f3a71999e
commit 8113cfb257
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 44 additions and 14 deletions

View File

@ -1,14 +1,17 @@
# Listeners
1. [Overview](#1-overview)
2. [Initialization](#2-initialization)
- [A. Linking listeners to the embedding component](#2a-linking-listeners-to-the-embedding-component)
- [B. Shape inference](#2b-shape-inference)
3. [Internal communication](#3-internal-communication)
- [A. During prediction](#3a-during-prediction)
- [B. During training](#3b-during-training)
- [C. Frozen components](#3c-frozen-components)
4. [Replacing listener with standalone](#4-replacing-listener-with-standalone)
- [1. Overview](#1-overview)
- [2. Initialization](#2-initialization)
- [2A. Linking listeners to the embedding component](#2a-linking-listeners-to-the-embedding-component)
- [2B. Shape inference](#2b-shape-inference)
- [3. Internal communication](#3-internal-communication)
- [3A. During prediction](#3a-during-prediction)
- [3B. During training](#3b-during-training)
- [Training with multiple listeners](#training-with-multiple-listeners)
- [3C. Frozen components](#3c-frozen-components)
- [The Tok2Vec or Transformer is frozen](#the-tok2vec-or-transformer-is-frozen)
- [The upstream component is frozen](#the-upstream-component-is-frozen)
- [4. Replacing listener with standalone](#4-replacing-listener-with-standalone)
## 1. Overview
@ -218,3 +221,15 @@ new_model = tok2vec_model.attrs["replace_listener"](new_model)
The new config and model are then properly stored on the `nlp` object.
Note that this functionality (running the replacement for a transformer listener) was broken prior to
`spacy-transformers` 1.0.5.
In spaCy 3.7, `Language.replace_listeners` was updated to pass the following additional arguments to the `replace_listener` callback:
the listener to be replaced and the `tok2vec`/`transformer` pipe from which the new model was copied. To maintain backwards-compatiblity,
the method only passes these extra arguments for callbacks that support them:
```
def replace_listener_pre_37(copied_tok2vec_model):
...
def replace_listener_post_37(copied_tok2vec_model, replaced_listener, tok2vec_pipe):
...
```

View File

@ -981,6 +981,8 @@ class Errors(metaclass=ErrorsWithCodes):
" 'min_length': {min_length}, 'max_length': {max_length}")
E1054 = ("The text, including whitespace, must match between reference and "
"predicted docs when training {component}.")
E1055 = ("The 'replace_listener' callback expects {num_params} parameters, "
"but only callbacks with one or three parameters are supported")
# Deprecated model shortcuts, only used in errors and warnings

View File

@ -1,4 +1,5 @@
import functools
import inspect
import itertools
import multiprocessing as mp
import random
@ -2033,8 +2034,20 @@ class Language:
# Go over the listener layers and replace them
for listener in pipe_listeners:
new_model = tok2vec_model.copy()
if "replace_listener" in tok2vec_model.attrs:
new_model = tok2vec_model.attrs["replace_listener"](new_model)
replace_listener_func = tok2vec_model.attrs.get("replace_listener")
if replace_listener_func is not None:
# Pass the extra args to the callback without breaking compatibility with
# old library versions that only expect a single parameter.
num_params = len(
inspect.signature(replace_listener_func).parameters
)
if num_params == 1:
new_model = replace_listener_func(new_model)
elif num_params == 3:
new_model = replace_listener_func(new_model, listener, tok2vec)
else:
raise ValueError(Errors.E1055.format(num_params=num_params))
util.replace_model_node(pipe.model, listener, new_model) # type: ignore[attr-defined]
tok2vec.remove_listener(listener, pipe_name)