Language.replace_listeners: Pass the replaced listener and the tok2vec pipe to the callback

This commit is contained in:
shadeMe 2023-07-04 14:16:45 +02:00
parent 57a230c6e4
commit 6711143f5b
No known key found for this signature in database
GPG Key ID: 6FCA9FC635B2A402

View File

@ -1,4 +1,5 @@
import functools
import inspect
import itertools
import multiprocessing as mp
import random
@ -30,6 +31,7 @@ from typing import (
)
import srsly
from thinc.api import Config, CupyOps, Optimizer, get_current_ops
from . import about, ty, util
@ -2033,8 +2035,18 @@ class Language:
# Go over the listener layers and replace them
for listener in pipe_listeners:
new_model = tok2vec_model.copy()
if "replace_listener" in tok2vec_model.attrs:
new_model = tok2vec_model.attrs["replace_listener"](new_model)
replace_listener_func = tok2vec_model.attrs.get("replace_listener")
if replace_listener_func is not None:
# Pass the extra args to the callback without breaking compatibility with
# old library versions that only expect a single parameter.
num_params = len(
inspect.signature(replace_listener_func).parameters
)
assert num_params in (1, 3)
if num_params == 1:
new_model = replace_listener_func(new_model)
else:
new_model = replace_listener_func(new_model, listener, tok2vec)
util.replace_model_node(pipe.model, listener, new_model) # type: ignore[attr-defined]
tok2vec.remove_listener(listener, pipe_name)