From 8f018e47f84264ca852c67578af1ab95cbd74be3 Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Sun, 4 Oct 2020 14:43:45 +0200 Subject: [PATCH] Adjust [initialize.components] on Language.remove_pipe and Language.rename_pipe --- spacy/language.py | 7 +++++++ spacy/tests/pipeline/test_pipe_methods.py | 22 ++++++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/spacy/language.py b/spacy/language.py index d76741da3..9fdde03d5 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -896,6 +896,10 @@ class Language: self._components[i] = (new_name, self._components[i][1]) self._pipe_meta[new_name] = self._pipe_meta.pop(old_name) self._pipe_configs[new_name] = self._pipe_configs.pop(old_name) + # Make sure [initialize] config is adjusted + if old_name in self._config["initialize"]["components"]: + init_cfg = self._config["initialize"]["components"].pop(old_name) + self._config["initialize"]["components"][new_name] = init_cfg def remove_pipe(self, name: str) -> Tuple[str, Callable[[Doc], Doc]]: """Remove a component from the pipeline. @@ -912,6 +916,9 @@ class Language: # because factory may be used for something else self._pipe_meta.pop(name) self._pipe_configs.pop(name) + # Make sure name is removed from the [initialize] config + if name in self._config["initialize"]["components"]: + self._config["initialize"]["components"].pop(name) # Make sure the name is also removed from the set of disabled components if name in self.disabled: self._disabled.remove(name) diff --git a/spacy/tests/pipeline/test_pipe_methods.py b/spacy/tests/pipeline/test_pipe_methods.py index e647ba440..a4297a1d1 100644 --- a/spacy/tests/pipeline/test_pipe_methods.py +++ b/spacy/tests/pipeline/test_pipe_methods.py @@ -382,3 +382,25 @@ def test_warning_pipe_begin_training(): def begin_training(*args, **kwargs): ... + + +def test_pipe_methods_initialize(): + """Test that the [initialize] config reflects the components correctly.""" + nlp = Language() + nlp.add_pipe("tagger") + assert "tagger" not in nlp.config["initialize"]["components"] + nlp.config["initialize"]["components"]["tagger"] = {"labels": ["hello"]} + assert nlp.config["initialize"]["components"]["tagger"] == {"labels": ["hello"]} + nlp.remove_pipe("tagger") + assert "tagger" not in nlp.config["initialize"]["components"] + nlp.add_pipe("tagger") + assert "tagger" not in nlp.config["initialize"]["components"] + nlp.config["initialize"]["components"]["tagger"] = {"labels": ["hello"]} + nlp.rename_pipe("tagger", "my_tagger") + assert "tagger" not in nlp.config["initialize"]["components"] + assert nlp.config["initialize"]["components"]["my_tagger"] == {"labels": ["hello"]} + nlp.config["initialize"]["components"]["test"] = {"foo": "bar"} + nlp.add_pipe("ner", name="test") + assert "test" in nlp.config["initialize"]["components"] + nlp.remove_pipe("test") + assert "test" not in nlp.config["initialize"]["components"]