Simplify config overrides in CLI and deserialization (#5880)

This commit is contained in:
Ines Montani 2020-08-05 23:35:09 +02:00 committed by GitHub
parent 50311a4d37
commit 5cc0d89fad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 72 additions and 81 deletions

View File

@ -8,6 +8,7 @@ warnings.filterwarnings("ignore", message="numpy.ufunc size changed") # noqa
# These are imported as part of the API
from thinc.api import prefer_gpu, require_gpu # noqa: F401
from thinc.api import Config
from . import pipeline # noqa: F401
from .cli.info import info # noqa: F401
@ -26,17 +27,17 @@ if sys.maxunicode == 65535:
def load(
name: Union[str, Path],
disable: Iterable[str] = tuple(),
component_cfg: Dict[str, Dict[str, Any]] = util.SimpleFrozenDict(),
config: Union[Dict[str, Any], Config] = util.SimpleFrozenDict(),
) -> Language:
"""Load a spaCy model from an installed package or a local path.
name (str): Package name or model path.
disable (Iterable[str]): Names of pipeline components to disable.
component_cfg (Dict[str, dict]): Config overrides for pipeline components,
keyed by component names.
config (Dict[str, Any] / Config): Config overrides as nested dict or dict
keyed by section values in dot notation.
RETURNS (Language): The loaded nlp object.
"""
return util.load_model(name, disable=disable, component_cfg=component_cfg)
return util.load_model(name, disable=disable, config=config)
def blank(name: str, **overrides) -> Language:

View File

@ -49,11 +49,9 @@ def debug_config_cli(
overrides = parse_config_overrides(ctx.args)
import_code(code_path)
with show_validation_error(config_path):
config = Config().from_disk(config_path)
config = Config().from_disk(config_path, overrides=overrides)
try:
nlp, _ = util.load_model_from_config(
config, overrides=overrides, auto_fill=auto_fill
)
nlp, _ = util.load_model_from_config(config, auto_fill=auto_fill)
except ValueError as e:
msg.fail(str(e), exits=1)
if auto_fill:
@ -136,8 +134,8 @@ def debug_data(
if not config_path.exists():
msg.fail("Config file not found", config_path, exists=1)
with show_validation_error(config_path):
cfg = Config().from_disk(config_path)
nlp, config = util.load_model_from_config(cfg, overrides=config_overrides)
cfg = Config().from_disk(config_path, overrides=config_overrides)
nlp, config = util.load_model_from_config(cfg)
# Use original config here, not resolved version
sourced_components = get_sourced_components(cfg)
frozen_components = config["training"]["frozen_components"]

View File

@ -49,9 +49,9 @@ def debug_model_cli(
}
config_overrides = parse_config_overrides(ctx.args)
with show_validation_error(config_path):
cfg = Config().from_disk(config_path)
cfg = Config().from_disk(config_path, overrides=config_overrides)
try:
nlp, config = util.load_model_from_config(cfg, overrides=config_overrides)
nlp, config = util.load_model_from_config(cfg)
except ValueError as e:
msg.fail(str(e), exits=1)
seed = config.get("training", {}).get("seed", None)

View File

@ -88,8 +88,8 @@ def pretrain(
msg.info("Using CPU")
msg.info(f"Loading config from: {config_path}")
with show_validation_error(config_path):
config = Config().from_disk(config_path)
nlp, config = util.load_model_from_config(config, overrides=config_overrides)
config = Config().from_disk(config_path, overrides=config_overrides)
nlp, config = util.load_model_from_config(config)
# TODO: validate that [pretraining] block exists
if not output_dir.exists():
output_dir.mkdir()

View File

@ -75,13 +75,13 @@ def train(
msg.info("Using CPU")
msg.info(f"Loading config and nlp from: {config_path}")
with show_validation_error(config_path):
config = Config().from_disk(config_path)
config = Config().from_disk(config_path, overrides=config_overrides)
if config.get("training", {}).get("seed") is not None:
fix_random_seed(config["training"]["seed"])
# Use original config here before it's resolved to functions
sourced_components = get_sourced_components(config)
with show_validation_error(config_path):
nlp, config = util.load_model_from_config(config, overrides=config_overrides)
nlp, config = util.load_model_from_config(config)
if config["training"]["vectors"] is not None:
util.load_vectors_into_model(nlp, config["training"]["vectors"])
verify_config(nlp)
@ -144,7 +144,7 @@ def train(
max_steps=T_cfg["max_steps"],
eval_frequency=T_cfg["eval_frequency"],
raw_text=None,
exclude=frozen_components
exclude=frozen_components,
)
msg.info(f"Training. Initial learn rate: {optimizer.learn_rate}")
print_row = setup_printer(T_cfg, nlp)

View File

@ -558,7 +558,6 @@ class Language:
name: Optional[str] = None,
*,
config: Optional[Dict[str, Any]] = SimpleFrozenDict(),
overrides: Optional[Dict[str, Any]] = SimpleFrozenDict(),
validate: bool = True,
) -> Callable[[Doc], Doc]:
"""Create a pipeline component. Mostly used internally. To create and
@ -569,8 +568,6 @@ class Language:
Defaults to factory name if not set.
config (Optional[Dict[str, Any]]): Config parameters to use for this
component. Will be merged with default config, if available.
overrides (Optional[Dict[str, Any]]): Config overrides, typically
passed in via the CLI.
validate (bool): Whether to validate the component config against the
arguments and types expected by the factory.
RETURNS (Callable[[Doc], Doc]): The pipeline component.
@ -613,7 +610,7 @@ class Language:
# registered functions twice
# TODO: customize validation to make it more readable / relate it to
# pipeline component and why it failed, explain default config
resolved, filled = registry.resolve(cfg, validate=validate, overrides=overrides)
resolved, filled = registry.resolve(cfg, validate=validate)
filled = filled[factory_name]
filled["factory"] = factory_name
filled.pop("@factories", None)
@ -657,7 +654,6 @@ class Language:
last: Optional[bool] = None,
source: Optional["Language"] = None,
config: Optional[Dict[str, Any]] = SimpleFrozenDict(),
overrides: Optional[Dict[str, Any]] = SimpleFrozenDict(),
validate: bool = True,
) -> Callable[[Doc], Doc]:
"""Add a component to the processing pipeline. Valid components are
@ -679,8 +675,6 @@ class Language:
component from.
config (Optional[Dict[str, Any]]): Config parameters to use for this
component. Will be merged with default config, if available.
overrides (Optional[Dict[str, Any]]): Config overrides, typically
passed in via the CLI.
validate (bool): Whether to validate the component config against the
arguments and types expected by the factory.
RETURNS (Callable[[Doc], Doc]): The pipeline component.
@ -710,11 +704,7 @@ class Language:
lang_code=self.lang,
)
pipe_component = self.create_pipe(
factory_name,
name=name,
config=config,
overrides=overrides,
validate=validate,
factory_name, name=name, config=config, validate=validate,
)
pipe_index = self._get_pipe_index(before, after, first, last)
self._pipe_meta[name] = self.get_factory_meta(factory_name)
@ -1416,7 +1406,6 @@ class Language:
*,
vocab: Union[Vocab, bool] = True,
disable: Iterable[str] = tuple(),
overrides: Dict[str, Any] = {},
auto_fill: bool = True,
validate: bool = True,
) -> "Language":
@ -1456,9 +1445,8 @@ class Language:
config = util.copy_config(config)
orig_pipeline = config.pop("components", {})
config["components"] = {}
non_pipe_overrides, pipe_overrides = _get_config_overrides(overrides)
resolved, filled = registry.resolve(
config, validate=validate, schema=ConfigSchema, overrides=non_pipe_overrides
config, validate=validate, schema=ConfigSchema
)
filled["components"] = orig_pipeline
config["components"] = orig_pipeline
@ -1507,11 +1495,7 @@ class Language:
# The pipe name (key in the config) here is the unique name
# of the component, not necessarily the factory
nlp.add_pipe(
factory,
name=pipe_name,
config=pipe_cfg,
overrides=pipe_overrides,
validate=validate,
factory, name=pipe_name, config=pipe_cfg, validate=validate,
)
else:
model = pipe_cfg["source"]
@ -1696,15 +1680,6 @@ class FactoryMeta:
default_score_weights: Optional[Dict[str, float]] = None # noqa: E704
def _get_config_overrides(
items: Dict[str, Any], prefix: str = "components"
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
prefix = f"{prefix}."
non_pipe = {k: v for k, v in items.items() if not k.startswith(prefix)}
pipe = {k.replace(prefix, ""): v for k, v in items.items() if k.startswith(prefix)}
return non_pipe, pipe
def _fix_pretrained_vectors_name(nlp: Language) -> None:
# TODO: Replace this once we handle vectors consistently as static
# data

View File

@ -27,6 +27,6 @@ def test_issue5137():
with make_tempdir() as tmpdir:
nlp.to_disk(tmpdir)
overrides = {"my_component": {"categories": "my_categories"}}
nlp2 = spacy.load(tmpdir, component_cfg=overrides)
overrides = {"components": {"my_component": {"categories": "my_categories"}}}
nlp2 = spacy.load(tmpdir, config=overrides)
assert nlp2.get_pipe("my_component").categories == "my_categories"

View File

@ -2,6 +2,7 @@ import pytest
from thinc.config import Config, ConfigValidationError
import spacy
from spacy.lang.en import English
from spacy.lang.de import German
from spacy.language import Language
from spacy.util import registry, deep_merge_configs, load_model_from_config
from spacy.ml.models import build_Tok2Vec_model, build_tb_parser_model
@ -282,3 +283,33 @@ def test_serialize_config_missing_pipes():
assert "tok2vec" not in config["components"]
with pytest.raises(ValueError):
load_model_from_config(config, auto_fill=True)
def test_config_overrides():
overrides_nested = {"nlp": {"lang": "de", "pipeline": ["tagger"]}}
overrides_dot = {"nlp.lang": "de", "nlp.pipeline": ["tagger"]}
# load_model from config with overrides passed directly to Config
config = Config().from_str(nlp_config_string, overrides=overrides_dot)
nlp, _ = load_model_from_config(config, auto_fill=True)
assert isinstance(nlp, German)
assert nlp.pipe_names == ["tagger"]
# Serialized roundtrip with config passed in
base_config = Config().from_str(nlp_config_string)
base_nlp, _ = load_model_from_config(base_config, auto_fill=True)
assert isinstance(base_nlp, English)
assert base_nlp.pipe_names == ["tok2vec", "tagger"]
with make_tempdir() as d:
base_nlp.to_disk(d)
nlp = spacy.load(d, config=overrides_nested)
assert isinstance(nlp, German)
assert nlp.pipe_names == ["tagger"]
with make_tempdir() as d:
base_nlp.to_disk(d)
nlp = spacy.load(d, config=overrides_dot)
assert isinstance(nlp, German)
assert nlp.pipe_names == ["tagger"]
with make_tempdir() as d:
base_nlp.to_disk(d)
nlp = spacy.load(d)
assert isinstance(nlp, English)
assert nlp.pipe_names == ["tok2vec", "tagger"]

View File

@ -210,7 +210,7 @@ def load_model(
*,
vocab: Union["Vocab", bool] = True,
disable: Iterable[str] = tuple(),
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
config: Union[Dict[str, Any], Config] = SimpleFrozenDict(),
) -> "Language":
"""Load a model from a package or data path.
@ -218,11 +218,11 @@ def load_model(
vocab (Vocab / True): Optional vocab to pass in on initialization. If True,
a new Vocab object will be created.
disable (Iterable[str]): Names of pipeline components to disable.
component_cfg (Dict[str, dict]): Config overrides for pipeline components,
keyed by component names.
config (Dict[str, Any] / Config): Config overrides as nested dict or dict
keyed by section values in dot notation.
RETURNS (Language): The loaded nlp object.
"""
kwargs = {"vocab": vocab, "disable": disable, "component_cfg": component_cfg}
kwargs = {"vocab": vocab, "disable": disable, "config": config}
if isinstance(name, str): # name or string path
if name.startswith("blank:"): # shortcut for blank model
return get_lang_class(name.replace("blank:", ""))()
@ -240,11 +240,11 @@ def load_model_from_package(
*,
vocab: Union["Vocab", bool] = True,
disable: Iterable[str] = tuple(),
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
config: Union[Dict[str, Any], Config] = SimpleFrozenDict(),
) -> "Language":
"""Load a model from an installed package."""
cls = importlib.import_module(name)
return cls.load(vocab=vocab, disable=disable, component_cfg=component_cfg)
return cls.load(vocab=vocab, disable=disable, config=config)
def load_model_from_path(
@ -253,7 +253,7 @@ def load_model_from_path(
meta: Optional[Dict[str, Any]] = None,
vocab: Union["Vocab", bool] = True,
disable: Iterable[str] = tuple(),
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
config: Union[Dict[str, Any], Config] = SimpleFrozenDict(),
) -> "Language":
"""Load a model from a data directory path. Creates Language class with
pipeline from config.cfg and then calls from_disk() with path."""
@ -264,12 +264,8 @@ def load_model_from_path(
config_path = model_path / "config.cfg"
if not config_path.exists() or not config_path.is_file():
raise IOError(Errors.E053.format(path=config_path, name="config.cfg"))
config = Config().from_disk(config_path)
override_cfg = {"components": {p: dict_to_dot(c) for p, c in component_cfg.items()}}
overrides = dict_to_dot(override_cfg)
nlp, _ = load_model_from_config(
config, vocab=vocab, disable=disable, overrides=overrides
)
config = Config().from_disk(config_path, overrides=dict_to_dot(config))
nlp, _ = load_model_from_config(config, vocab=vocab, disable=disable)
return nlp.from_disk(model_path, exclude=disable)
@ -278,7 +274,6 @@ def load_model_from_config(
*,
vocab: Union["Vocab", bool] = True,
disable: Iterable[str] = tuple(),
overrides: Dict[str, Any] = {},
auto_fill: bool = False,
validate: bool = True,
) -> Tuple["Language", Config]:
@ -294,12 +289,7 @@ def load_model_from_config(
# registry, including custom subclasses provided via entry points
lang_cls = get_lang_class(nlp_config["lang"])
nlp = lang_cls.from_config(
config,
vocab=vocab,
disable=disable,
overrides=overrides,
auto_fill=auto_fill,
validate=validate,
config, vocab=vocab, disable=disable, auto_fill=auto_fill, validate=validate,
)
return nlp, nlp.resolved
@ -309,14 +299,10 @@ def load_model_from_init_py(
*,
vocab: Union["Vocab", bool] = True,
disable: Iterable[str] = tuple(),
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
config: Union[Dict[str, Any], Config] = SimpleFrozenDict(),
) -> "Language":
"""Helper function to use in the `load()` method of a model package's
__init__.py.
init_file (str): Path to model's __init__.py, i.e. `__file__`.
**overrides: Specific overrides, like pipeline components to disable.
RETURNS (Language): `Language` class with loaded model.
"""
model_path = Path(init_file).parent
meta = get_model_meta(model_path)
@ -325,7 +311,7 @@ def load_model_from_init_py(
if not model_path.exists():
raise IOError(Errors.E052.format(path=data_path))
return load_model_from_path(
data_path, vocab=vocab, meta=meta, disable=disable, component_cfg=component_cfg
data_path, vocab=vocab, meta=meta, disable=disable, config=config
)

View File

@ -32,13 +32,13 @@ loaded in via [`Language.from_disk`](/api/language#from_disk).
> nlp = spacy.load("en_core_web_sm", disable=["parser", "tagger"])
> ```
| Name | Type | Description |
| ------------------------------------------ | ----------------- | --------------------------------------------------------------------------------- |
| `name` | str / `Path` | Model to load, i.e. package name or path. |
| _keyword-only_ | | |
| `disable` | `List[str]` | Names of pipeline components to [disable](/usage/processing-pipelines#disabling). |
| `component_cfg` <Tag variant="new">3</Tag> | `Dict[str, dict]` | Optional config overrides for pipeline components, keyed by component names. |
| **RETURNS** | `Language` | A `Language` object with the loaded model. |
| Name | Type | Description |
| ----------------------------------- | ---------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------- |
| `name` | str / `Path` | Model to load, i.e. package name or path. |
| _keyword-only_ | | |
| `disable` | `List[str]` | Names of pipeline components to [disable](/usage/processing-pipelines#disabling). |
| `config` <Tag variant="new">3</Tag> | `Dict[str, Any]` / [`Config`](https://thinc.ai/docs/api-config#config) | Optional config overrides, either as nested dict or dict keyed by section value in dot notation, e.g. `"components.name.value"`. |
| **RETURNS** | `Language` | A `Language` object with the loaded model. |
Essentially, `spacy.load()` is a convenience wrapper that reads the language ID
and pipeline components from a model's `meta.json`, initializes the `Language`