From 219410facb31ccacc3109060407668762e62c492 Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Tue, 4 Apr 2023 17:59:59 +0200 Subject: [PATCH] Always redo listener state on pipeline modification * Modify `Language._link_components` to reset listener map and re-add all components from scratch. * Run `Language._link_components` when pipes are added or removed. * Fix replace listeners for sourced components: * Make sure that the source pipeline has the listener state corresponding to the source pipeline and not the new pipeline when listeners are replaced. * Remove removal of unused listeners (this is now always updated by `_link_components` at the point where pipes are added). * Remove incorrect `replace listeners` after the pipeline is created. * For components where `replace_listeners` was specified, the listeners have already been replaced when the component was sourced+added. * This incorrectly ran `replace_listeners` twice for components that had `replace_listeners` AND the listened-to component was also sourced into the pipeline (but at this point the listened-to component is irrelevant for those components). --- spacy/language.py | 30 +++++++++------------------- spacy/tests/pipeline/test_tok2vec.py | 17 ++++++++++++++-- 2 files changed, 24 insertions(+), 23 deletions(-) diff --git a/spacy/language.py b/spacy/language.py index 9fdcf6328..9cc35cac0 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -793,6 +793,7 @@ class Language: pipe_index = self._get_pipe_index(before, after, first, last) self._pipe_meta[name] = self.get_factory_meta(factory_name) self._components.insert(pipe_index, (name, pipe_component)) + self._link_components() return pipe_component def _get_pipe_index( @@ -951,6 +952,7 @@ class Language: # Make sure the name is also removed from the set of disabled components if name in self.disabled: self._disabled.remove(name) + self._link_components() return removed def disable_pipe(self, name: str) -> None: @@ -1675,6 +1677,7 @@ class Language: # here :( for i, (name1, proc1) in enumerate(self.pipeline): if isinstance(proc1, ty.ListenedToComponent): + proc1.listener_map = {} for name2, proc2 in self.pipeline[i + 1 :]: proc1.find_listeners(proc2) @@ -1827,6 +1830,12 @@ class Language: source_name = pipe_cfg.get("component", pipe_name) listeners_replaced = False if "replace_listeners" in pipe_cfg: + # HACK: Reset any listened-to components to the listener + # state of the source pipeline for the purpose of + # replacing listeners. The add_pipe below will set + # the state back to the listener state for the new + # pipeline. + source_nlps[model]._link_components() for name, proc in source_nlps[model].pipeline: if source_name in getattr(proc, "listening_components", []): source_nlps[model].replace_listeners( @@ -1892,27 +1901,6 @@ class Language: raise ValueError( Errors.E942.format(name="pipeline_creation", value=type(nlp)) ) - # Detect components with listeners that are not frozen consistently - for name, proc in nlp.pipeline: - if isinstance(proc, ty.ListenedToComponent): - # Remove listeners not in the pipeline - listener_names = proc.listening_components - unused_listener_names = [ - ll for ll in listener_names if ll not in nlp.pipe_names - ] - for listener_name in unused_listener_names: - for listener in proc.listener_map.get(listener_name, []): - proc.remove_listener(listener, listener_name) - - for listener_name in proc.listening_components: - # e.g. tok2vec/transformer - # If it's a component sourced from another pipeline, we check if - # the tok2vec listeners should be replaced with standalone tok2vec - # models (e.g. so component can be frozen without its performance - # degrading when other components/tok2vec are updated) - paths = sourced.get(listener_name, {}).get("replace_listeners", []) - if paths: - nlp.replace_listeners(name, listener_name, paths) return nlp def replace_listeners( diff --git a/spacy/tests/pipeline/test_tok2vec.py b/spacy/tests/pipeline/test_tok2vec.py index e423d9a19..d6f5459c4 100644 --- a/spacy/tests/pipeline/test_tok2vec.py +++ b/spacy/tests/pipeline/test_tok2vec.py @@ -189,9 +189,8 @@ def test_tok2vec_listener(with_vectors): tagger.add_label(tag) # Check that the Tok2Vec component finds it listeners - assert tok2vec.listeners == [] - optimizer = nlp.initialize(lambda: train_examples) assert tok2vec.listeners == [tagger_tok2vec] + optimizer = nlp.initialize(lambda: train_examples) for i in range(5): losses = {} @@ -540,3 +539,17 @@ def test_tok2vec_listeners_textcat(): assert cats1["imperative"] < 0.9 assert [t.tag_ for t in docs[0]] == ["V", "J", "N"] assert [t.tag_ for t in docs[1]] == ["N", "V", "J", "N"] + + +def test_tok2vec_listener_source_link(): + orig_config = Config().from_str(cfg_string_multi) + nlp1 = util.load_model_from_config(orig_config, auto_fill=True, validate=True) + assert list(nlp1.get_pipe("tok2vec").listener_map.keys()) == ["tagger", "ner"] + + nlp2 = English() + nlp2.add_pipe("tok2vec", source=nlp1) + assert nlp2.get_pipe("tok2vec").listener_map == {} + nlp2.add_pipe("tagger", source=nlp1) + assert list(nlp2.get_pipe("tok2vec").listener_map.keys()) == ["tagger"] + nlp2.add_pipe("ner", source=nlp1) + assert list(nlp2.get_pipe("tok2vec").listener_map.keys()) == ["tagger", "ner"]