mirror of
https://github.com/explosion/spaCy.git
synced 2025-04-21 09:31:59 +03:00
Language.replace_listeners
: Pass the replaced listener and the tok2vec
pipe to the callback
This commit is contained in:
parent
57a230c6e4
commit
6711143f5b
|
@ -1,4 +1,5 @@
|
|||
import functools
|
||||
import inspect
|
||||
import itertools
|
||||
import multiprocessing as mp
|
||||
import random
|
||||
|
@ -30,6 +31,7 @@ from typing import (
|
|||
)
|
||||
|
||||
import srsly
|
||||
|
||||
from thinc.api import Config, CupyOps, Optimizer, get_current_ops
|
||||
|
||||
from . import about, ty, util
|
||||
|
@ -2033,8 +2035,18 @@ class Language:
|
|||
# Go over the listener layers and replace them
|
||||
for listener in pipe_listeners:
|
||||
new_model = tok2vec_model.copy()
|
||||
if "replace_listener" in tok2vec_model.attrs:
|
||||
new_model = tok2vec_model.attrs["replace_listener"](new_model)
|
||||
replace_listener_func = tok2vec_model.attrs.get("replace_listener")
|
||||
if replace_listener_func is not None:
|
||||
# Pass the extra args to the callback without breaking compatibility with
|
||||
# old library versions that only expect a single parameter.
|
||||
num_params = len(
|
||||
inspect.signature(replace_listener_func).parameters
|
||||
)
|
||||
assert num_params in (1, 3)
|
||||
if num_params == 1:
|
||||
new_model = replace_listener_func(new_model)
|
||||
else:
|
||||
new_model = replace_listener_func(new_model, listener, tok2vec)
|
||||
util.replace_model_node(pipe.model, listener, new_model) # type: ignore[attr-defined]
|
||||
tok2vec.remove_listener(listener, pipe_name)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user