mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
Simplify and clarify enable/disable behavior of spacy.load() (#11459)
* Change enable/disable behavior so that arguments take precedence over config options. Extend error message on conflict. Add warning message in case of overwriting config option with arguments. * Fix tests in test_serialize_pipeline.py to reflect changes to handling of enable/disable. * Fix type issue. * Move comment. * Move comment. * Issue UserWarning instead of printing wasabi message. Adjust test. * Added pytest.warns(UserWarning) for expected warning to fix tests. * Update warning message. * Move type handling out of fetch_pipes_status(). * Add global variable for default value. Use id() to determine whether used values are default value. * Fix default value for disable. * Rename DEFAULT_PIPE_STATUS to _DEFAULT_EMPTY_PIPES.
This commit is contained in:
parent
9557b0fb01
commit
aea16719be
|
@ -31,9 +31,9 @@ def load(
|
||||||
name: Union[str, Path],
|
name: Union[str, Path],
|
||||||
*,
|
*,
|
||||||
vocab: Union[Vocab, bool] = True,
|
vocab: Union[Vocab, bool] = True,
|
||||||
disable: Union[str, Iterable[str]] = util.SimpleFrozenList(),
|
disable: Union[str, Iterable[str]] = util._DEFAULT_EMPTY_PIPES,
|
||||||
enable: Union[str, Iterable[str]] = util.SimpleFrozenList(),
|
enable: Union[str, Iterable[str]] = util._DEFAULT_EMPTY_PIPES,
|
||||||
exclude: Union[str, Iterable[str]] = util.SimpleFrozenList(),
|
exclude: Union[str, Iterable[str]] = util._DEFAULT_EMPTY_PIPES,
|
||||||
config: Union[Dict[str, Any], Config] = util.SimpleFrozenDict(),
|
config: Union[Dict[str, Any], Config] = util.SimpleFrozenDict(),
|
||||||
) -> Language:
|
) -> Language:
|
||||||
"""Load a spaCy model from an installed package or a local path.
|
"""Load a spaCy model from an installed package or a local path.
|
||||||
|
|
|
@ -212,6 +212,8 @@ class Warnings(metaclass=ErrorsWithCodes):
|
||||||
W121 = ("Attempting to trace non-existent method '{method}' in pipe '{pipe}'")
|
W121 = ("Attempting to trace non-existent method '{method}' in pipe '{pipe}'")
|
||||||
W122 = ("Couldn't trace method '{method}' in pipe '{pipe}'. This can happen if the pipe class "
|
W122 = ("Couldn't trace method '{method}' in pipe '{pipe}'. This can happen if the pipe class "
|
||||||
"is a Cython extension type.")
|
"is a Cython extension type.")
|
||||||
|
W123 = ("Argument {arg} with value {arg_value} is used instead of {config_value} as specified in the config. Be "
|
||||||
|
"aware that this might affect other components in your pipeline.")
|
||||||
|
|
||||||
|
|
||||||
class Errors(metaclass=ErrorsWithCodes):
|
class Errors(metaclass=ErrorsWithCodes):
|
||||||
|
@ -937,8 +939,9 @@ class Errors(metaclass=ErrorsWithCodes):
|
||||||
E1040 = ("Doc.from_json requires all tokens to have the same attributes. "
|
E1040 = ("Doc.from_json requires all tokens to have the same attributes. "
|
||||||
"Some tokens do not contain annotation for: {partial_attrs}")
|
"Some tokens do not contain annotation for: {partial_attrs}")
|
||||||
E1041 = ("Expected a string, Doc, or bytes as input, but got: {type}")
|
E1041 = ("Expected a string, Doc, or bytes as input, but got: {type}")
|
||||||
E1042 = ("Function was called with `{arg1}`={arg1_values} and "
|
E1042 = ("`enable={enable}` and `disable={disable}` are inconsistent with each other.\nIf you only passed "
|
||||||
"`{arg2}`={arg2_values} but these arguments are conflicting.")
|
"one of `enable` or `disable`, the other argument is specified in your pipeline's configuration.\nIn that "
|
||||||
|
"case pass an empty list for the previously not specified argument to avoid this error.")
|
||||||
E1043 = ("Expected None or a value in range [{range_start}, {range_end}] for entity linker threshold, but got "
|
E1043 = ("Expected None or a value in range [{range_start}, {range_end}] for entity linker threshold, but got "
|
||||||
"{value}.")
|
"{value}.")
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Iterator, Optional, Any, Dict, Callable, Iterable, Collection
|
from typing import Iterator, Optional, Any, Dict, Callable, Iterable
|
||||||
from typing import Union, Tuple, List, Set, Pattern, Sequence
|
from typing import Union, Tuple, List, Set, Pattern, Sequence
|
||||||
from typing import NoReturn, TYPE_CHECKING, TypeVar, cast, overload
|
from typing import NoReturn, TYPE_CHECKING, TypeVar, cast, overload
|
||||||
|
|
||||||
|
@ -10,6 +10,7 @@ from contextlib import contextmanager
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from thinc.api import get_current_ops, Config, CupyOps, Optimizer
|
from thinc.api import get_current_ops, Config, CupyOps, Optimizer
|
||||||
import srsly
|
import srsly
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
|
@ -24,7 +25,7 @@ from .pipe_analysis import validate_attrs, analyze_pipes, print_pipe_analysis
|
||||||
from .training import Example, validate_examples
|
from .training import Example, validate_examples
|
||||||
from .training.initialize import init_vocab, init_tok2vec
|
from .training.initialize import init_vocab, init_tok2vec
|
||||||
from .scorer import Scorer
|
from .scorer import Scorer
|
||||||
from .util import registry, SimpleFrozenList, _pipe, raise_error
|
from .util import registry, SimpleFrozenList, _pipe, raise_error, _DEFAULT_EMPTY_PIPES
|
||||||
from .util import SimpleFrozenDict, combine_score_weights, CONFIG_SECTION_ORDER
|
from .util import SimpleFrozenDict, combine_score_weights, CONFIG_SECTION_ORDER
|
||||||
from .util import warn_if_jupyter_cupy
|
from .util import warn_if_jupyter_cupy
|
||||||
from .lang.tokenizer_exceptions import URL_MATCH, BASE_EXCEPTIONS
|
from .lang.tokenizer_exceptions import URL_MATCH, BASE_EXCEPTIONS
|
||||||
|
@ -1698,9 +1699,9 @@ class Language:
|
||||||
config: Union[Dict[str, Any], Config] = {},
|
config: Union[Dict[str, Any], Config] = {},
|
||||||
*,
|
*,
|
||||||
vocab: Union[Vocab, bool] = True,
|
vocab: Union[Vocab, bool] = True,
|
||||||
disable: Union[str, Iterable[str]] = SimpleFrozenList(),
|
disable: Union[str, Iterable[str]] = _DEFAULT_EMPTY_PIPES,
|
||||||
enable: Union[str, Iterable[str]] = SimpleFrozenList(),
|
enable: Union[str, Iterable[str]] = _DEFAULT_EMPTY_PIPES,
|
||||||
exclude: Union[str, Iterable[str]] = SimpleFrozenList(),
|
exclude: Union[str, Iterable[str]] = _DEFAULT_EMPTY_PIPES,
|
||||||
meta: Dict[str, Any] = SimpleFrozenDict(),
|
meta: Dict[str, Any] = SimpleFrozenDict(),
|
||||||
auto_fill: bool = True,
|
auto_fill: bool = True,
|
||||||
validate: bool = True,
|
validate: bool = True,
|
||||||
|
@ -1727,12 +1728,6 @@ class Language:
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/language#from_config
|
DOCS: https://spacy.io/api/language#from_config
|
||||||
"""
|
"""
|
||||||
if isinstance(disable, str):
|
|
||||||
disable = [disable]
|
|
||||||
if isinstance(enable, str):
|
|
||||||
enable = [enable]
|
|
||||||
if isinstance(exclude, str):
|
|
||||||
exclude = [exclude]
|
|
||||||
if auto_fill:
|
if auto_fill:
|
||||||
config = Config(
|
config = Config(
|
||||||
cls.default_config, section_order=CONFIG_SECTION_ORDER
|
cls.default_config, section_order=CONFIG_SECTION_ORDER
|
||||||
|
@ -1877,9 +1872,38 @@ class Language:
|
||||||
nlp.vocab.from_bytes(vocab_b)
|
nlp.vocab.from_bytes(vocab_b)
|
||||||
|
|
||||||
# Resolve disabled/enabled settings.
|
# Resolve disabled/enabled settings.
|
||||||
|
if isinstance(disable, str):
|
||||||
|
disable = [disable]
|
||||||
|
if isinstance(enable, str):
|
||||||
|
enable = [enable]
|
||||||
|
if isinstance(exclude, str):
|
||||||
|
exclude = [exclude]
|
||||||
|
|
||||||
|
def fetch_pipes_status(value: Iterable[str], key: str) -> Iterable[str]:
|
||||||
|
"""Fetch value for `enable` or `disable` w.r.t. the specified config and passed arguments passed to
|
||||||
|
.load(). If both arguments and config specified values for this field, the passed arguments take precedence
|
||||||
|
and a warning is printed.
|
||||||
|
value (Iterable[str]): Passed value for `enable` or `disable`.
|
||||||
|
key (str): Key for field in config (either "enabled" or "disabled").
|
||||||
|
RETURN (Iterable[str]):
|
||||||
|
"""
|
||||||
|
# We assume that no argument was passed if the value is the specified default value.
|
||||||
|
if id(value) == id(_DEFAULT_EMPTY_PIPES):
|
||||||
|
return config["nlp"].get(key, [])
|
||||||
|
else:
|
||||||
|
if len(config["nlp"].get(key, [])):
|
||||||
|
warnings.warn(
|
||||||
|
Warnings.W123.format(
|
||||||
|
arg=key[:-1],
|
||||||
|
arg_value=value,
|
||||||
|
config_value=config["nlp"][key],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return value
|
||||||
|
|
||||||
disabled_pipes = cls._resolve_component_status(
|
disabled_pipes = cls._resolve_component_status(
|
||||||
[*config["nlp"]["disabled"], *disable],
|
fetch_pipes_status(disable, "disabled"),
|
||||||
[*config["nlp"].get("enabled", []), *enable],
|
fetch_pipes_status(enable, "enabled"),
|
||||||
config["nlp"]["pipeline"],
|
config["nlp"]["pipeline"],
|
||||||
)
|
)
|
||||||
nlp._disabled = set(p for p in disabled_pipes if p not in exclude)
|
nlp._disabled = set(p for p in disabled_pipes if p not in exclude)
|
||||||
|
@ -2064,14 +2088,7 @@ class Language:
|
||||||
pipe_name for pipe_name in pipe_names if pipe_name not in enable
|
pipe_name for pipe_name in pipe_names if pipe_name not in enable
|
||||||
]
|
]
|
||||||
if disable and disable != to_disable:
|
if disable and disable != to_disable:
|
||||||
raise ValueError(
|
raise ValueError(Errors.E1042.format(enable=enable, disable=disable))
|
||||||
Errors.E1042.format(
|
|
||||||
arg1="enable",
|
|
||||||
arg2="disable",
|
|
||||||
arg1_values=enable,
|
|
||||||
arg2_values=disable,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return tuple(to_disable)
|
return tuple(to_disable)
|
||||||
|
|
||||||
|
|
|
@ -605,10 +605,35 @@ def test_update_with_annotates():
|
||||||
assert results[component] == ""
|
assert results[component] == ""
|
||||||
|
|
||||||
|
|
||||||
def test_load_disable_enable() -> None:
|
@pytest.mark.issue(11443)
|
||||||
"""
|
def test_enable_disable_conflict_with_config():
|
||||||
Tests spacy.load() with dis-/enabling components.
|
"""Test conflict between enable/disable w.r.t. `nlp.disabled` set in the config."""
|
||||||
"""
|
nlp = English()
|
||||||
|
nlp.add_pipe("tagger")
|
||||||
|
nlp.add_pipe("senter")
|
||||||
|
nlp.add_pipe("sentencizer")
|
||||||
|
|
||||||
|
with make_tempdir() as tmp_dir:
|
||||||
|
nlp.to_disk(tmp_dir)
|
||||||
|
# Expected to fail, as config and arguments conflict.
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
spacy.load(
|
||||||
|
tmp_dir, enable=["tagger"], config={"nlp": {"disabled": ["senter"]}}
|
||||||
|
)
|
||||||
|
# Expected to succeed without warning due to the lack of a conflicting config option.
|
||||||
|
spacy.load(tmp_dir, enable=["tagger"])
|
||||||
|
# Expected to succeed with a warning, as disable=[] should override the config setting.
|
||||||
|
with pytest.warns(UserWarning):
|
||||||
|
spacy.load(
|
||||||
|
tmp_dir,
|
||||||
|
enable=["tagger"],
|
||||||
|
disable=[],
|
||||||
|
config={"nlp": {"disabled": ["senter"]}},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_disable_enable():
|
||||||
|
"""Tests spacy.load() with dis-/enabling components."""
|
||||||
|
|
||||||
base_nlp = English()
|
base_nlp = English()
|
||||||
for pipe in ("sentencizer", "tagger", "parser"):
|
for pipe in ("sentencizer", "tagger", "parser"):
|
||||||
|
|
|
@ -404,10 +404,11 @@ def test_serialize_pipeline_disable_enable():
|
||||||
assert nlp3.component_names == ["ner", "tagger"]
|
assert nlp3.component_names == ["ner", "tagger"]
|
||||||
with make_tempdir() as d:
|
with make_tempdir() as d:
|
||||||
nlp3.to_disk(d)
|
nlp3.to_disk(d)
|
||||||
nlp4 = spacy.load(d, disable=["ner"])
|
with pytest.warns(UserWarning):
|
||||||
assert nlp4.pipe_names == []
|
nlp4 = spacy.load(d, disable=["ner"])
|
||||||
|
assert nlp4.pipe_names == ["tagger"]
|
||||||
assert nlp4.component_names == ["ner", "tagger"]
|
assert nlp4.component_names == ["ner", "tagger"]
|
||||||
assert nlp4.disabled == ["ner", "tagger"]
|
assert nlp4.disabled == ["ner"]
|
||||||
with make_tempdir() as d:
|
with make_tempdir() as d:
|
||||||
nlp.to_disk(d)
|
nlp.to_disk(d)
|
||||||
nlp5 = spacy.load(d, exclude=["tagger"])
|
nlp5 = spacy.load(d, exclude=["tagger"])
|
||||||
|
|
|
@ -67,7 +67,6 @@ LEXEME_NORM_LANGS = ["cs", "da", "de", "el", "en", "id", "lb", "mk", "pt", "ru",
|
||||||
CONFIG_SECTION_ORDER = ["paths", "variables", "system", "nlp", "components", "corpora", "training", "pretraining", "initialize"]
|
CONFIG_SECTION_ORDER = ["paths", "variables", "system", "nlp", "components", "corpora", "training", "pretraining", "initialize"]
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger("spacy")
|
logger = logging.getLogger("spacy")
|
||||||
logger_stream_handler = logging.StreamHandler()
|
logger_stream_handler = logging.StreamHandler()
|
||||||
logger_stream_handler.setFormatter(
|
logger_stream_handler.setFormatter(
|
||||||
|
@ -394,13 +393,17 @@ def get_module_path(module: ModuleType) -> Path:
|
||||||
return file_path.parent
|
return file_path.parent
|
||||||
|
|
||||||
|
|
||||||
|
# Default value for passed enable/disable values.
|
||||||
|
_DEFAULT_EMPTY_PIPES = SimpleFrozenList()
|
||||||
|
|
||||||
|
|
||||||
def load_model(
|
def load_model(
|
||||||
name: Union[str, Path],
|
name: Union[str, Path],
|
||||||
*,
|
*,
|
||||||
vocab: Union["Vocab", bool] = True,
|
vocab: Union["Vocab", bool] = True,
|
||||||
disable: Union[str, Iterable[str]] = SimpleFrozenList(),
|
disable: Union[str, Iterable[str]] = _DEFAULT_EMPTY_PIPES,
|
||||||
enable: Union[str, Iterable[str]] = SimpleFrozenList(),
|
enable: Union[str, Iterable[str]] = _DEFAULT_EMPTY_PIPES,
|
||||||
exclude: Union[str, Iterable[str]] = SimpleFrozenList(),
|
exclude: Union[str, Iterable[str]] = _DEFAULT_EMPTY_PIPES,
|
||||||
config: Union[Dict[str, Any], Config] = SimpleFrozenDict(),
|
config: Union[Dict[str, Any], Config] = SimpleFrozenDict(),
|
||||||
) -> "Language":
|
) -> "Language":
|
||||||
"""Load a model from a package or data path.
|
"""Load a model from a package or data path.
|
||||||
|
@ -470,9 +473,9 @@ def load_model_from_path(
|
||||||
*,
|
*,
|
||||||
meta: Optional[Dict[str, Any]] = None,
|
meta: Optional[Dict[str, Any]] = None,
|
||||||
vocab: Union["Vocab", bool] = True,
|
vocab: Union["Vocab", bool] = True,
|
||||||
disable: Union[str, Iterable[str]] = SimpleFrozenList(),
|
disable: Union[str, Iterable[str]] = _DEFAULT_EMPTY_PIPES,
|
||||||
enable: Union[str, Iterable[str]] = SimpleFrozenList(),
|
enable: Union[str, Iterable[str]] = _DEFAULT_EMPTY_PIPES,
|
||||||
exclude: Union[str, Iterable[str]] = SimpleFrozenList(),
|
exclude: Union[str, Iterable[str]] = _DEFAULT_EMPTY_PIPES,
|
||||||
config: Union[Dict[str, Any], Config] = SimpleFrozenDict(),
|
config: Union[Dict[str, Any], Config] = SimpleFrozenDict(),
|
||||||
) -> "Language":
|
) -> "Language":
|
||||||
"""Load a model from a data directory path. Creates Language class with
|
"""Load a model from a data directory path. Creates Language class with
|
||||||
|
@ -516,9 +519,9 @@ def load_model_from_config(
|
||||||
*,
|
*,
|
||||||
meta: Dict[str, Any] = SimpleFrozenDict(),
|
meta: Dict[str, Any] = SimpleFrozenDict(),
|
||||||
vocab: Union["Vocab", bool] = True,
|
vocab: Union["Vocab", bool] = True,
|
||||||
disable: Union[str, Iterable[str]] = SimpleFrozenList(),
|
disable: Union[str, Iterable[str]] = _DEFAULT_EMPTY_PIPES,
|
||||||
enable: Union[str, Iterable[str]] = SimpleFrozenList(),
|
enable: Union[str, Iterable[str]] = _DEFAULT_EMPTY_PIPES,
|
||||||
exclude: Union[str, Iterable[str]] = SimpleFrozenList(),
|
exclude: Union[str, Iterable[str]] = _DEFAULT_EMPTY_PIPES,
|
||||||
auto_fill: bool = False,
|
auto_fill: bool = False,
|
||||||
validate: bool = True,
|
validate: bool = True,
|
||||||
) -> "Language":
|
) -> "Language":
|
||||||
|
|
Loading…
Reference in New Issue
Block a user