Always redo listener state on pipeline modification

* Modify `Language._link_components` to reset listener map and re-add
  all components from scratch.
* Run `Language._link_components` when pipes are added or removed.
* Fix replace listeners for sourced components:
  * Make sure that the source pipeline has the listener state
    corresponding to the source pipeline and not the new pipeline when
    listeners are replaced.
  * Remove removal of unused listeners (this is now always updated by
    `_link_components` at the point where pipes are added).
  * Remove incorrect `replace listeners` after the pipeline is created.
      * For components where `replace_listeners` was specified, the
        listeners have already been replaced when the component was
        sourced+added.
      * This incorrectly ran `replace_listeners` twice for components
        that had `replace_listeners` AND the listened-to component was
        also sourced into the pipeline (but at this point the
        listened-to component is irrelevant for those components).
This commit is contained in:
Adriane Boyd 2023-04-04 17:59:59 +02:00
parent de32011e4c
commit 219410facb
2 changed files with 24 additions and 23 deletions

View File

@ -793,6 +793,7 @@ class Language:
pipe_index = self._get_pipe_index(before, after, first, last)
self._pipe_meta[name] = self.get_factory_meta(factory_name)
self._components.insert(pipe_index, (name, pipe_component))
self._link_components()
return pipe_component
def _get_pipe_index(
@ -951,6 +952,7 @@ class Language:
# Make sure the name is also removed from the set of disabled components
if name in self.disabled:
self._disabled.remove(name)
self._link_components()
return removed
def disable_pipe(self, name: str) -> None:
@ -1675,6 +1677,7 @@ class Language:
# here :(
for i, (name1, proc1) in enumerate(self.pipeline):
if isinstance(proc1, ty.ListenedToComponent):
proc1.listener_map = {}
for name2, proc2 in self.pipeline[i + 1 :]:
proc1.find_listeners(proc2)
@ -1827,6 +1830,12 @@ class Language:
source_name = pipe_cfg.get("component", pipe_name)
listeners_replaced = False
if "replace_listeners" in pipe_cfg:
# HACK: Reset any listened-to components to the listener
# state of the source pipeline for the purpose of
# replacing listeners. The add_pipe below will set
# the state back to the listener state for the new
# pipeline.
source_nlps[model]._link_components()
for name, proc in source_nlps[model].pipeline:
if source_name in getattr(proc, "listening_components", []):
source_nlps[model].replace_listeners(
@ -1892,27 +1901,6 @@ class Language:
raise ValueError(
Errors.E942.format(name="pipeline_creation", value=type(nlp))
)
# Detect components with listeners that are not frozen consistently
for name, proc in nlp.pipeline:
if isinstance(proc, ty.ListenedToComponent):
# Remove listeners not in the pipeline
listener_names = proc.listening_components
unused_listener_names = [
ll for ll in listener_names if ll not in nlp.pipe_names
]
for listener_name in unused_listener_names:
for listener in proc.listener_map.get(listener_name, []):
proc.remove_listener(listener, listener_name)
for listener_name in proc.listening_components:
# e.g. tok2vec/transformer
# If it's a component sourced from another pipeline, we check if
# the tok2vec listeners should be replaced with standalone tok2vec
# models (e.g. so component can be frozen without its performance
# degrading when other components/tok2vec are updated)
paths = sourced.get(listener_name, {}).get("replace_listeners", [])
if paths:
nlp.replace_listeners(name, listener_name, paths)
return nlp
def replace_listeners(

View File

@ -189,9 +189,8 @@ def test_tok2vec_listener(with_vectors):
tagger.add_label(tag)
# Check that the Tok2Vec component finds it listeners
assert tok2vec.listeners == []
optimizer = nlp.initialize(lambda: train_examples)
assert tok2vec.listeners == [tagger_tok2vec]
optimizer = nlp.initialize(lambda: train_examples)
for i in range(5):
losses = {}
@ -540,3 +539,17 @@ def test_tok2vec_listeners_textcat():
assert cats1["imperative"] < 0.9
assert [t.tag_ for t in docs[0]] == ["V", "J", "N"]
assert [t.tag_ for t in docs[1]] == ["N", "V", "J", "N"]
def test_tok2vec_listener_source_link():
orig_config = Config().from_str(cfg_string_multi)
nlp1 = util.load_model_from_config(orig_config, auto_fill=True, validate=True)
assert list(nlp1.get_pipe("tok2vec").listener_map.keys()) == ["tagger", "ner"]
nlp2 = English()
nlp2.add_pipe("tok2vec", source=nlp1)
assert nlp2.get_pipe("tok2vec").listener_map == {}
nlp2.add_pipe("tagger", source=nlp1)
assert list(nlp2.get_pipe("tok2vec").listener_map.keys()) == ["tagger"]
nlp2.add_pipe("ner", source=nlp1)
assert list(nlp2.get_pipe("tok2vec").listener_map.keys()) == ["tagger", "ner"]