mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-11 08:42:28 +03:00
Language.replace_listeners
: Pass the replaced listener and the tok2vec
pipe to the callback
This commit is contained in:
parent
57a230c6e4
commit
6711143f5b
|
@ -1,4 +1,5 @@
|
||||||
import functools
|
import functools
|
||||||
|
import inspect
|
||||||
import itertools
|
import itertools
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
import random
|
import random
|
||||||
|
@ -30,6 +31,7 @@ from typing import (
|
||||||
)
|
)
|
||||||
|
|
||||||
import srsly
|
import srsly
|
||||||
|
|
||||||
from thinc.api import Config, CupyOps, Optimizer, get_current_ops
|
from thinc.api import Config, CupyOps, Optimizer, get_current_ops
|
||||||
|
|
||||||
from . import about, ty, util
|
from . import about, ty, util
|
||||||
|
@ -2033,8 +2035,18 @@ class Language:
|
||||||
# Go over the listener layers and replace them
|
# Go over the listener layers and replace them
|
||||||
for listener in pipe_listeners:
|
for listener in pipe_listeners:
|
||||||
new_model = tok2vec_model.copy()
|
new_model = tok2vec_model.copy()
|
||||||
if "replace_listener" in tok2vec_model.attrs:
|
replace_listener_func = tok2vec_model.attrs.get("replace_listener")
|
||||||
new_model = tok2vec_model.attrs["replace_listener"](new_model)
|
if replace_listener_func is not None:
|
||||||
|
# Pass the extra args to the callback without breaking compatibility with
|
||||||
|
# old library versions that only expect a single parameter.
|
||||||
|
num_params = len(
|
||||||
|
inspect.signature(replace_listener_func).parameters
|
||||||
|
)
|
||||||
|
assert num_params in (1, 3)
|
||||||
|
if num_params == 1:
|
||||||
|
new_model = replace_listener_func(new_model)
|
||||||
|
else:
|
||||||
|
new_model = replace_listener_func(new_model, listener, tok2vec)
|
||||||
util.replace_model_node(pipe.model, listener, new_model) # type: ignore[attr-defined]
|
util.replace_model_node(pipe.model, listener, new_model) # type: ignore[attr-defined]
|
||||||
tok2vec.remove_listener(listener, pipe_name)
|
tok2vec.remove_listener(listener, pipe_name)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user