From fb70f8813ddf2ce71eed9cf763d2aadfb3ad2620 Mon Sep 17 00:00:00 2001 From: svlandeg Date: Wed, 31 May 2023 14:21:18 +0200 Subject: [PATCH] link components in add and remove pipe --- spacy/language.py | 5 ++++- spacy/tests/pipeline/test_tok2vec.py | 18 ++++++++++++++++-- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/spacy/language.py b/spacy/language.py index 289e6dd2c..1a45de7ec 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -793,6 +793,7 @@ class Language: pipe_index = self._get_pipe_index(before, after, first, last) self._pipe_meta[name] = self.get_factory_meta(factory_name) self._components.insert(pipe_index, (name, pipe_component)) + self._link_components() return pipe_component def _get_pipe_index( @@ -951,6 +952,7 @@ class Language: # Make sure the name is also removed from the set of disabled components if name in self.disabled: self._disabled.remove(name) + self._link_components() return removed def disable_pipe(self, name: str) -> None: @@ -1310,7 +1312,6 @@ class Language: if pretrain_cfg: P = registry.resolve(pretrain_cfg, schema=ConfigSchemaPretrain) init_tok2vec(self, P, I) - self._link_components() self._optimizer = sgd if sgd is not None: self._optimizer = sgd @@ -1678,6 +1679,7 @@ class Language: # here :( for i, (name1, proc1) in enumerate(self.pipeline): if isinstance(proc1, ty.ListenedToComponent): + proc1.listener_map = {} for name2, proc2 in self.pipeline[i + 1 :]: proc1.find_listeners(proc2) @@ -1811,6 +1813,7 @@ class Language: raw_config=raw_config, ) else: + assert "source" in pipe_cfg # We need the sourced components to reference the same # vocab without modifying the current vocab state **AND** # we still want to load the source model vectors to perform diff --git a/spacy/tests/pipeline/test_tok2vec.py b/spacy/tests/pipeline/test_tok2vec.py index e423d9a19..19046fc50 100644 --- a/spacy/tests/pipeline/test_tok2vec.py +++ b/spacy/tests/pipeline/test_tok2vec.py @@ -189,10 +189,10 @@ def test_tok2vec_listener(with_vectors): tagger.add_label(tag) # Check that the Tok2Vec component finds it listeners - assert tok2vec.listeners == [] - optimizer = nlp.initialize(lambda: train_examples) assert tok2vec.listeners == [tagger_tok2vec] + # Initialize and train + optimizer = nlp.initialize(lambda: train_examples) for i in range(5): losses = {} nlp.update(train_examples, sgd=optimizer, losses=losses) @@ -540,3 +540,17 @@ def test_tok2vec_listeners_textcat(): assert cats1["imperative"] < 0.9 assert [t.tag_ for t in docs[0]] == ["V", "J", "N"] assert [t.tag_ for t in docs[1]] == ["N", "V", "J", "N"] + + +def test_tok2vec_listener_source_link(): + orig_config = Config().from_str(cfg_string_multi) + nlp1 = util.load_model_from_config(orig_config, auto_fill=True, validate=True) + assert list(nlp1.get_pipe("tok2vec").listener_map.keys()) == ["tagger", "ner"] + + nlp2 = English() + nlp2.add_pipe("tok2vec", source=nlp1) + assert nlp2.get_pipe("tok2vec").listener_map == {} + nlp2.add_pipe("tagger", source=nlp1) + assert list(nlp2.get_pipe("tok2vec").listener_map.keys()) == ["tagger"] + nlp2.add_pipe("ner", source=nlp1) + assert list(nlp2.get_pipe("tok2vec").listener_map.keys()) == ["tagger", "ner"] \ No newline at end of file