Merge disable with disabled. Adjust warnings, errors and tests.

This commit is contained in:
Raphael Mitsch 2022-11-03 13:31:42 +01:00
parent 40e1000db0
commit 13a6324673
4 changed files with 32 additions and 42 deletions

View File

@ -212,8 +212,8 @@ class Warnings(metaclass=ErrorsWithCodes):
W121 = ("Attempting to trace non-existent method '{method}' in pipe '{pipe}'")
W122 = ("Couldn't trace method '{method}' in pipe '{pipe}'. This can happen if the pipe class "
"is a Cython extension type.")
W123 = ("Argument {arg} with value {arg_value} is used instead of {config_value} as specified in the config. Be "
"aware that this might affect other components in your pipeline.")
W123 = ("Argument `enable` with value {enable} does not contain all values specified in the config option "
"`enabled` ({enabled}). Be aware that this might affect other components in your pipeline.")
class Errors(metaclass=ErrorsWithCodes):

View File

@ -1879,31 +1879,22 @@ class Language:
if isinstance(exclude, str):
exclude = [exclude]
def fetch_pipes_status(value: Iterable[str], key: str) -> Iterable[str]:
"""Fetch value for `enable` or `disable` w.r.t. the specified config and passed arguments passed to
.load(). If both arguments and config specified values for this field, the passed arguments take precedence
and a warning is printed.
value (Iterable[str]): Passed value for `enable` or `disable`.
key (str): Key for field in config (either "enabled" or "disabled").
RETURN (Iterable[str]):
"""
# We assume that no argument was passed if the value is the specified default value.
if id(value) == id(_DEFAULT_EMPTY_PIPES):
return config["nlp"].get(key, [])
else:
if len(config["nlp"].get(key, [])):
warnings.warn(
Warnings.W123.format(
arg=key[:-1],
arg_value=value,
config_value=config["nlp"][key],
)
# `enable` should not be merged with `enabled` (the opposite is true for `disable`/`disabled`). If the config
# specifies values for `enabled` not included in `enable`, emit warning.
if id(enable) != id(_DEFAULT_EMPTY_PIPES):
enabled = config["nlp"].get("enabled", [])
if len(enabled) and not set(enabled).issubset(enable):
warnings.warn(
Warnings.W123.format(
enable=enable,
enabled=enabled,
)
return value
)
# Ensure sets of disabled/enabled pipe names are not contradictory.
disabled_pipes = cls._resolve_component_status(
fetch_pipes_status(disable, "disabled"),
fetch_pipes_status(enable, "enabled"),
list({*disable, *config["nlp"].get("disabled", [])}),
enable,
config["nlp"]["pipeline"],
)
nlp._disabled = set(p for p in disabled_pipes if p not in exclude)
@ -2084,10 +2075,12 @@ class Language:
if enable:
if isinstance(enable, str):
enable = [enable]
to_disable = [
pipe_name for pipe_name in pipe_names if pipe_name not in enable
]
if disable and disable != to_disable:
to_disable = {
*[pipe_name for pipe_name in pipe_names if pipe_name not in enable],
*disable,
}
# If any pipe to be enabled is in to_disable, the specification is inconsistent.
if any([pipe_to_enable in to_disable for pipe_to_enable in enable]):
raise ValueError(Errors.E1042.format(enable=enable, disable=disable))
return tuple(to_disable)

View File

@ -615,20 +615,18 @@ def test_enable_disable_conflict_with_config():
with make_tempdir() as tmp_dir:
nlp.to_disk(tmp_dir)
# Expected to fail, as config and arguments conflict.
with pytest.raises(ValueError):
spacy.load(
tmp_dir, enable=["tagger"], config={"nlp": {"disabled": ["senter"]}}
)
# Expected to succeed, as config and arguments conflict.
assert spacy.load(
tmp_dir, enable=["tagger"], config={"nlp": {"disabled": ["senter"]}}
).disabled == ["senter", "sentencizer"]
# Expected to succeed without warning due to the lack of a conflicting config option.
spacy.load(tmp_dir, enable=["tagger"])
# Expected to succeed with a warning, as disable=[] should override the config setting.
with pytest.warns(UserWarning):
# Expected to fail due to conflict between enable and disabled.
with pytest.raises(ValueError):
spacy.load(
tmp_dir,
enable=["tagger"],
disable=[],
config={"nlp": {"disabled": ["senter"]}},
enable=["senter"],
config={"nlp": {"disabled": ["senter", "tagger"]}},
)

View File

@ -404,11 +404,10 @@ def test_serialize_pipeline_disable_enable():
assert nlp3.component_names == ["ner", "tagger"]
with make_tempdir() as d:
nlp3.to_disk(d)
with pytest.warns(UserWarning):
nlp4 = spacy.load(d, disable=["ner"])
assert nlp4.pipe_names == ["tagger"]
nlp4 = spacy.load(d, disable=["ner"])
assert nlp4.pipe_names == []
assert nlp4.component_names == ["ner", "tagger"]
assert nlp4.disabled == ["ner"]
assert nlp4.disabled == ["ner", "tagger"]
with make_tempdir() as d:
nlp.to_disk(d)
nlp5 = spacy.load(d, exclude=["tagger"])