call replace_listener_cfg attr if it's available

This commit is contained in:
svlandeg 2021-05-12 17:19:38 +02:00
parent 44a3a58599
commit 235e9f5488

View File

@ -1764,6 +1764,7 @@ class Language:
raise ValueError(err)
tok2vec = self.get_pipe(tok2vec_name)
tok2vec_cfg = self.get_pipe_config(tok2vec_name)
tok2vec_model = tok2vec.model
if (
not hasattr(tok2vec, "model")
or not hasattr(tok2vec, "listener_map")
@ -1772,6 +1773,7 @@ class Language:
):
raise ValueError(Errors.E888.format(name=tok2vec_name, pipe=type(tok2vec)))
pipe_listeners = tok2vec.listener_map.get(pipe_name, [])
pipe = self.get_pipe(pipe_name)
pipe_cfg = self._pipe_configs[pipe_name]
if listeners:
util.logger.debug(f"Replacing listeners of component '{pipe_name}'")
@ -1786,7 +1788,6 @@ class Language:
n_listeners=len(pipe_listeners),
)
raise ValueError(err)
pipe = self.get_pipe(pipe_name)
# Update the config accordingly by copying the tok2vec model to all
# sections defined in the listener paths
for listener_path in listeners:
@ -1798,12 +1799,16 @@ class Language:
name=pipe_name, tok2vec=tok2vec_name, path=listener_path
)
raise ValueError(err)
util.set_dot_to_object(pipe_cfg, listener_path, tok2vec_cfg["model"])
new_config = tok2vec_cfg["model"]
if "replace_listener_cfg" in tok2vec_model.attrs:
replace_func = tok2vec_model.attrs["replace_listener_cfg"]
new_config = replace_func(tok2vec_cfg["model"], pipe_cfg["model"]["tok2vec"])
util.set_dot_to_object(pipe_cfg, listener_path, new_config)
# Go over the listener layers and replace them
for listener in pipe_listeners:
new_model = tok2vec.model.copy()
if "replace_listener" in new_model.attrs:
new_model = new_model.attrs["replace_listener"](new_model)
new_model = tok2vec_model.copy()
if "replace_listener" in tok2vec_model.attrs:
new_model = tok2vec_model.attrs["replace_listener"](new_model)
util.replace_model_node(pipe.model, listener, new_model)
tok2vec.remove_listener(listener, pipe_name)