diff --git a/spacy/__init__.py b/spacy/__init__.py index ca47edc94..069215fda 100644 --- a/spacy/__init__.py +++ b/spacy/__init__.py @@ -32,6 +32,7 @@ def load( *, vocab: Union[Vocab, bool] = True, disable: Iterable[str] = util.SimpleFrozenList(), + enable: Iterable[str] = util.SimpleFrozenList(), exclude: Iterable[str] = util.SimpleFrozenList(), config: Union[Dict[str, Any], Config] = util.SimpleFrozenDict(), ) -> Language: @@ -42,6 +43,8 @@ def load( disable (Iterable[str]): Names of pipeline components to disable. Disabled pipes will be loaded but they won't be run unless you explicitly enable them by calling nlp.enable_pipe. + enable (Iterable[str]): Names of pipeline components to enable. All other + pipes will be disabled (but can be enabled later using nlp.enable_pipe). exclude (Iterable[str]): Names of pipeline components to exclude. Excluded components won't be loaded. config (Dict[str, Any] / Config): Config overrides as nested dict or dict @@ -49,7 +52,12 @@ def load( RETURNS (Language): The loaded nlp object. """ return util.load_model( - name, vocab=vocab, disable=disable, exclude=exclude, config=config + name, + vocab=vocab, + disable=disable, + enable=enable, + exclude=exclude, + config=config, ) diff --git a/spacy/errors.py b/spacy/errors.py index 384a6a4d2..14010565b 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -932,6 +932,8 @@ class Errors(metaclass=ErrorsWithCodes): E1040 = ("Doc.from_json requires all tokens to have the same attributes. " "Some tokens do not contain annotation for: {partial_attrs}") E1041 = ("Expected a string, Doc, or bytes as input, but got: {type}") + E1042 = ("Function was called with `{arg1}`={arg1_values} and " + "`{arg2}`={arg2_values} but these arguments are conflicting.") # Deprecated model shortcuts, only used in errors and warnings diff --git a/spacy/language.py b/spacy/language.py index 42847823f..816bd6531 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -1,4 +1,4 @@ -from typing import Iterator, Optional, Any, Dict, Callable, Iterable +from typing import Iterator, Optional, Any, Dict, Callable, Iterable, Collection from typing import Union, Tuple, List, Set, Pattern, Sequence from typing import NoReturn, TYPE_CHECKING, TypeVar, cast, overload @@ -1694,6 +1694,7 @@ class Language: *, vocab: Union[Vocab, bool] = True, disable: Iterable[str] = SimpleFrozenList(), + enable: Iterable[str] = SimpleFrozenList(), exclude: Iterable[str] = SimpleFrozenList(), meta: Dict[str, Any] = SimpleFrozenDict(), auto_fill: bool = True, @@ -1708,6 +1709,8 @@ class Language: disable (Iterable[str]): Names of pipeline components to disable. Disabled pipes will be loaded but they won't be run unless you explicitly enable them by calling nlp.enable_pipe. + enable (Iterable[str]): Names of pipeline components to enable. All other + pipes will be disabled (and can be enabled using `nlp.enable_pipe`). exclude (Iterable[str]): Names of pipeline components to exclude. Excluded components won't be loaded. meta (Dict[str, Any]): Meta overrides for nlp.meta. @@ -1861,8 +1864,15 @@ class Language: # Restore the original vocab after sourcing if necessary if vocab_b is not None: nlp.vocab.from_bytes(vocab_b) - disabled_pipes = [*config["nlp"]["disabled"], *disable] + + # Resolve disabled/enabled settings. + disabled_pipes = cls._resolve_component_status( + [*config["nlp"]["disabled"], *disable], + [*config["nlp"].get("enabled", []), *enable], + config["nlp"]["pipeline"], + ) nlp._disabled = set(p for p in disabled_pipes if p not in exclude) + nlp.batch_size = config["nlp"]["batch_size"] nlp.config = filled if auto_fill else config if after_pipeline_creation is not None: @@ -2014,6 +2024,42 @@ class Language: serializers["vocab"] = lambda p: self.vocab.to_disk(p, exclude=exclude) util.to_disk(path, serializers, exclude) + @staticmethod + def _resolve_component_status( + disable: Iterable[str], enable: Iterable[str], pipe_names: Collection[str] + ) -> Tuple[str, ...]: + """Derives whether (1) `disable` and `enable` values are consistent and (2) + resolves those to a single set of disabled components. Raises an error in + case of inconsistency. + + disable (Iterable[str]): Names of components or serialization fields to disable. + enable (Iterable[str]): Names of pipeline components to enable. + pipe_names (Iterable[str]): Names of all pipeline components. + + RETURNS (Tuple[str, ...]): Names of components to exclude from pipeline w.r.t. + specified includes and excludes. + """ + + if disable is not None and isinstance(disable, str): + disable = [disable] + to_disable = disable + + if enable: + to_disable = [ + pipe_name for pipe_name in pipe_names if pipe_name not in enable + ] + if disable and disable != to_disable: + raise ValueError( + Errors.E1042.format( + arg1="enable", + arg2="disable", + arg1_values=enable, + arg2_values=disable, + ) + ) + + return tuple(to_disable) + def from_disk( self, path: Union[str, Path], diff --git a/spacy/tests/pipeline/test_pipe_methods.py b/spacy/tests/pipeline/test_pipe_methods.py index 4b8fb8ebc..6f00a1cd9 100644 --- a/spacy/tests/pipeline/test_pipe_methods.py +++ b/spacy/tests/pipeline/test_pipe_methods.py @@ -4,13 +4,14 @@ import numpy import pytest from thinc.api import get_current_ops +import spacy from spacy.lang.en import English from spacy.lang.en.syntax_iterators import noun_chunks from spacy.language import Language from spacy.pipeline import TrainablePipe from spacy.tokens import Doc from spacy.training import Example -from spacy.util import SimpleFrozenList, get_arg_names +from spacy.util import SimpleFrozenList, get_arg_names, make_tempdir from spacy.vocab import Vocab @@ -602,3 +603,52 @@ def test_update_with_annotates(): assert results[component] == "".join(eg.predicted.text for eg in examples) for component in components - set(components_to_annotate): assert results[component] == "" + + +def test_load_disable_enable() -> None: + """ + Tests spacy.load() with dis-/enabling components. + """ + + base_nlp = English() + for pipe in ("sentencizer", "tagger", "parser"): + base_nlp.add_pipe(pipe) + + with make_tempdir() as tmp_dir: + base_nlp.to_disk(tmp_dir) + to_disable = ["parser", "tagger"] + to_enable = ["tagger", "parser"] + + # Setting only `disable`. + nlp = spacy.load(tmp_dir, disable=to_disable) + assert all([comp_name in nlp.disabled for comp_name in to_disable]) + + # Setting only `enable`. + nlp = spacy.load(tmp_dir, enable=to_enable) + assert all( + [ + (comp_name in nlp.disabled) is (comp_name not in to_enable) + for comp_name in nlp.component_names + ] + ) + + # Testing consistent enable/disable combination. + nlp = spacy.load( + tmp_dir, + enable=to_enable, + disable=[ + comp_name + for comp_name in nlp.component_names + if comp_name not in to_enable + ], + ) + assert all( + [ + (comp_name in nlp.disabled) is (comp_name not in to_enable) + for comp_name in nlp.component_names + ] + ) + + # Inconsistent enable/disable combination. + with pytest.raises(ValueError): + spacy.load(tmp_dir, enable=to_enable, disable=["parser"]) diff --git a/spacy/util.py b/spacy/util.py index 0111c839e..9b871b87b 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -1,6 +1,6 @@ from typing import List, Mapping, NoReturn, Union, Dict, Any, Set, cast from typing import Optional, Iterable, Callable, Tuple, Type -from typing import Iterator, Type, Pattern, Generator, TYPE_CHECKING +from typing import Iterator, Pattern, Generator, TYPE_CHECKING from types import ModuleType import os import importlib @@ -12,7 +12,6 @@ from thinc.api import NumpyOps, get_current_ops, Adam, Config, Optimizer from thinc.api import ConfigValidationError, Model import functools import itertools -import numpy.random import numpy import srsly import catalogue @@ -400,6 +399,7 @@ def load_model( *, vocab: Union["Vocab", bool] = True, disable: Iterable[str] = SimpleFrozenList(), + enable: Iterable[str] = SimpleFrozenList(), exclude: Iterable[str] = SimpleFrozenList(), config: Union[Dict[str, Any], Config] = SimpleFrozenDict(), ) -> "Language": @@ -409,11 +409,19 @@ def load_model( vocab (Vocab / True): Optional vocab to pass in on initialization. If True, a new Vocab object will be created. disable (Iterable[str]): Names of pipeline components to disable. + enable (Iterable[str]): Names of pipeline components to enable. All others will be disabled. + exclude (Iterable[str]): Names of pipeline components to exclude. config (Dict[str, Any] / Config): Config overrides as nested dict or dict keyed by section values in dot notation. RETURNS (Language): The loaded nlp object. """ - kwargs = {"vocab": vocab, "disable": disable, "exclude": exclude, "config": config} + kwargs = { + "vocab": vocab, + "disable": disable, + "enable": enable, + "exclude": exclude, + "config": config, + } if isinstance(name, str): # name or string path if name.startswith("blank:"): # shortcut for blank model return get_lang_class(name.replace("blank:", ""))() @@ -433,6 +441,7 @@ def load_model_from_package( *, vocab: Union["Vocab", bool] = True, disable: Iterable[str] = SimpleFrozenList(), + enable: Iterable[str] = SimpleFrozenList(), exclude: Iterable[str] = SimpleFrozenList(), config: Union[Dict[str, Any], Config] = SimpleFrozenDict(), ) -> "Language": @@ -444,6 +453,8 @@ def load_model_from_package( disable (Iterable[str]): Names of pipeline components to disable. Disabled pipes will be loaded but they won't be run unless you explicitly enable them by calling nlp.enable_pipe. + enable (Iterable[str]): Names of pipeline components to enable. All other + pipes will be disabled (and can be enabled using `nlp.enable_pipe`). exclude (Iterable[str]): Names of pipeline components to exclude. Excluded components won't be loaded. config (Dict[str, Any] / Config): Config overrides as nested dict or dict @@ -451,7 +462,7 @@ def load_model_from_package( RETURNS (Language): The loaded nlp object. """ cls = importlib.import_module(name) - return cls.load(vocab=vocab, disable=disable, exclude=exclude, config=config) # type: ignore[attr-defined] + return cls.load(vocab=vocab, disable=disable, enable=enable, exclude=exclude, config=config) # type: ignore[attr-defined] def load_model_from_path( @@ -460,6 +471,7 @@ def load_model_from_path( meta: Optional[Dict[str, Any]] = None, vocab: Union["Vocab", bool] = True, disable: Iterable[str] = SimpleFrozenList(), + enable: Iterable[str] = SimpleFrozenList(), exclude: Iterable[str] = SimpleFrozenList(), config: Union[Dict[str, Any], Config] = SimpleFrozenDict(), ) -> "Language": @@ -473,6 +485,8 @@ def load_model_from_path( disable (Iterable[str]): Names of pipeline components to disable. Disabled pipes will be loaded but they won't be run unless you explicitly enable them by calling nlp.enable_pipe. + enable (Iterable[str]): Names of pipeline components to enable. All other + pipes will be disabled (and can be enabled using `nlp.enable_pipe`). exclude (Iterable[str]): Names of pipeline components to exclude. Excluded components won't be loaded. config (Dict[str, Any] / Config): Config overrides as nested dict or dict @@ -487,7 +501,12 @@ def load_model_from_path( overrides = dict_to_dot(config) config = load_config(config_path, overrides=overrides) nlp = load_model_from_config( - config, vocab=vocab, disable=disable, exclude=exclude, meta=meta + config, + vocab=vocab, + disable=disable, + enable=enable, + exclude=exclude, + meta=meta, ) return nlp.from_disk(model_path, exclude=exclude, overrides=overrides) @@ -498,6 +517,7 @@ def load_model_from_config( meta: Dict[str, Any] = SimpleFrozenDict(), vocab: Union["Vocab", bool] = True, disable: Iterable[str] = SimpleFrozenList(), + enable: Iterable[str] = SimpleFrozenList(), exclude: Iterable[str] = SimpleFrozenList(), auto_fill: bool = False, validate: bool = True, @@ -512,6 +532,8 @@ def load_model_from_config( disable (Iterable[str]): Names of pipeline components to disable. Disabled pipes will be loaded but they won't be run unless you explicitly enable them by calling nlp.enable_pipe. + enable (Iterable[str]): Names of pipeline components to enable. All other + pipes will be disabled (and can be enabled using `nlp.enable_pipe`). exclude (Iterable[str]): Names of pipeline components to exclude. Excluded components won't be loaded. auto_fill (bool): Whether to auto-fill config with missing defaults. @@ -530,6 +552,7 @@ def load_model_from_config( config, vocab=vocab, disable=disable, + enable=enable, exclude=exclude, auto_fill=auto_fill, validate=validate, @@ -594,6 +617,7 @@ def load_model_from_init_py( *, vocab: Union["Vocab", bool] = True, disable: Iterable[str] = SimpleFrozenList(), + enable: Iterable[str] = SimpleFrozenList(), exclude: Iterable[str] = SimpleFrozenList(), config: Union[Dict[str, Any], Config] = SimpleFrozenDict(), ) -> "Language": @@ -605,6 +629,8 @@ def load_model_from_init_py( disable (Iterable[str]): Names of pipeline components to disable. Disabled pipes will be loaded but they won't be run unless you explicitly enable them by calling nlp.enable_pipe. + enable (Iterable[str]): Names of pipeline components to enable. All other + pipes will be disabled (and can be enabled using `nlp.enable_pipe`). exclude (Iterable[str]): Names of pipeline components to exclude. Excluded components won't be loaded. config (Dict[str, Any] / Config): Config overrides as nested dict or dict @@ -622,6 +648,7 @@ def load_model_from_init_py( vocab=vocab, meta=meta, disable=disable, + enable=enable, exclude=exclude, config=config, ) diff --git a/website/docs/api/top-level.md b/website/docs/api/top-level.md index 889c6437c..c96c571e9 100644 --- a/website/docs/api/top-level.md +++ b/website/docs/api/top-level.md @@ -51,6 +51,7 @@ specified separately using the new `exclude` keyword argument. | _keyword-only_ | | | `vocab` | Optional shared vocab to pass in on initialization. If `True` (default), a new `Vocab` object will be created. ~~Union[Vocab, bool]~~ | | `disable` | Names of pipeline components to [disable](/usage/processing-pipelines#disabling). Disabled pipes will be loaded but they won't be run unless you explicitly enable them by calling [nlp.enable_pipe](/api/language#enable_pipe). ~~List[str]~~ | +| `enable` | Names of pipeline components to [enable](/usage/processing-pipelines#disabling). All other pipes will be disabled. ~~List[str]~~ | | `exclude` 3 | Names of pipeline components to [exclude](/usage/processing-pipelines#disabling). Excluded components won't be loaded. ~~List[str]~~ | | `config` 3 | Optional config overrides, either as nested dict or dict keyed by section value in dot notation, e.g. `"components.name.value"`. ~~Union[Dict[str, Any], Config]~~ | | **RETURNS** | A `Language` object with the loaded pipeline. ~~Language~~ | diff --git a/website/docs/usage/processing-pipelines.md b/website/docs/usage/processing-pipelines.md index 4f75b5193..bd28810ae 100644 --- a/website/docs/usage/processing-pipelines.md +++ b/website/docs/usage/processing-pipelines.md @@ -362,6 +362,18 @@ nlp = spacy.load("en_core_web_sm", disable=["tagger", "parser"]) nlp.enable_pipe("tagger") ``` +In addition to `disable`, `spacy.load()` also accepts `enable`. If `enable` is +set, all components except for those in `enable` are disabled. + +```python +# Load the complete pipeline, but disable all components except for tok2vec and tagger +nlp = spacy.load("en_core_web_sm", enable=["tok2vec", "tagger"]) +# Has the same effect, as NER is already not part of enabled set of components +nlp = spacy.load("en_core_web_sm", enable=["tok2vec", "tagger"], disable=["ner"]) +# Will raise an error, as the sets of enabled and disabled components are conflicting +nlp = spacy.load("en_core_web_sm", enable=["ner"], disable=["ner"]) +``` + As of v3.0, the `disable` keyword argument specifies components to load but