mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-13 01:32:32 +03:00
Merge pull request #8112 from svlandeg/bugfix/replace-trf
This commit is contained in:
commit
5957ab74f7
|
@ -1769,6 +1769,7 @@ class Language:
|
||||||
raise ValueError(err)
|
raise ValueError(err)
|
||||||
tok2vec = self.get_pipe(tok2vec_name)
|
tok2vec = self.get_pipe(tok2vec_name)
|
||||||
tok2vec_cfg = self.get_pipe_config(tok2vec_name)
|
tok2vec_cfg = self.get_pipe_config(tok2vec_name)
|
||||||
|
tok2vec_model = tok2vec.model
|
||||||
if (
|
if (
|
||||||
not hasattr(tok2vec, "model")
|
not hasattr(tok2vec, "model")
|
||||||
or not hasattr(tok2vec, "listener_map")
|
or not hasattr(tok2vec, "listener_map")
|
||||||
|
@ -1777,6 +1778,7 @@ class Language:
|
||||||
):
|
):
|
||||||
raise ValueError(Errors.E888.format(name=tok2vec_name, pipe=type(tok2vec)))
|
raise ValueError(Errors.E888.format(name=tok2vec_name, pipe=type(tok2vec)))
|
||||||
pipe_listeners = tok2vec.listener_map.get(pipe_name, [])
|
pipe_listeners = tok2vec.listener_map.get(pipe_name, [])
|
||||||
|
pipe = self.get_pipe(pipe_name)
|
||||||
pipe_cfg = self._pipe_configs[pipe_name]
|
pipe_cfg = self._pipe_configs[pipe_name]
|
||||||
if listeners:
|
if listeners:
|
||||||
util.logger.debug(f"Replacing listeners of component '{pipe_name}'")
|
util.logger.debug(f"Replacing listeners of component '{pipe_name}'")
|
||||||
|
@ -1791,7 +1793,6 @@ class Language:
|
||||||
n_listeners=len(pipe_listeners),
|
n_listeners=len(pipe_listeners),
|
||||||
)
|
)
|
||||||
raise ValueError(err)
|
raise ValueError(err)
|
||||||
pipe = self.get_pipe(pipe_name)
|
|
||||||
# Update the config accordingly by copying the tok2vec model to all
|
# Update the config accordingly by copying the tok2vec model to all
|
||||||
# sections defined in the listener paths
|
# sections defined in the listener paths
|
||||||
for listener_path in listeners:
|
for listener_path in listeners:
|
||||||
|
@ -1803,10 +1804,17 @@ class Language:
|
||||||
name=pipe_name, tok2vec=tok2vec_name, path=listener_path
|
name=pipe_name, tok2vec=tok2vec_name, path=listener_path
|
||||||
)
|
)
|
||||||
raise ValueError(err)
|
raise ValueError(err)
|
||||||
util.set_dot_to_object(pipe_cfg, listener_path, tok2vec_cfg["model"])
|
new_config = tok2vec_cfg["model"]
|
||||||
|
if "replace_listener_cfg" in tok2vec_model.attrs:
|
||||||
|
replace_func = tok2vec_model.attrs["replace_listener_cfg"]
|
||||||
|
new_config = replace_func(tok2vec_cfg["model"], pipe_cfg["model"]["tok2vec"])
|
||||||
|
util.set_dot_to_object(pipe_cfg, listener_path, new_config)
|
||||||
# Go over the listener layers and replace them
|
# Go over the listener layers and replace them
|
||||||
for listener in pipe_listeners:
|
for listener in pipe_listeners:
|
||||||
util.replace_model_node(pipe.model, listener, tok2vec.model.copy())
|
new_model = tok2vec_model.copy()
|
||||||
|
if "replace_listener" in tok2vec_model.attrs:
|
||||||
|
new_model = tok2vec_model.attrs["replace_listener"](new_model)
|
||||||
|
util.replace_model_node(pipe.model, listener, new_model)
|
||||||
tok2vec.remove_listener(listener, pipe_name)
|
tok2vec.remove_listener(listener, pipe_name)
|
||||||
|
|
||||||
def to_disk(
|
def to_disk(
|
||||||
|
|
|
@ -218,6 +218,13 @@ def test_replace_listeners():
|
||||||
nlp.replace_listeners("tok2vec", "tagger", ["model.yolo"])
|
nlp.replace_listeners("tok2vec", "tagger", ["model.yolo"])
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
nlp.replace_listeners("tok2vec", "tagger", ["model.tok2vec", "model.yolo"])
|
nlp.replace_listeners("tok2vec", "tagger", ["model.tok2vec", "model.yolo"])
|
||||||
|
# attempt training with the new pipeline
|
||||||
|
optimizer = nlp.initialize(lambda: examples)
|
||||||
|
for i in range(2):
|
||||||
|
losses = {}
|
||||||
|
nlp.update(examples, sgd=optimizer, losses=losses)
|
||||||
|
assert losses["tok2vec"] == 0.0
|
||||||
|
assert losses["tagger"] > 0.0
|
||||||
|
|
||||||
|
|
||||||
cfg_string_multi = """
|
cfg_string_multi = """
|
||||||
|
|
Loading…
Reference in New Issue
Block a user