diff --git a/spacy/cli/train.py b/spacy/cli/train.py index 6538994c0..0073ba9f0 100644 --- a/spacy/cli/train.py +++ b/spacy/cli/train.py @@ -156,8 +156,7 @@ def train( "`lang` argument ('{}') ".format(nlp.lang, lang), exits=1, ) - other_pipes = [pipe for pipe in nlp.pipe_names if pipe not in pipeline] - nlp.disable_pipes(*other_pipes) + nlp.disable_pipes([p for p in nlp.pipe_names if p not in pipeline]) for pipe in pipeline: if pipe not in nlp.pipe_names: if pipe == "parser": diff --git a/spacy/language.py b/spacy/language.py index 330852741..5f0e632ae 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -448,6 +448,8 @@ class Language(object): DOCS: https://spacy.io/api/language#disable_pipes """ + if len(names) == 1 and isinstance(names[0], (list, tuple)): + names = names[0] # support list of names instead of spread return DisabledPipes(self, *names) def make_doc(self, text): diff --git a/spacy/pipeline/entityruler.py b/spacy/pipeline/entityruler.py index 9078f387e..882e87547 100644 --- a/spacy/pipeline/entityruler.py +++ b/spacy/pipeline/entityruler.py @@ -187,7 +187,7 @@ class EntityRuler(object): ] except ValueError: subsequent_pipes = [] - with self.nlp.disable_pipes(*subsequent_pipes): + with self.nlp.disable_pipes(subsequent_pipes): for entry in patterns: label = entry["label"] if "id" in entry: diff --git a/spacy/tests/pipeline/test_pipe_methods.py b/spacy/tests/pipeline/test_pipe_methods.py index 5f1fa5cfe..27fb57b18 100644 --- a/spacy/tests/pipeline/test_pipe_methods.py +++ b/spacy/tests/pipeline/test_pipe_methods.py @@ -105,6 +105,16 @@ def test_disable_pipes_context(nlp, name): assert nlp.has_pipe(name) +def test_disable_pipes_list_arg(nlp): + for name in ["c1", "c2", "c3"]: + nlp.add_pipe(new_pipe, name=name) + assert nlp.has_pipe(name) + with nlp.disable_pipes(["c1", "c2"]): + assert not nlp.has_pipe("c1") + assert not nlp.has_pipe("c2") + assert nlp.has_pipe("c3") + + @pytest.mark.parametrize("n_pipes", [100]) def test_add_lots_of_pipes(nlp, n_pipes): for i in range(n_pipes): diff --git a/spacy/tests/regression/test_issue3611.py b/spacy/tests/regression/test_issue3611.py index c0ee83e1b..3c4836264 100644 --- a/spacy/tests/regression/test_issue3611.py +++ b/spacy/tests/regression/test_issue3611.py @@ -34,8 +34,7 @@ def test_issue3611(): nlp.add_pipe(textcat, last=True) # training the network - other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "textcat"] - with nlp.disable_pipes(*other_pipes): + with nlp.disable_pipes([p for p in nlp.pipe_names if p != "textcat"]): optimizer = nlp.begin_training() for i in range(3): losses = {} diff --git a/spacy/tests/regression/test_issue4030.py b/spacy/tests/regression/test_issue4030.py index c331fa1d2..ed219573f 100644 --- a/spacy/tests/regression/test_issue4030.py +++ b/spacy/tests/regression/test_issue4030.py @@ -34,8 +34,7 @@ def test_issue4030(): nlp.add_pipe(textcat, last=True) # training the network - other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "textcat"] - with nlp.disable_pipes(*other_pipes): + with nlp.disable_pipes([p for p in nlp.pipe_names if p != "textcat"]): optimizer = nlp.begin_training() for i in range(3): losses = {} diff --git a/website/docs/api/language.md b/website/docs/api/language.md index c44339ff5..6e7f6be3e 100644 --- a/website/docs/api/language.md +++ b/website/docs/api/language.md @@ -323,18 +323,38 @@ you can use to undo your changes. > #### Example > > ```python -> with nlp.disable_pipes('tagger', 'parser'): +> # New API as of v2.2.2 +> with nlp.disable_pipes(["tagger", "parser"]): +> nlp.begin_training() +> +> with nlp.disable_pipes("tagger", "parser"): > nlp.begin_training() > -> disabled = nlp.disable_pipes('tagger', 'parser') +> disabled = nlp.disable_pipes("tagger", "parser") > nlp.begin_training() > disabled.restore() > ``` -| Name | Type | Description | -| ----------- | --------------- | ------------------------------------------------------------------------------------ | -| `*disabled` | unicode | Names of pipeline components to disable. | -| **RETURNS** | `DisabledPipes` | The disabled pipes that can be restored by calling the object's `.restore()` method. | +| Name | Type | Description | +| ----------------------------------------- | --------------- | ------------------------------------------------------------------------------------ | +| `disabled` 2.2.2 | list | Names of pipeline components to disable. | +| `*disabled` | unicode | Names of pipeline components to disable. | +| **RETURNS** | `DisabledPipes` | The disabled pipes that can be restored by calling the object's `.restore()` method. | + + + +As of spaCy v2.2.2, the `Language.disable_pipes` method can also take a list of +component names as its first argument (instead of a variable number of +arguments). This is especially useful if you're generating the component names +to disable programmatically. The new syntax will become the default in the +future. + +```diff +- disabled = nlp.disable_pipes("tagger", "parser") ++ disabled = nlp.disable_pipes(["tagger", "parser"]) +``` + + ## Language.to_disk {#to_disk tag="method" new="2"}