Simplify and clarify enable/disable behavior of spacy.load() (#11459)

* Change enable/disable behavior so that arguments take precedence over config options. Extend error message on conflict. Add warning message in case of overwriting config option with arguments.

* Fix tests in test_serialize_pipeline.py to reflect changes to handling of enable/disable.

* Fix type issue.

* Move comment.

* Move comment.

* Issue UserWarning instead of printing wasabi message. Adjust test.

* Added pytest.warns(UserWarning) for expected warning to fix tests.

* Update warning message.

* Move type handling out of fetch_pipes_status().

* Add global variable for default value. Use id() to determine whether used values are default value.

* Fix default value for disable.

* Rename DEFAULT_PIPE_STATUS to _DEFAULT_EMPTY_PIPES.
This commit is contained in:
Raphael Mitsch 2022-09-27 14:22:36 +02:00 committed by GitHub
parent 9557b0fb01
commit aea16719be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 92 additions and 43 deletions

View File

@ -31,9 +31,9 @@ def load(
name: Union[str, Path], name: Union[str, Path],
*, *,
vocab: Union[Vocab, bool] = True, vocab: Union[Vocab, bool] = True,
disable: Union[str, Iterable[str]] = util.SimpleFrozenList(), disable: Union[str, Iterable[str]] = util._DEFAULT_EMPTY_PIPES,
enable: Union[str, Iterable[str]] = util.SimpleFrozenList(), enable: Union[str, Iterable[str]] = util._DEFAULT_EMPTY_PIPES,
exclude: Union[str, Iterable[str]] = util.SimpleFrozenList(), exclude: Union[str, Iterable[str]] = util._DEFAULT_EMPTY_PIPES,
config: Union[Dict[str, Any], Config] = util.SimpleFrozenDict(), config: Union[Dict[str, Any], Config] = util.SimpleFrozenDict(),
) -> Language: ) -> Language:
"""Load a spaCy model from an installed package or a local path. """Load a spaCy model from an installed package or a local path.

View File

@ -212,6 +212,8 @@ class Warnings(metaclass=ErrorsWithCodes):
W121 = ("Attempting to trace non-existent method '{method}' in pipe '{pipe}'") W121 = ("Attempting to trace non-existent method '{method}' in pipe '{pipe}'")
W122 = ("Couldn't trace method '{method}' in pipe '{pipe}'. This can happen if the pipe class " W122 = ("Couldn't trace method '{method}' in pipe '{pipe}'. This can happen if the pipe class "
"is a Cython extension type.") "is a Cython extension type.")
W123 = ("Argument {arg} with value {arg_value} is used instead of {config_value} as specified in the config. Be "
"aware that this might affect other components in your pipeline.")
class Errors(metaclass=ErrorsWithCodes): class Errors(metaclass=ErrorsWithCodes):
@ -937,8 +939,9 @@ class Errors(metaclass=ErrorsWithCodes):
E1040 = ("Doc.from_json requires all tokens to have the same attributes. " E1040 = ("Doc.from_json requires all tokens to have the same attributes. "
"Some tokens do not contain annotation for: {partial_attrs}") "Some tokens do not contain annotation for: {partial_attrs}")
E1041 = ("Expected a string, Doc, or bytes as input, but got: {type}") E1041 = ("Expected a string, Doc, or bytes as input, but got: {type}")
E1042 = ("Function was called with `{arg1}`={arg1_values} and " E1042 = ("`enable={enable}` and `disable={disable}` are inconsistent with each other.\nIf you only passed "
"`{arg2}`={arg2_values} but these arguments are conflicting.") "one of `enable` or `disable`, the other argument is specified in your pipeline's configuration.\nIn that "
"case pass an empty list for the previously not specified argument to avoid this error.")
E1043 = ("Expected None or a value in range [{range_start}, {range_end}] for entity linker threshold, but got " E1043 = ("Expected None or a value in range [{range_start}, {range_end}] for entity linker threshold, but got "
"{value}.") "{value}.")

View File

@ -1,4 +1,4 @@
from typing import Iterator, Optional, Any, Dict, Callable, Iterable, Collection from typing import Iterator, Optional, Any, Dict, Callable, Iterable
from typing import Union, Tuple, List, Set, Pattern, Sequence from typing import Union, Tuple, List, Set, Pattern, Sequence
from typing import NoReturn, TYPE_CHECKING, TypeVar, cast, overload from typing import NoReturn, TYPE_CHECKING, TypeVar, cast, overload
@ -10,6 +10,7 @@ from contextlib import contextmanager
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
import warnings import warnings
from thinc.api import get_current_ops, Config, CupyOps, Optimizer from thinc.api import get_current_ops, Config, CupyOps, Optimizer
import srsly import srsly
import multiprocessing as mp import multiprocessing as mp
@ -24,7 +25,7 @@ from .pipe_analysis import validate_attrs, analyze_pipes, print_pipe_analysis
from .training import Example, validate_examples from .training import Example, validate_examples
from .training.initialize import init_vocab, init_tok2vec from .training.initialize import init_vocab, init_tok2vec
from .scorer import Scorer from .scorer import Scorer
from .util import registry, SimpleFrozenList, _pipe, raise_error from .util import registry, SimpleFrozenList, _pipe, raise_error, _DEFAULT_EMPTY_PIPES
from .util import SimpleFrozenDict, combine_score_weights, CONFIG_SECTION_ORDER from .util import SimpleFrozenDict, combine_score_weights, CONFIG_SECTION_ORDER
from .util import warn_if_jupyter_cupy from .util import warn_if_jupyter_cupy
from .lang.tokenizer_exceptions import URL_MATCH, BASE_EXCEPTIONS from .lang.tokenizer_exceptions import URL_MATCH, BASE_EXCEPTIONS
@ -1698,9 +1699,9 @@ class Language:
config: Union[Dict[str, Any], Config] = {}, config: Union[Dict[str, Any], Config] = {},
*, *,
vocab: Union[Vocab, bool] = True, vocab: Union[Vocab, bool] = True,
disable: Union[str, Iterable[str]] = SimpleFrozenList(), disable: Union[str, Iterable[str]] = _DEFAULT_EMPTY_PIPES,
enable: Union[str, Iterable[str]] = SimpleFrozenList(), enable: Union[str, Iterable[str]] = _DEFAULT_EMPTY_PIPES,
exclude: Union[str, Iterable[str]] = SimpleFrozenList(), exclude: Union[str, Iterable[str]] = _DEFAULT_EMPTY_PIPES,
meta: Dict[str, Any] = SimpleFrozenDict(), meta: Dict[str, Any] = SimpleFrozenDict(),
auto_fill: bool = True, auto_fill: bool = True,
validate: bool = True, validate: bool = True,
@ -1727,12 +1728,6 @@ class Language:
DOCS: https://spacy.io/api/language#from_config DOCS: https://spacy.io/api/language#from_config
""" """
if isinstance(disable, str):
disable = [disable]
if isinstance(enable, str):
enable = [enable]
if isinstance(exclude, str):
exclude = [exclude]
if auto_fill: if auto_fill:
config = Config( config = Config(
cls.default_config, section_order=CONFIG_SECTION_ORDER cls.default_config, section_order=CONFIG_SECTION_ORDER
@ -1877,9 +1872,38 @@ class Language:
nlp.vocab.from_bytes(vocab_b) nlp.vocab.from_bytes(vocab_b)
# Resolve disabled/enabled settings. # Resolve disabled/enabled settings.
if isinstance(disable, str):
disable = [disable]
if isinstance(enable, str):
enable = [enable]
if isinstance(exclude, str):
exclude = [exclude]
def fetch_pipes_status(value: Iterable[str], key: str) -> Iterable[str]:
"""Fetch value for `enable` or `disable` w.r.t. the specified config and passed arguments passed to
.load(). If both arguments and config specified values for this field, the passed arguments take precedence
and a warning is printed.
value (Iterable[str]): Passed value for `enable` or `disable`.
key (str): Key for field in config (either "enabled" or "disabled").
RETURN (Iterable[str]):
"""
# We assume that no argument was passed if the value is the specified default value.
if id(value) == id(_DEFAULT_EMPTY_PIPES):
return config["nlp"].get(key, [])
else:
if len(config["nlp"].get(key, [])):
warnings.warn(
Warnings.W123.format(
arg=key[:-1],
arg_value=value,
config_value=config["nlp"][key],
)
)
return value
disabled_pipes = cls._resolve_component_status( disabled_pipes = cls._resolve_component_status(
[*config["nlp"]["disabled"], *disable], fetch_pipes_status(disable, "disabled"),
[*config["nlp"].get("enabled", []), *enable], fetch_pipes_status(enable, "enabled"),
config["nlp"]["pipeline"], config["nlp"]["pipeline"],
) )
nlp._disabled = set(p for p in disabled_pipes if p not in exclude) nlp._disabled = set(p for p in disabled_pipes if p not in exclude)
@ -2064,14 +2088,7 @@ class Language:
pipe_name for pipe_name in pipe_names if pipe_name not in enable pipe_name for pipe_name in pipe_names if pipe_name not in enable
] ]
if disable and disable != to_disable: if disable and disable != to_disable:
raise ValueError( raise ValueError(Errors.E1042.format(enable=enable, disable=disable))
Errors.E1042.format(
arg1="enable",
arg2="disable",
arg1_values=enable,
arg2_values=disable,
)
)
return tuple(to_disable) return tuple(to_disable)

View File

@ -605,10 +605,35 @@ def test_update_with_annotates():
assert results[component] == "" assert results[component] == ""
def test_load_disable_enable() -> None: @pytest.mark.issue(11443)
""" def test_enable_disable_conflict_with_config():
Tests spacy.load() with dis-/enabling components. """Test conflict between enable/disable w.r.t. `nlp.disabled` set in the config."""
""" nlp = English()
nlp.add_pipe("tagger")
nlp.add_pipe("senter")
nlp.add_pipe("sentencizer")
with make_tempdir() as tmp_dir:
nlp.to_disk(tmp_dir)
# Expected to fail, as config and arguments conflict.
with pytest.raises(ValueError):
spacy.load(
tmp_dir, enable=["tagger"], config={"nlp": {"disabled": ["senter"]}}
)
# Expected to succeed without warning due to the lack of a conflicting config option.
spacy.load(tmp_dir, enable=["tagger"])
# Expected to succeed with a warning, as disable=[] should override the config setting.
with pytest.warns(UserWarning):
spacy.load(
tmp_dir,
enable=["tagger"],
disable=[],
config={"nlp": {"disabled": ["senter"]}},
)
def test_load_disable_enable():
"""Tests spacy.load() with dis-/enabling components."""
base_nlp = English() base_nlp = English()
for pipe in ("sentencizer", "tagger", "parser"): for pipe in ("sentencizer", "tagger", "parser"):

View File

@ -404,10 +404,11 @@ def test_serialize_pipeline_disable_enable():
assert nlp3.component_names == ["ner", "tagger"] assert nlp3.component_names == ["ner", "tagger"]
with make_tempdir() as d: with make_tempdir() as d:
nlp3.to_disk(d) nlp3.to_disk(d)
with pytest.warns(UserWarning):
nlp4 = spacy.load(d, disable=["ner"]) nlp4 = spacy.load(d, disable=["ner"])
assert nlp4.pipe_names == [] assert nlp4.pipe_names == ["tagger"]
assert nlp4.component_names == ["ner", "tagger"] assert nlp4.component_names == ["ner", "tagger"]
assert nlp4.disabled == ["ner", "tagger"] assert nlp4.disabled == ["ner"]
with make_tempdir() as d: with make_tempdir() as d:
nlp.to_disk(d) nlp.to_disk(d)
nlp5 = spacy.load(d, exclude=["tagger"]) nlp5 = spacy.load(d, exclude=["tagger"])

View File

@ -67,7 +67,6 @@ LEXEME_NORM_LANGS = ["cs", "da", "de", "el", "en", "id", "lb", "mk", "pt", "ru",
CONFIG_SECTION_ORDER = ["paths", "variables", "system", "nlp", "components", "corpora", "training", "pretraining", "initialize"] CONFIG_SECTION_ORDER = ["paths", "variables", "system", "nlp", "components", "corpora", "training", "pretraining", "initialize"]
# fmt: on # fmt: on
logger = logging.getLogger("spacy") logger = logging.getLogger("spacy")
logger_stream_handler = logging.StreamHandler() logger_stream_handler = logging.StreamHandler()
logger_stream_handler.setFormatter( logger_stream_handler.setFormatter(
@ -394,13 +393,17 @@ def get_module_path(module: ModuleType) -> Path:
return file_path.parent return file_path.parent
# Default value for passed enable/disable values.
_DEFAULT_EMPTY_PIPES = SimpleFrozenList()
def load_model( def load_model(
name: Union[str, Path], name: Union[str, Path],
*, *,
vocab: Union["Vocab", bool] = True, vocab: Union["Vocab", bool] = True,
disable: Union[str, Iterable[str]] = SimpleFrozenList(), disable: Union[str, Iterable[str]] = _DEFAULT_EMPTY_PIPES,
enable: Union[str, Iterable[str]] = SimpleFrozenList(), enable: Union[str, Iterable[str]] = _DEFAULT_EMPTY_PIPES,
exclude: Union[str, Iterable[str]] = SimpleFrozenList(), exclude: Union[str, Iterable[str]] = _DEFAULT_EMPTY_PIPES,
config: Union[Dict[str, Any], Config] = SimpleFrozenDict(), config: Union[Dict[str, Any], Config] = SimpleFrozenDict(),
) -> "Language": ) -> "Language":
"""Load a model from a package or data path. """Load a model from a package or data path.
@ -470,9 +473,9 @@ def load_model_from_path(
*, *,
meta: Optional[Dict[str, Any]] = None, meta: Optional[Dict[str, Any]] = None,
vocab: Union["Vocab", bool] = True, vocab: Union["Vocab", bool] = True,
disable: Union[str, Iterable[str]] = SimpleFrozenList(), disable: Union[str, Iterable[str]] = _DEFAULT_EMPTY_PIPES,
enable: Union[str, Iterable[str]] = SimpleFrozenList(), enable: Union[str, Iterable[str]] = _DEFAULT_EMPTY_PIPES,
exclude: Union[str, Iterable[str]] = SimpleFrozenList(), exclude: Union[str, Iterable[str]] = _DEFAULT_EMPTY_PIPES,
config: Union[Dict[str, Any], Config] = SimpleFrozenDict(), config: Union[Dict[str, Any], Config] = SimpleFrozenDict(),
) -> "Language": ) -> "Language":
"""Load a model from a data directory path. Creates Language class with """Load a model from a data directory path. Creates Language class with
@ -516,9 +519,9 @@ def load_model_from_config(
*, *,
meta: Dict[str, Any] = SimpleFrozenDict(), meta: Dict[str, Any] = SimpleFrozenDict(),
vocab: Union["Vocab", bool] = True, vocab: Union["Vocab", bool] = True,
disable: Union[str, Iterable[str]] = SimpleFrozenList(), disable: Union[str, Iterable[str]] = _DEFAULT_EMPTY_PIPES,
enable: Union[str, Iterable[str]] = SimpleFrozenList(), enable: Union[str, Iterable[str]] = _DEFAULT_EMPTY_PIPES,
exclude: Union[str, Iterable[str]] = SimpleFrozenList(), exclude: Union[str, Iterable[str]] = _DEFAULT_EMPTY_PIPES,
auto_fill: bool = False, auto_fill: bool = False,
validate: bool = True, validate: bool = True,
) -> "Language": ) -> "Language":