mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 00:46:28 +03:00
enable
argument for spacy.load() (#10784)
* Enable flag on spacy.load: foundation for include, enable arguments. * Enable flag on spacy.load: fixed tests. * Enable flag on spacy.load: switched from pretrained model to empty model with added pipes for tests. * Enable flag on spacy.load: switched to more consistent error on misspecification of component activity. Test refactoring. Added to default config. * Enable flag on spacy.load: added support for fields not in pipeline. * Enable flag on spacy.load: removed serialization fields from supported fields. * Enable flag on spacy.load: removed 'enable' from config again. * Enable flag on spacy.load: relaxed checks in _resolve_component_activation_status() to allow non-standard pipes. * Enable flag on spacy.load: fixed relaxed checks for _resolve_component_activation_status() to allow non-standard pipes. Extended tests. * Enable flag on spacy.load: comments w.r.t. resolution workarounds. * Enable flag on spacy.load: remove include fields. Update website docs. * Enable flag on spacy.load: updates w.r.t. changes in master. * Implement Doc.from_json(): update docstrings. Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> * Implement Doc.from_json(): remove newline. Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> * Implement Doc.from_json(): change error message for E1038. Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> * Enable flag on spacy.load: wrapped docstring for _resolve_component_status() at 80 chars. * Enable flag on spacy.load: changed exmples for enable flag. * Remove newline. Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> * Fix docstring for Language._resolve_component_status(). * Rename E1038 to E1042. Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
This commit is contained in:
parent
eaeca5eb6a
commit
4c058eb40a
|
@ -32,6 +32,7 @@ def load(
|
|||
*,
|
||||
vocab: Union[Vocab, bool] = True,
|
||||
disable: Iterable[str] = util.SimpleFrozenList(),
|
||||
enable: Iterable[str] = util.SimpleFrozenList(),
|
||||
exclude: Iterable[str] = util.SimpleFrozenList(),
|
||||
config: Union[Dict[str, Any], Config] = util.SimpleFrozenDict(),
|
||||
) -> Language:
|
||||
|
@ -42,6 +43,8 @@ def load(
|
|||
disable (Iterable[str]): Names of pipeline components to disable. Disabled
|
||||
pipes will be loaded but they won't be run unless you explicitly
|
||||
enable them by calling nlp.enable_pipe.
|
||||
enable (Iterable[str]): Names of pipeline components to enable. All other
|
||||
pipes will be disabled (but can be enabled later using nlp.enable_pipe).
|
||||
exclude (Iterable[str]): Names of pipeline components to exclude. Excluded
|
||||
components won't be loaded.
|
||||
config (Dict[str, Any] / Config): Config overrides as nested dict or dict
|
||||
|
@ -49,7 +52,12 @@ def load(
|
|||
RETURNS (Language): The loaded nlp object.
|
||||
"""
|
||||
return util.load_model(
|
||||
name, vocab=vocab, disable=disable, exclude=exclude, config=config
|
||||
name,
|
||||
vocab=vocab,
|
||||
disable=disable,
|
||||
enable=enable,
|
||||
exclude=exclude,
|
||||
config=config,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -932,6 +932,8 @@ class Errors(metaclass=ErrorsWithCodes):
|
|||
E1040 = ("Doc.from_json requires all tokens to have the same attributes. "
|
||||
"Some tokens do not contain annotation for: {partial_attrs}")
|
||||
E1041 = ("Expected a string, Doc, or bytes as input, but got: {type}")
|
||||
E1042 = ("Function was called with `{arg1}`={arg1_values} and "
|
||||
"`{arg2}`={arg2_values} but these arguments are conflicting.")
|
||||
|
||||
|
||||
# Deprecated model shortcuts, only used in errors and warnings
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Iterator, Optional, Any, Dict, Callable, Iterable
|
||||
from typing import Iterator, Optional, Any, Dict, Callable, Iterable, Collection
|
||||
from typing import Union, Tuple, List, Set, Pattern, Sequence
|
||||
from typing import NoReturn, TYPE_CHECKING, TypeVar, cast, overload
|
||||
|
||||
|
@ -1694,6 +1694,7 @@ class Language:
|
|||
*,
|
||||
vocab: Union[Vocab, bool] = True,
|
||||
disable: Iterable[str] = SimpleFrozenList(),
|
||||
enable: Iterable[str] = SimpleFrozenList(),
|
||||
exclude: Iterable[str] = SimpleFrozenList(),
|
||||
meta: Dict[str, Any] = SimpleFrozenDict(),
|
||||
auto_fill: bool = True,
|
||||
|
@ -1708,6 +1709,8 @@ class Language:
|
|||
disable (Iterable[str]): Names of pipeline components to disable.
|
||||
Disabled pipes will be loaded but they won't be run unless you
|
||||
explicitly enable them by calling nlp.enable_pipe.
|
||||
enable (Iterable[str]): Names of pipeline components to enable. All other
|
||||
pipes will be disabled (and can be enabled using `nlp.enable_pipe`).
|
||||
exclude (Iterable[str]): Names of pipeline components to exclude.
|
||||
Excluded components won't be loaded.
|
||||
meta (Dict[str, Any]): Meta overrides for nlp.meta.
|
||||
|
@ -1861,8 +1864,15 @@ class Language:
|
|||
# Restore the original vocab after sourcing if necessary
|
||||
if vocab_b is not None:
|
||||
nlp.vocab.from_bytes(vocab_b)
|
||||
disabled_pipes = [*config["nlp"]["disabled"], *disable]
|
||||
|
||||
# Resolve disabled/enabled settings.
|
||||
disabled_pipes = cls._resolve_component_status(
|
||||
[*config["nlp"]["disabled"], *disable],
|
||||
[*config["nlp"].get("enabled", []), *enable],
|
||||
config["nlp"]["pipeline"],
|
||||
)
|
||||
nlp._disabled = set(p for p in disabled_pipes if p not in exclude)
|
||||
|
||||
nlp.batch_size = config["nlp"]["batch_size"]
|
||||
nlp.config = filled if auto_fill else config
|
||||
if after_pipeline_creation is not None:
|
||||
|
@ -2014,6 +2024,42 @@ class Language:
|
|||
serializers["vocab"] = lambda p: self.vocab.to_disk(p, exclude=exclude)
|
||||
util.to_disk(path, serializers, exclude)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_component_status(
|
||||
disable: Iterable[str], enable: Iterable[str], pipe_names: Collection[str]
|
||||
) -> Tuple[str, ...]:
|
||||
"""Derives whether (1) `disable` and `enable` values are consistent and (2)
|
||||
resolves those to a single set of disabled components. Raises an error in
|
||||
case of inconsistency.
|
||||
|
||||
disable (Iterable[str]): Names of components or serialization fields to disable.
|
||||
enable (Iterable[str]): Names of pipeline components to enable.
|
||||
pipe_names (Iterable[str]): Names of all pipeline components.
|
||||
|
||||
RETURNS (Tuple[str, ...]): Names of components to exclude from pipeline w.r.t.
|
||||
specified includes and excludes.
|
||||
"""
|
||||
|
||||
if disable is not None and isinstance(disable, str):
|
||||
disable = [disable]
|
||||
to_disable = disable
|
||||
|
||||
if enable:
|
||||
to_disable = [
|
||||
pipe_name for pipe_name in pipe_names if pipe_name not in enable
|
||||
]
|
||||
if disable and disable != to_disable:
|
||||
raise ValueError(
|
||||
Errors.E1042.format(
|
||||
arg1="enable",
|
||||
arg2="disable",
|
||||
arg1_values=enable,
|
||||
arg2_values=disable,
|
||||
)
|
||||
)
|
||||
|
||||
return tuple(to_disable)
|
||||
|
||||
def from_disk(
|
||||
self,
|
||||
path: Union[str, Path],
|
||||
|
|
|
@ -4,13 +4,14 @@ import numpy
|
|||
import pytest
|
||||
from thinc.api import get_current_ops
|
||||
|
||||
import spacy
|
||||
from spacy.lang.en import English
|
||||
from spacy.lang.en.syntax_iterators import noun_chunks
|
||||
from spacy.language import Language
|
||||
from spacy.pipeline import TrainablePipe
|
||||
from spacy.tokens import Doc
|
||||
from spacy.training import Example
|
||||
from spacy.util import SimpleFrozenList, get_arg_names
|
||||
from spacy.util import SimpleFrozenList, get_arg_names, make_tempdir
|
||||
from spacy.vocab import Vocab
|
||||
|
||||
|
||||
|
@ -602,3 +603,52 @@ def test_update_with_annotates():
|
|||
assert results[component] == "".join(eg.predicted.text for eg in examples)
|
||||
for component in components - set(components_to_annotate):
|
||||
assert results[component] == ""
|
||||
|
||||
|
||||
def test_load_disable_enable() -> None:
|
||||
"""
|
||||
Tests spacy.load() with dis-/enabling components.
|
||||
"""
|
||||
|
||||
base_nlp = English()
|
||||
for pipe in ("sentencizer", "tagger", "parser"):
|
||||
base_nlp.add_pipe(pipe)
|
||||
|
||||
with make_tempdir() as tmp_dir:
|
||||
base_nlp.to_disk(tmp_dir)
|
||||
to_disable = ["parser", "tagger"]
|
||||
to_enable = ["tagger", "parser"]
|
||||
|
||||
# Setting only `disable`.
|
||||
nlp = spacy.load(tmp_dir, disable=to_disable)
|
||||
assert all([comp_name in nlp.disabled for comp_name in to_disable])
|
||||
|
||||
# Setting only `enable`.
|
||||
nlp = spacy.load(tmp_dir, enable=to_enable)
|
||||
assert all(
|
||||
[
|
||||
(comp_name in nlp.disabled) is (comp_name not in to_enable)
|
||||
for comp_name in nlp.component_names
|
||||
]
|
||||
)
|
||||
|
||||
# Testing consistent enable/disable combination.
|
||||
nlp = spacy.load(
|
||||
tmp_dir,
|
||||
enable=to_enable,
|
||||
disable=[
|
||||
comp_name
|
||||
for comp_name in nlp.component_names
|
||||
if comp_name not in to_enable
|
||||
],
|
||||
)
|
||||
assert all(
|
||||
[
|
||||
(comp_name in nlp.disabled) is (comp_name not in to_enable)
|
||||
for comp_name in nlp.component_names
|
||||
]
|
||||
)
|
||||
|
||||
# Inconsistent enable/disable combination.
|
||||
with pytest.raises(ValueError):
|
||||
spacy.load(tmp_dir, enable=to_enable, disable=["parser"])
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from typing import List, Mapping, NoReturn, Union, Dict, Any, Set, cast
|
||||
from typing import Optional, Iterable, Callable, Tuple, Type
|
||||
from typing import Iterator, Type, Pattern, Generator, TYPE_CHECKING
|
||||
from typing import Iterator, Pattern, Generator, TYPE_CHECKING
|
||||
from types import ModuleType
|
||||
import os
|
||||
import importlib
|
||||
|
@ -12,7 +12,6 @@ from thinc.api import NumpyOps, get_current_ops, Adam, Config, Optimizer
|
|||
from thinc.api import ConfigValidationError, Model
|
||||
import functools
|
||||
import itertools
|
||||
import numpy.random
|
||||
import numpy
|
||||
import srsly
|
||||
import catalogue
|
||||
|
@ -400,6 +399,7 @@ def load_model(
|
|||
*,
|
||||
vocab: Union["Vocab", bool] = True,
|
||||
disable: Iterable[str] = SimpleFrozenList(),
|
||||
enable: Iterable[str] = SimpleFrozenList(),
|
||||
exclude: Iterable[str] = SimpleFrozenList(),
|
||||
config: Union[Dict[str, Any], Config] = SimpleFrozenDict(),
|
||||
) -> "Language":
|
||||
|
@ -409,11 +409,19 @@ def load_model(
|
|||
vocab (Vocab / True): Optional vocab to pass in on initialization. If True,
|
||||
a new Vocab object will be created.
|
||||
disable (Iterable[str]): Names of pipeline components to disable.
|
||||
enable (Iterable[str]): Names of pipeline components to enable. All others will be disabled.
|
||||
exclude (Iterable[str]): Names of pipeline components to exclude.
|
||||
config (Dict[str, Any] / Config): Config overrides as nested dict or dict
|
||||
keyed by section values in dot notation.
|
||||
RETURNS (Language): The loaded nlp object.
|
||||
"""
|
||||
kwargs = {"vocab": vocab, "disable": disable, "exclude": exclude, "config": config}
|
||||
kwargs = {
|
||||
"vocab": vocab,
|
||||
"disable": disable,
|
||||
"enable": enable,
|
||||
"exclude": exclude,
|
||||
"config": config,
|
||||
}
|
||||
if isinstance(name, str): # name or string path
|
||||
if name.startswith("blank:"): # shortcut for blank model
|
||||
return get_lang_class(name.replace("blank:", ""))()
|
||||
|
@ -433,6 +441,7 @@ def load_model_from_package(
|
|||
*,
|
||||
vocab: Union["Vocab", bool] = True,
|
||||
disable: Iterable[str] = SimpleFrozenList(),
|
||||
enable: Iterable[str] = SimpleFrozenList(),
|
||||
exclude: Iterable[str] = SimpleFrozenList(),
|
||||
config: Union[Dict[str, Any], Config] = SimpleFrozenDict(),
|
||||
) -> "Language":
|
||||
|
@ -444,6 +453,8 @@ def load_model_from_package(
|
|||
disable (Iterable[str]): Names of pipeline components to disable. Disabled
|
||||
pipes will be loaded but they won't be run unless you explicitly
|
||||
enable them by calling nlp.enable_pipe.
|
||||
enable (Iterable[str]): Names of pipeline components to enable. All other
|
||||
pipes will be disabled (and can be enabled using `nlp.enable_pipe`).
|
||||
exclude (Iterable[str]): Names of pipeline components to exclude. Excluded
|
||||
components won't be loaded.
|
||||
config (Dict[str, Any] / Config): Config overrides as nested dict or dict
|
||||
|
@ -451,7 +462,7 @@ def load_model_from_package(
|
|||
RETURNS (Language): The loaded nlp object.
|
||||
"""
|
||||
cls = importlib.import_module(name)
|
||||
return cls.load(vocab=vocab, disable=disable, exclude=exclude, config=config) # type: ignore[attr-defined]
|
||||
return cls.load(vocab=vocab, disable=disable, enable=enable, exclude=exclude, config=config) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def load_model_from_path(
|
||||
|
@ -460,6 +471,7 @@ def load_model_from_path(
|
|||
meta: Optional[Dict[str, Any]] = None,
|
||||
vocab: Union["Vocab", bool] = True,
|
||||
disable: Iterable[str] = SimpleFrozenList(),
|
||||
enable: Iterable[str] = SimpleFrozenList(),
|
||||
exclude: Iterable[str] = SimpleFrozenList(),
|
||||
config: Union[Dict[str, Any], Config] = SimpleFrozenDict(),
|
||||
) -> "Language":
|
||||
|
@ -473,6 +485,8 @@ def load_model_from_path(
|
|||
disable (Iterable[str]): Names of pipeline components to disable. Disabled
|
||||
pipes will be loaded but they won't be run unless you explicitly
|
||||
enable them by calling nlp.enable_pipe.
|
||||
enable (Iterable[str]): Names of pipeline components to enable. All other
|
||||
pipes will be disabled (and can be enabled using `nlp.enable_pipe`).
|
||||
exclude (Iterable[str]): Names of pipeline components to exclude. Excluded
|
||||
components won't be loaded.
|
||||
config (Dict[str, Any] / Config): Config overrides as nested dict or dict
|
||||
|
@ -487,7 +501,12 @@ def load_model_from_path(
|
|||
overrides = dict_to_dot(config)
|
||||
config = load_config(config_path, overrides=overrides)
|
||||
nlp = load_model_from_config(
|
||||
config, vocab=vocab, disable=disable, exclude=exclude, meta=meta
|
||||
config,
|
||||
vocab=vocab,
|
||||
disable=disable,
|
||||
enable=enable,
|
||||
exclude=exclude,
|
||||
meta=meta,
|
||||
)
|
||||
return nlp.from_disk(model_path, exclude=exclude, overrides=overrides)
|
||||
|
||||
|
@ -498,6 +517,7 @@ def load_model_from_config(
|
|||
meta: Dict[str, Any] = SimpleFrozenDict(),
|
||||
vocab: Union["Vocab", bool] = True,
|
||||
disable: Iterable[str] = SimpleFrozenList(),
|
||||
enable: Iterable[str] = SimpleFrozenList(),
|
||||
exclude: Iterable[str] = SimpleFrozenList(),
|
||||
auto_fill: bool = False,
|
||||
validate: bool = True,
|
||||
|
@ -512,6 +532,8 @@ def load_model_from_config(
|
|||
disable (Iterable[str]): Names of pipeline components to disable. Disabled
|
||||
pipes will be loaded but they won't be run unless you explicitly
|
||||
enable them by calling nlp.enable_pipe.
|
||||
enable (Iterable[str]): Names of pipeline components to enable. All other
|
||||
pipes will be disabled (and can be enabled using `nlp.enable_pipe`).
|
||||
exclude (Iterable[str]): Names of pipeline components to exclude. Excluded
|
||||
components won't be loaded.
|
||||
auto_fill (bool): Whether to auto-fill config with missing defaults.
|
||||
|
@ -530,6 +552,7 @@ def load_model_from_config(
|
|||
config,
|
||||
vocab=vocab,
|
||||
disable=disable,
|
||||
enable=enable,
|
||||
exclude=exclude,
|
||||
auto_fill=auto_fill,
|
||||
validate=validate,
|
||||
|
@ -594,6 +617,7 @@ def load_model_from_init_py(
|
|||
*,
|
||||
vocab: Union["Vocab", bool] = True,
|
||||
disable: Iterable[str] = SimpleFrozenList(),
|
||||
enable: Iterable[str] = SimpleFrozenList(),
|
||||
exclude: Iterable[str] = SimpleFrozenList(),
|
||||
config: Union[Dict[str, Any], Config] = SimpleFrozenDict(),
|
||||
) -> "Language":
|
||||
|
@ -605,6 +629,8 @@ def load_model_from_init_py(
|
|||
disable (Iterable[str]): Names of pipeline components to disable. Disabled
|
||||
pipes will be loaded but they won't be run unless you explicitly
|
||||
enable them by calling nlp.enable_pipe.
|
||||
enable (Iterable[str]): Names of pipeline components to enable. All other
|
||||
pipes will be disabled (and can be enabled using `nlp.enable_pipe`).
|
||||
exclude (Iterable[str]): Names of pipeline components to exclude. Excluded
|
||||
components won't be loaded.
|
||||
config (Dict[str, Any] / Config): Config overrides as nested dict or dict
|
||||
|
@ -622,6 +648,7 @@ def load_model_from_init_py(
|
|||
vocab=vocab,
|
||||
meta=meta,
|
||||
disable=disable,
|
||||
enable=enable,
|
||||
exclude=exclude,
|
||||
config=config,
|
||||
)
|
||||
|
|
|
@ -51,6 +51,7 @@ specified separately using the new `exclude` keyword argument.
|
|||
| _keyword-only_ | |
|
||||
| `vocab` | Optional shared vocab to pass in on initialization. If `True` (default), a new `Vocab` object will be created. ~~Union[Vocab, bool]~~ |
|
||||
| `disable` | Names of pipeline components to [disable](/usage/processing-pipelines#disabling). Disabled pipes will be loaded but they won't be run unless you explicitly enable them by calling [nlp.enable_pipe](/api/language#enable_pipe). ~~List[str]~~ |
|
||||
| `enable` | Names of pipeline components to [enable](/usage/processing-pipelines#disabling). All other pipes will be disabled. ~~List[str]~~ |
|
||||
| `exclude` <Tag variant="new">3</Tag> | Names of pipeline components to [exclude](/usage/processing-pipelines#disabling). Excluded components won't be loaded. ~~List[str]~~ |
|
||||
| `config` <Tag variant="new">3</Tag> | Optional config overrides, either as nested dict or dict keyed by section value in dot notation, e.g. `"components.name.value"`. ~~Union[Dict[str, Any], Config]~~ |
|
||||
| **RETURNS** | A `Language` object with the loaded pipeline. ~~Language~~ |
|
||||
|
|
|
@ -362,6 +362,18 @@ nlp = spacy.load("en_core_web_sm", disable=["tagger", "parser"])
|
|||
nlp.enable_pipe("tagger")
|
||||
```
|
||||
|
||||
In addition to `disable`, `spacy.load()` also accepts `enable`. If `enable` is
|
||||
set, all components except for those in `enable` are disabled.
|
||||
|
||||
```python
|
||||
# Load the complete pipeline, but disable all components except for tok2vec and tagger
|
||||
nlp = spacy.load("en_core_web_sm", enable=["tok2vec", "tagger"])
|
||||
# Has the same effect, as NER is already not part of enabled set of components
|
||||
nlp = spacy.load("en_core_web_sm", enable=["tok2vec", "tagger"], disable=["ner"])
|
||||
# Will raise an error, as the sets of enabled and disabled components are conflicting
|
||||
nlp = spacy.load("en_core_web_sm", enable=["ner"], disable=["ner"])
|
||||
```
|
||||
|
||||
<Infobox variant="warning" title="Changed in v3.0">
|
||||
|
||||
As of v3.0, the `disable` keyword argument specifies components to load but
|
||||
|
|
Loading…
Reference in New Issue
Block a user