diff --git a/spacy/__init__.py b/spacy/__init__.py
index ca47edc94..069215fda 100644
--- a/spacy/__init__.py
+++ b/spacy/__init__.py
@@ -32,6 +32,7 @@ def load(
*,
vocab: Union[Vocab, bool] = True,
disable: Iterable[str] = util.SimpleFrozenList(),
+ enable: Iterable[str] = util.SimpleFrozenList(),
exclude: Iterable[str] = util.SimpleFrozenList(),
config: Union[Dict[str, Any], Config] = util.SimpleFrozenDict(),
) -> Language:
@@ -42,6 +43,8 @@ def load(
disable (Iterable[str]): Names of pipeline components to disable. Disabled
pipes will be loaded but they won't be run unless you explicitly
enable them by calling nlp.enable_pipe.
+ enable (Iterable[str]): Names of pipeline components to enable. All other
+ pipes will be disabled (but can be enabled later using nlp.enable_pipe).
exclude (Iterable[str]): Names of pipeline components to exclude. Excluded
components won't be loaded.
config (Dict[str, Any] / Config): Config overrides as nested dict or dict
@@ -49,7 +52,12 @@ def load(
RETURNS (Language): The loaded nlp object.
"""
return util.load_model(
- name, vocab=vocab, disable=disable, exclude=exclude, config=config
+ name,
+ vocab=vocab,
+ disable=disable,
+ enable=enable,
+ exclude=exclude,
+ config=config,
)
diff --git a/spacy/errors.py b/spacy/errors.py
index 384a6a4d2..14010565b 100644
--- a/spacy/errors.py
+++ b/spacy/errors.py
@@ -932,6 +932,8 @@ class Errors(metaclass=ErrorsWithCodes):
E1040 = ("Doc.from_json requires all tokens to have the same attributes. "
"Some tokens do not contain annotation for: {partial_attrs}")
E1041 = ("Expected a string, Doc, or bytes as input, but got: {type}")
+ E1042 = ("Function was called with `{arg1}`={arg1_values} and "
+ "`{arg2}`={arg2_values} but these arguments are conflicting.")
# Deprecated model shortcuts, only used in errors and warnings
diff --git a/spacy/language.py b/spacy/language.py
index 42847823f..816bd6531 100644
--- a/spacy/language.py
+++ b/spacy/language.py
@@ -1,4 +1,4 @@
-from typing import Iterator, Optional, Any, Dict, Callable, Iterable
+from typing import Iterator, Optional, Any, Dict, Callable, Iterable, Collection
from typing import Union, Tuple, List, Set, Pattern, Sequence
from typing import NoReturn, TYPE_CHECKING, TypeVar, cast, overload
@@ -1694,6 +1694,7 @@ class Language:
*,
vocab: Union[Vocab, bool] = True,
disable: Iterable[str] = SimpleFrozenList(),
+ enable: Iterable[str] = SimpleFrozenList(),
exclude: Iterable[str] = SimpleFrozenList(),
meta: Dict[str, Any] = SimpleFrozenDict(),
auto_fill: bool = True,
@@ -1708,6 +1709,8 @@ class Language:
disable (Iterable[str]): Names of pipeline components to disable.
Disabled pipes will be loaded but they won't be run unless you
explicitly enable them by calling nlp.enable_pipe.
+ enable (Iterable[str]): Names of pipeline components to enable. All other
+ pipes will be disabled (and can be enabled using `nlp.enable_pipe`).
exclude (Iterable[str]): Names of pipeline components to exclude.
Excluded components won't be loaded.
meta (Dict[str, Any]): Meta overrides for nlp.meta.
@@ -1861,8 +1864,15 @@ class Language:
# Restore the original vocab after sourcing if necessary
if vocab_b is not None:
nlp.vocab.from_bytes(vocab_b)
- disabled_pipes = [*config["nlp"]["disabled"], *disable]
+
+ # Resolve disabled/enabled settings.
+ disabled_pipes = cls._resolve_component_status(
+ [*config["nlp"]["disabled"], *disable],
+ [*config["nlp"].get("enabled", []), *enable],
+ config["nlp"]["pipeline"],
+ )
nlp._disabled = set(p for p in disabled_pipes if p not in exclude)
+
nlp.batch_size = config["nlp"]["batch_size"]
nlp.config = filled if auto_fill else config
if after_pipeline_creation is not None:
@@ -2014,6 +2024,42 @@ class Language:
serializers["vocab"] = lambda p: self.vocab.to_disk(p, exclude=exclude)
util.to_disk(path, serializers, exclude)
+ @staticmethod
+ def _resolve_component_status(
+ disable: Iterable[str], enable: Iterable[str], pipe_names: Collection[str]
+ ) -> Tuple[str, ...]:
+ """Derives whether (1) `disable` and `enable` values are consistent and (2)
+ resolves those to a single set of disabled components. Raises an error in
+ case of inconsistency.
+
+ disable (Iterable[str]): Names of components or serialization fields to disable.
+ enable (Iterable[str]): Names of pipeline components to enable.
+ pipe_names (Iterable[str]): Names of all pipeline components.
+
+ RETURNS (Tuple[str, ...]): Names of components to exclude from pipeline w.r.t.
+ specified includes and excludes.
+ """
+
+ if disable is not None and isinstance(disable, str):
+ disable = [disable]
+ to_disable = disable
+
+ if enable:
+ to_disable = [
+ pipe_name for pipe_name in pipe_names if pipe_name not in enable
+ ]
+ if disable and disable != to_disable:
+ raise ValueError(
+ Errors.E1042.format(
+ arg1="enable",
+ arg2="disable",
+ arg1_values=enable,
+ arg2_values=disable,
+ )
+ )
+
+ return tuple(to_disable)
+
def from_disk(
self,
path: Union[str, Path],
diff --git a/spacy/tests/pipeline/test_pipe_methods.py b/spacy/tests/pipeline/test_pipe_methods.py
index 4b8fb8ebc..6f00a1cd9 100644
--- a/spacy/tests/pipeline/test_pipe_methods.py
+++ b/spacy/tests/pipeline/test_pipe_methods.py
@@ -4,13 +4,14 @@ import numpy
import pytest
from thinc.api import get_current_ops
+import spacy
from spacy.lang.en import English
from spacy.lang.en.syntax_iterators import noun_chunks
from spacy.language import Language
from spacy.pipeline import TrainablePipe
from spacy.tokens import Doc
from spacy.training import Example
-from spacy.util import SimpleFrozenList, get_arg_names
+from spacy.util import SimpleFrozenList, get_arg_names, make_tempdir
from spacy.vocab import Vocab
@@ -602,3 +603,52 @@ def test_update_with_annotates():
assert results[component] == "".join(eg.predicted.text for eg in examples)
for component in components - set(components_to_annotate):
assert results[component] == ""
+
+
+def test_load_disable_enable() -> None:
+ """
+ Tests spacy.load() with dis-/enabling components.
+ """
+
+ base_nlp = English()
+ for pipe in ("sentencizer", "tagger", "parser"):
+ base_nlp.add_pipe(pipe)
+
+ with make_tempdir() as tmp_dir:
+ base_nlp.to_disk(tmp_dir)
+ to_disable = ["parser", "tagger"]
+ to_enable = ["tagger", "parser"]
+
+ # Setting only `disable`.
+ nlp = spacy.load(tmp_dir, disable=to_disable)
+ assert all([comp_name in nlp.disabled for comp_name in to_disable])
+
+ # Setting only `enable`.
+ nlp = spacy.load(tmp_dir, enable=to_enable)
+ assert all(
+ [
+ (comp_name in nlp.disabled) is (comp_name not in to_enable)
+ for comp_name in nlp.component_names
+ ]
+ )
+
+ # Testing consistent enable/disable combination.
+ nlp = spacy.load(
+ tmp_dir,
+ enable=to_enable,
+ disable=[
+ comp_name
+ for comp_name in nlp.component_names
+ if comp_name not in to_enable
+ ],
+ )
+ assert all(
+ [
+ (comp_name in nlp.disabled) is (comp_name not in to_enable)
+ for comp_name in nlp.component_names
+ ]
+ )
+
+ # Inconsistent enable/disable combination.
+ with pytest.raises(ValueError):
+ spacy.load(tmp_dir, enable=to_enable, disable=["parser"])
diff --git a/spacy/util.py b/spacy/util.py
index 0111c839e..9b871b87b 100644
--- a/spacy/util.py
+++ b/spacy/util.py
@@ -1,6 +1,6 @@
from typing import List, Mapping, NoReturn, Union, Dict, Any, Set, cast
from typing import Optional, Iterable, Callable, Tuple, Type
-from typing import Iterator, Type, Pattern, Generator, TYPE_CHECKING
+from typing import Iterator, Pattern, Generator, TYPE_CHECKING
from types import ModuleType
import os
import importlib
@@ -12,7 +12,6 @@ from thinc.api import NumpyOps, get_current_ops, Adam, Config, Optimizer
from thinc.api import ConfigValidationError, Model
import functools
import itertools
-import numpy.random
import numpy
import srsly
import catalogue
@@ -400,6 +399,7 @@ def load_model(
*,
vocab: Union["Vocab", bool] = True,
disable: Iterable[str] = SimpleFrozenList(),
+ enable: Iterable[str] = SimpleFrozenList(),
exclude: Iterable[str] = SimpleFrozenList(),
config: Union[Dict[str, Any], Config] = SimpleFrozenDict(),
) -> "Language":
@@ -409,11 +409,19 @@ def load_model(
vocab (Vocab / True): Optional vocab to pass in on initialization. If True,
a new Vocab object will be created.
disable (Iterable[str]): Names of pipeline components to disable.
+ enable (Iterable[str]): Names of pipeline components to enable. All others will be disabled.
+ exclude (Iterable[str]): Names of pipeline components to exclude.
config (Dict[str, Any] / Config): Config overrides as nested dict or dict
keyed by section values in dot notation.
RETURNS (Language): The loaded nlp object.
"""
- kwargs = {"vocab": vocab, "disable": disable, "exclude": exclude, "config": config}
+ kwargs = {
+ "vocab": vocab,
+ "disable": disable,
+ "enable": enable,
+ "exclude": exclude,
+ "config": config,
+ }
if isinstance(name, str): # name or string path
if name.startswith("blank:"): # shortcut for blank model
return get_lang_class(name.replace("blank:", ""))()
@@ -433,6 +441,7 @@ def load_model_from_package(
*,
vocab: Union["Vocab", bool] = True,
disable: Iterable[str] = SimpleFrozenList(),
+ enable: Iterable[str] = SimpleFrozenList(),
exclude: Iterable[str] = SimpleFrozenList(),
config: Union[Dict[str, Any], Config] = SimpleFrozenDict(),
) -> "Language":
@@ -444,6 +453,8 @@ def load_model_from_package(
disable (Iterable[str]): Names of pipeline components to disable. Disabled
pipes will be loaded but they won't be run unless you explicitly
enable them by calling nlp.enable_pipe.
+ enable (Iterable[str]): Names of pipeline components to enable. All other
+ pipes will be disabled (and can be enabled using `nlp.enable_pipe`).
exclude (Iterable[str]): Names of pipeline components to exclude. Excluded
components won't be loaded.
config (Dict[str, Any] / Config): Config overrides as nested dict or dict
@@ -451,7 +462,7 @@ def load_model_from_package(
RETURNS (Language): The loaded nlp object.
"""
cls = importlib.import_module(name)
- return cls.load(vocab=vocab, disable=disable, exclude=exclude, config=config) # type: ignore[attr-defined]
+ return cls.load(vocab=vocab, disable=disable, enable=enable, exclude=exclude, config=config) # type: ignore[attr-defined]
def load_model_from_path(
@@ -460,6 +471,7 @@ def load_model_from_path(
meta: Optional[Dict[str, Any]] = None,
vocab: Union["Vocab", bool] = True,
disable: Iterable[str] = SimpleFrozenList(),
+ enable: Iterable[str] = SimpleFrozenList(),
exclude: Iterable[str] = SimpleFrozenList(),
config: Union[Dict[str, Any], Config] = SimpleFrozenDict(),
) -> "Language":
@@ -473,6 +485,8 @@ def load_model_from_path(
disable (Iterable[str]): Names of pipeline components to disable. Disabled
pipes will be loaded but they won't be run unless you explicitly
enable them by calling nlp.enable_pipe.
+ enable (Iterable[str]): Names of pipeline components to enable. All other
+ pipes will be disabled (and can be enabled using `nlp.enable_pipe`).
exclude (Iterable[str]): Names of pipeline components to exclude. Excluded
components won't be loaded.
config (Dict[str, Any] / Config): Config overrides as nested dict or dict
@@ -487,7 +501,12 @@ def load_model_from_path(
overrides = dict_to_dot(config)
config = load_config(config_path, overrides=overrides)
nlp = load_model_from_config(
- config, vocab=vocab, disable=disable, exclude=exclude, meta=meta
+ config,
+ vocab=vocab,
+ disable=disable,
+ enable=enable,
+ exclude=exclude,
+ meta=meta,
)
return nlp.from_disk(model_path, exclude=exclude, overrides=overrides)
@@ -498,6 +517,7 @@ def load_model_from_config(
meta: Dict[str, Any] = SimpleFrozenDict(),
vocab: Union["Vocab", bool] = True,
disable: Iterable[str] = SimpleFrozenList(),
+ enable: Iterable[str] = SimpleFrozenList(),
exclude: Iterable[str] = SimpleFrozenList(),
auto_fill: bool = False,
validate: bool = True,
@@ -512,6 +532,8 @@ def load_model_from_config(
disable (Iterable[str]): Names of pipeline components to disable. Disabled
pipes will be loaded but they won't be run unless you explicitly
enable them by calling nlp.enable_pipe.
+ enable (Iterable[str]): Names of pipeline components to enable. All other
+ pipes will be disabled (and can be enabled using `nlp.enable_pipe`).
exclude (Iterable[str]): Names of pipeline components to exclude. Excluded
components won't be loaded.
auto_fill (bool): Whether to auto-fill config with missing defaults.
@@ -530,6 +552,7 @@ def load_model_from_config(
config,
vocab=vocab,
disable=disable,
+ enable=enable,
exclude=exclude,
auto_fill=auto_fill,
validate=validate,
@@ -594,6 +617,7 @@ def load_model_from_init_py(
*,
vocab: Union["Vocab", bool] = True,
disable: Iterable[str] = SimpleFrozenList(),
+ enable: Iterable[str] = SimpleFrozenList(),
exclude: Iterable[str] = SimpleFrozenList(),
config: Union[Dict[str, Any], Config] = SimpleFrozenDict(),
) -> "Language":
@@ -605,6 +629,8 @@ def load_model_from_init_py(
disable (Iterable[str]): Names of pipeline components to disable. Disabled
pipes will be loaded but they won't be run unless you explicitly
enable them by calling nlp.enable_pipe.
+ enable (Iterable[str]): Names of pipeline components to enable. All other
+ pipes will be disabled (and can be enabled using `nlp.enable_pipe`).
exclude (Iterable[str]): Names of pipeline components to exclude. Excluded
components won't be loaded.
config (Dict[str, Any] / Config): Config overrides as nested dict or dict
@@ -622,6 +648,7 @@ def load_model_from_init_py(
vocab=vocab,
meta=meta,
disable=disable,
+ enable=enable,
exclude=exclude,
config=config,
)
diff --git a/website/docs/api/top-level.md b/website/docs/api/top-level.md
index 889c6437c..c96c571e9 100644
--- a/website/docs/api/top-level.md
+++ b/website/docs/api/top-level.md
@@ -51,6 +51,7 @@ specified separately using the new `exclude` keyword argument.
| _keyword-only_ | |
| `vocab` | Optional shared vocab to pass in on initialization. If `True` (default), a new `Vocab` object will be created. ~~Union[Vocab, bool]~~ |
| `disable` | Names of pipeline components to [disable](/usage/processing-pipelines#disabling). Disabled pipes will be loaded but they won't be run unless you explicitly enable them by calling [nlp.enable_pipe](/api/language#enable_pipe). ~~List[str]~~ |
+| `enable` | Names of pipeline components to [enable](/usage/processing-pipelines#disabling). All other pipes will be disabled. ~~List[str]~~ |
| `exclude` 3 | Names of pipeline components to [exclude](/usage/processing-pipelines#disabling). Excluded components won't be loaded. ~~List[str]~~ |
| `config` 3 | Optional config overrides, either as nested dict or dict keyed by section value in dot notation, e.g. `"components.name.value"`. ~~Union[Dict[str, Any], Config]~~ |
| **RETURNS** | A `Language` object with the loaded pipeline. ~~Language~~ |
diff --git a/website/docs/usage/processing-pipelines.md b/website/docs/usage/processing-pipelines.md
index 4f75b5193..bd28810ae 100644
--- a/website/docs/usage/processing-pipelines.md
+++ b/website/docs/usage/processing-pipelines.md
@@ -362,6 +362,18 @@ nlp = spacy.load("en_core_web_sm", disable=["tagger", "parser"])
nlp.enable_pipe("tagger")
```
+In addition to `disable`, `spacy.load()` also accepts `enable`. If `enable` is
+set, all components except for those in `enable` are disabled.
+
+```python
+# Load the complete pipeline, but disable all components except for tok2vec and tagger
+nlp = spacy.load("en_core_web_sm", enable=["tok2vec", "tagger"])
+# Has the same effect, as NER is already not part of enabled set of components
+nlp = spacy.load("en_core_web_sm", enable=["tok2vec", "tagger"], disable=["ner"])
+# Will raise an error, as the sets of enabled and disabled components are conflicting
+nlp = spacy.load("en_core_web_sm", enable=["ner"], disable=["ner"])
+```
+
As of v3.0, the `disable` keyword argument specifies components to load but