Simplify config overrides in CLI and deserialization (#5880)

This commit is contained in:
Ines Montani 2020-08-05 23:35:09 +02:00 committed by GitHub
parent 50311a4d37
commit 5cc0d89fad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 72 additions and 81 deletions

View File

@ -8,6 +8,7 @@ warnings.filterwarnings("ignore", message="numpy.ufunc size changed") # noqa
# These are imported as part of the API # These are imported as part of the API
from thinc.api import prefer_gpu, require_gpu # noqa: F401 from thinc.api import prefer_gpu, require_gpu # noqa: F401
from thinc.api import Config
from . import pipeline # noqa: F401 from . import pipeline # noqa: F401
from .cli.info import info # noqa: F401 from .cli.info import info # noqa: F401
@ -26,17 +27,17 @@ if sys.maxunicode == 65535:
def load( def load(
name: Union[str, Path], name: Union[str, Path],
disable: Iterable[str] = tuple(), disable: Iterable[str] = tuple(),
component_cfg: Dict[str, Dict[str, Any]] = util.SimpleFrozenDict(), config: Union[Dict[str, Any], Config] = util.SimpleFrozenDict(),
) -> Language: ) -> Language:
"""Load a spaCy model from an installed package or a local path. """Load a spaCy model from an installed package or a local path.
name (str): Package name or model path. name (str): Package name or model path.
disable (Iterable[str]): Names of pipeline components to disable. disable (Iterable[str]): Names of pipeline components to disable.
component_cfg (Dict[str, dict]): Config overrides for pipeline components, config (Dict[str, Any] / Config): Config overrides as nested dict or dict
keyed by component names. keyed by section values in dot notation.
RETURNS (Language): The loaded nlp object. RETURNS (Language): The loaded nlp object.
""" """
return util.load_model(name, disable=disable, component_cfg=component_cfg) return util.load_model(name, disable=disable, config=config)
def blank(name: str, **overrides) -> Language: def blank(name: str, **overrides) -> Language:

View File

@ -49,11 +49,9 @@ def debug_config_cli(
overrides = parse_config_overrides(ctx.args) overrides = parse_config_overrides(ctx.args)
import_code(code_path) import_code(code_path)
with show_validation_error(config_path): with show_validation_error(config_path):
config = Config().from_disk(config_path) config = Config().from_disk(config_path, overrides=overrides)
try: try:
nlp, _ = util.load_model_from_config( nlp, _ = util.load_model_from_config(config, auto_fill=auto_fill)
config, overrides=overrides, auto_fill=auto_fill
)
except ValueError as e: except ValueError as e:
msg.fail(str(e), exits=1) msg.fail(str(e), exits=1)
if auto_fill: if auto_fill:
@ -136,8 +134,8 @@ def debug_data(
if not config_path.exists(): if not config_path.exists():
msg.fail("Config file not found", config_path, exists=1) msg.fail("Config file not found", config_path, exists=1)
with show_validation_error(config_path): with show_validation_error(config_path):
cfg = Config().from_disk(config_path) cfg = Config().from_disk(config_path, overrides=config_overrides)
nlp, config = util.load_model_from_config(cfg, overrides=config_overrides) nlp, config = util.load_model_from_config(cfg)
# Use original config here, not resolved version # Use original config here, not resolved version
sourced_components = get_sourced_components(cfg) sourced_components = get_sourced_components(cfg)
frozen_components = config["training"]["frozen_components"] frozen_components = config["training"]["frozen_components"]

View File

@ -49,9 +49,9 @@ def debug_model_cli(
} }
config_overrides = parse_config_overrides(ctx.args) config_overrides = parse_config_overrides(ctx.args)
with show_validation_error(config_path): with show_validation_error(config_path):
cfg = Config().from_disk(config_path) cfg = Config().from_disk(config_path, overrides=config_overrides)
try: try:
nlp, config = util.load_model_from_config(cfg, overrides=config_overrides) nlp, config = util.load_model_from_config(cfg)
except ValueError as e: except ValueError as e:
msg.fail(str(e), exits=1) msg.fail(str(e), exits=1)
seed = config.get("training", {}).get("seed", None) seed = config.get("training", {}).get("seed", None)

View File

@ -88,8 +88,8 @@ def pretrain(
msg.info("Using CPU") msg.info("Using CPU")
msg.info(f"Loading config from: {config_path}") msg.info(f"Loading config from: {config_path}")
with show_validation_error(config_path): with show_validation_error(config_path):
config = Config().from_disk(config_path) config = Config().from_disk(config_path, overrides=config_overrides)
nlp, config = util.load_model_from_config(config, overrides=config_overrides) nlp, config = util.load_model_from_config(config)
# TODO: validate that [pretraining] block exists # TODO: validate that [pretraining] block exists
if not output_dir.exists(): if not output_dir.exists():
output_dir.mkdir() output_dir.mkdir()

View File

@ -75,13 +75,13 @@ def train(
msg.info("Using CPU") msg.info("Using CPU")
msg.info(f"Loading config and nlp from: {config_path}") msg.info(f"Loading config and nlp from: {config_path}")
with show_validation_error(config_path): with show_validation_error(config_path):
config = Config().from_disk(config_path) config = Config().from_disk(config_path, overrides=config_overrides)
if config.get("training", {}).get("seed") is not None: if config.get("training", {}).get("seed") is not None:
fix_random_seed(config["training"]["seed"]) fix_random_seed(config["training"]["seed"])
# Use original config here before it's resolved to functions # Use original config here before it's resolved to functions
sourced_components = get_sourced_components(config) sourced_components = get_sourced_components(config)
with show_validation_error(config_path): with show_validation_error(config_path):
nlp, config = util.load_model_from_config(config, overrides=config_overrides) nlp, config = util.load_model_from_config(config)
if config["training"]["vectors"] is not None: if config["training"]["vectors"] is not None:
util.load_vectors_into_model(nlp, config["training"]["vectors"]) util.load_vectors_into_model(nlp, config["training"]["vectors"])
verify_config(nlp) verify_config(nlp)
@ -144,7 +144,7 @@ def train(
max_steps=T_cfg["max_steps"], max_steps=T_cfg["max_steps"],
eval_frequency=T_cfg["eval_frequency"], eval_frequency=T_cfg["eval_frequency"],
raw_text=None, raw_text=None,
exclude=frozen_components exclude=frozen_components,
) )
msg.info(f"Training. Initial learn rate: {optimizer.learn_rate}") msg.info(f"Training. Initial learn rate: {optimizer.learn_rate}")
print_row = setup_printer(T_cfg, nlp) print_row = setup_printer(T_cfg, nlp)

View File

@ -558,7 +558,6 @@ class Language:
name: Optional[str] = None, name: Optional[str] = None,
*, *,
config: Optional[Dict[str, Any]] = SimpleFrozenDict(), config: Optional[Dict[str, Any]] = SimpleFrozenDict(),
overrides: Optional[Dict[str, Any]] = SimpleFrozenDict(),
validate: bool = True, validate: bool = True,
) -> Callable[[Doc], Doc]: ) -> Callable[[Doc], Doc]:
"""Create a pipeline component. Mostly used internally. To create and """Create a pipeline component. Mostly used internally. To create and
@ -569,8 +568,6 @@ class Language:
Defaults to factory name if not set. Defaults to factory name if not set.
config (Optional[Dict[str, Any]]): Config parameters to use for this config (Optional[Dict[str, Any]]): Config parameters to use for this
component. Will be merged with default config, if available. component. Will be merged with default config, if available.
overrides (Optional[Dict[str, Any]]): Config overrides, typically
passed in via the CLI.
validate (bool): Whether to validate the component config against the validate (bool): Whether to validate the component config against the
arguments and types expected by the factory. arguments and types expected by the factory.
RETURNS (Callable[[Doc], Doc]): The pipeline component. RETURNS (Callable[[Doc], Doc]): The pipeline component.
@ -613,7 +610,7 @@ class Language:
# registered functions twice # registered functions twice
# TODO: customize validation to make it more readable / relate it to # TODO: customize validation to make it more readable / relate it to
# pipeline component and why it failed, explain default config # pipeline component and why it failed, explain default config
resolved, filled = registry.resolve(cfg, validate=validate, overrides=overrides) resolved, filled = registry.resolve(cfg, validate=validate)
filled = filled[factory_name] filled = filled[factory_name]
filled["factory"] = factory_name filled["factory"] = factory_name
filled.pop("@factories", None) filled.pop("@factories", None)
@ -657,7 +654,6 @@ class Language:
last: Optional[bool] = None, last: Optional[bool] = None,
source: Optional["Language"] = None, source: Optional["Language"] = None,
config: Optional[Dict[str, Any]] = SimpleFrozenDict(), config: Optional[Dict[str, Any]] = SimpleFrozenDict(),
overrides: Optional[Dict[str, Any]] = SimpleFrozenDict(),
validate: bool = True, validate: bool = True,
) -> Callable[[Doc], Doc]: ) -> Callable[[Doc], Doc]:
"""Add a component to the processing pipeline. Valid components are """Add a component to the processing pipeline. Valid components are
@ -679,8 +675,6 @@ class Language:
component from. component from.
config (Optional[Dict[str, Any]]): Config parameters to use for this config (Optional[Dict[str, Any]]): Config parameters to use for this
component. Will be merged with default config, if available. component. Will be merged with default config, if available.
overrides (Optional[Dict[str, Any]]): Config overrides, typically
passed in via the CLI.
validate (bool): Whether to validate the component config against the validate (bool): Whether to validate the component config against the
arguments and types expected by the factory. arguments and types expected by the factory.
RETURNS (Callable[[Doc], Doc]): The pipeline component. RETURNS (Callable[[Doc], Doc]): The pipeline component.
@ -710,11 +704,7 @@ class Language:
lang_code=self.lang, lang_code=self.lang,
) )
pipe_component = self.create_pipe( pipe_component = self.create_pipe(
factory_name, factory_name, name=name, config=config, validate=validate,
name=name,
config=config,
overrides=overrides,
validate=validate,
) )
pipe_index = self._get_pipe_index(before, after, first, last) pipe_index = self._get_pipe_index(before, after, first, last)
self._pipe_meta[name] = self.get_factory_meta(factory_name) self._pipe_meta[name] = self.get_factory_meta(factory_name)
@ -1416,7 +1406,6 @@ class Language:
*, *,
vocab: Union[Vocab, bool] = True, vocab: Union[Vocab, bool] = True,
disable: Iterable[str] = tuple(), disable: Iterable[str] = tuple(),
overrides: Dict[str, Any] = {},
auto_fill: bool = True, auto_fill: bool = True,
validate: bool = True, validate: bool = True,
) -> "Language": ) -> "Language":
@ -1456,9 +1445,8 @@ class Language:
config = util.copy_config(config) config = util.copy_config(config)
orig_pipeline = config.pop("components", {}) orig_pipeline = config.pop("components", {})
config["components"] = {} config["components"] = {}
non_pipe_overrides, pipe_overrides = _get_config_overrides(overrides)
resolved, filled = registry.resolve( resolved, filled = registry.resolve(
config, validate=validate, schema=ConfigSchema, overrides=non_pipe_overrides config, validate=validate, schema=ConfigSchema
) )
filled["components"] = orig_pipeline filled["components"] = orig_pipeline
config["components"] = orig_pipeline config["components"] = orig_pipeline
@ -1507,11 +1495,7 @@ class Language:
# The pipe name (key in the config) here is the unique name # The pipe name (key in the config) here is the unique name
# of the component, not necessarily the factory # of the component, not necessarily the factory
nlp.add_pipe( nlp.add_pipe(
factory, factory, name=pipe_name, config=pipe_cfg, validate=validate,
name=pipe_name,
config=pipe_cfg,
overrides=pipe_overrides,
validate=validate,
) )
else: else:
model = pipe_cfg["source"] model = pipe_cfg["source"]
@ -1696,15 +1680,6 @@ class FactoryMeta:
default_score_weights: Optional[Dict[str, float]] = None # noqa: E704 default_score_weights: Optional[Dict[str, float]] = None # noqa: E704
def _get_config_overrides(
items: Dict[str, Any], prefix: str = "components"
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
prefix = f"{prefix}."
non_pipe = {k: v for k, v in items.items() if not k.startswith(prefix)}
pipe = {k.replace(prefix, ""): v for k, v in items.items() if k.startswith(prefix)}
return non_pipe, pipe
def _fix_pretrained_vectors_name(nlp: Language) -> None: def _fix_pretrained_vectors_name(nlp: Language) -> None:
# TODO: Replace this once we handle vectors consistently as static # TODO: Replace this once we handle vectors consistently as static
# data # data

View File

@ -27,6 +27,6 @@ def test_issue5137():
with make_tempdir() as tmpdir: with make_tempdir() as tmpdir:
nlp.to_disk(tmpdir) nlp.to_disk(tmpdir)
overrides = {"my_component": {"categories": "my_categories"}} overrides = {"components": {"my_component": {"categories": "my_categories"}}}
nlp2 = spacy.load(tmpdir, component_cfg=overrides) nlp2 = spacy.load(tmpdir, config=overrides)
assert nlp2.get_pipe("my_component").categories == "my_categories" assert nlp2.get_pipe("my_component").categories == "my_categories"

View File

@ -2,6 +2,7 @@ import pytest
from thinc.config import Config, ConfigValidationError from thinc.config import Config, ConfigValidationError
import spacy import spacy
from spacy.lang.en import English from spacy.lang.en import English
from spacy.lang.de import German
from spacy.language import Language from spacy.language import Language
from spacy.util import registry, deep_merge_configs, load_model_from_config from spacy.util import registry, deep_merge_configs, load_model_from_config
from spacy.ml.models import build_Tok2Vec_model, build_tb_parser_model from spacy.ml.models import build_Tok2Vec_model, build_tb_parser_model
@ -282,3 +283,33 @@ def test_serialize_config_missing_pipes():
assert "tok2vec" not in config["components"] assert "tok2vec" not in config["components"]
with pytest.raises(ValueError): with pytest.raises(ValueError):
load_model_from_config(config, auto_fill=True) load_model_from_config(config, auto_fill=True)
def test_config_overrides():
overrides_nested = {"nlp": {"lang": "de", "pipeline": ["tagger"]}}
overrides_dot = {"nlp.lang": "de", "nlp.pipeline": ["tagger"]}
# load_model from config with overrides passed directly to Config
config = Config().from_str(nlp_config_string, overrides=overrides_dot)
nlp, _ = load_model_from_config(config, auto_fill=True)
assert isinstance(nlp, German)
assert nlp.pipe_names == ["tagger"]
# Serialized roundtrip with config passed in
base_config = Config().from_str(nlp_config_string)
base_nlp, _ = load_model_from_config(base_config, auto_fill=True)
assert isinstance(base_nlp, English)
assert base_nlp.pipe_names == ["tok2vec", "tagger"]
with make_tempdir() as d:
base_nlp.to_disk(d)
nlp = spacy.load(d, config=overrides_nested)
assert isinstance(nlp, German)
assert nlp.pipe_names == ["tagger"]
with make_tempdir() as d:
base_nlp.to_disk(d)
nlp = spacy.load(d, config=overrides_dot)
assert isinstance(nlp, German)
assert nlp.pipe_names == ["tagger"]
with make_tempdir() as d:
base_nlp.to_disk(d)
nlp = spacy.load(d)
assert isinstance(nlp, English)
assert nlp.pipe_names == ["tok2vec", "tagger"]

View File

@ -210,7 +210,7 @@ def load_model(
*, *,
vocab: Union["Vocab", bool] = True, vocab: Union["Vocab", bool] = True,
disable: Iterable[str] = tuple(), disable: Iterable[str] = tuple(),
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(), config: Union[Dict[str, Any], Config] = SimpleFrozenDict(),
) -> "Language": ) -> "Language":
"""Load a model from a package or data path. """Load a model from a package or data path.
@ -218,11 +218,11 @@ def load_model(
vocab (Vocab / True): Optional vocab to pass in on initialization. If True, vocab (Vocab / True): Optional vocab to pass in on initialization. If True,
a new Vocab object will be created. a new Vocab object will be created.
disable (Iterable[str]): Names of pipeline components to disable. disable (Iterable[str]): Names of pipeline components to disable.
component_cfg (Dict[str, dict]): Config overrides for pipeline components, config (Dict[str, Any] / Config): Config overrides as nested dict or dict
keyed by component names. keyed by section values in dot notation.
RETURNS (Language): The loaded nlp object. RETURNS (Language): The loaded nlp object.
""" """
kwargs = {"vocab": vocab, "disable": disable, "component_cfg": component_cfg} kwargs = {"vocab": vocab, "disable": disable, "config": config}
if isinstance(name, str): # name or string path if isinstance(name, str): # name or string path
if name.startswith("blank:"): # shortcut for blank model if name.startswith("blank:"): # shortcut for blank model
return get_lang_class(name.replace("blank:", ""))() return get_lang_class(name.replace("blank:", ""))()
@ -240,11 +240,11 @@ def load_model_from_package(
*, *,
vocab: Union["Vocab", bool] = True, vocab: Union["Vocab", bool] = True,
disable: Iterable[str] = tuple(), disable: Iterable[str] = tuple(),
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(), config: Union[Dict[str, Any], Config] = SimpleFrozenDict(),
) -> "Language": ) -> "Language":
"""Load a model from an installed package.""" """Load a model from an installed package."""
cls = importlib.import_module(name) cls = importlib.import_module(name)
return cls.load(vocab=vocab, disable=disable, component_cfg=component_cfg) return cls.load(vocab=vocab, disable=disable, config=config)
def load_model_from_path( def load_model_from_path(
@ -253,7 +253,7 @@ def load_model_from_path(
meta: Optional[Dict[str, Any]] = None, meta: Optional[Dict[str, Any]] = None,
vocab: Union["Vocab", bool] = True, vocab: Union["Vocab", bool] = True,
disable: Iterable[str] = tuple(), disable: Iterable[str] = tuple(),
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(), config: Union[Dict[str, Any], Config] = SimpleFrozenDict(),
) -> "Language": ) -> "Language":
"""Load a model from a data directory path. Creates Language class with """Load a model from a data directory path. Creates Language class with
pipeline from config.cfg and then calls from_disk() with path.""" pipeline from config.cfg and then calls from_disk() with path."""
@ -264,12 +264,8 @@ def load_model_from_path(
config_path = model_path / "config.cfg" config_path = model_path / "config.cfg"
if not config_path.exists() or not config_path.is_file(): if not config_path.exists() or not config_path.is_file():
raise IOError(Errors.E053.format(path=config_path, name="config.cfg")) raise IOError(Errors.E053.format(path=config_path, name="config.cfg"))
config = Config().from_disk(config_path) config = Config().from_disk(config_path, overrides=dict_to_dot(config))
override_cfg = {"components": {p: dict_to_dot(c) for p, c in component_cfg.items()}} nlp, _ = load_model_from_config(config, vocab=vocab, disable=disable)
overrides = dict_to_dot(override_cfg)
nlp, _ = load_model_from_config(
config, vocab=vocab, disable=disable, overrides=overrides
)
return nlp.from_disk(model_path, exclude=disable) return nlp.from_disk(model_path, exclude=disable)
@ -278,7 +274,6 @@ def load_model_from_config(
*, *,
vocab: Union["Vocab", bool] = True, vocab: Union["Vocab", bool] = True,
disable: Iterable[str] = tuple(), disable: Iterable[str] = tuple(),
overrides: Dict[str, Any] = {},
auto_fill: bool = False, auto_fill: bool = False,
validate: bool = True, validate: bool = True,
) -> Tuple["Language", Config]: ) -> Tuple["Language", Config]:
@ -294,12 +289,7 @@ def load_model_from_config(
# registry, including custom subclasses provided via entry points # registry, including custom subclasses provided via entry points
lang_cls = get_lang_class(nlp_config["lang"]) lang_cls = get_lang_class(nlp_config["lang"])
nlp = lang_cls.from_config( nlp = lang_cls.from_config(
config, config, vocab=vocab, disable=disable, auto_fill=auto_fill, validate=validate,
vocab=vocab,
disable=disable,
overrides=overrides,
auto_fill=auto_fill,
validate=validate,
) )
return nlp, nlp.resolved return nlp, nlp.resolved
@ -309,14 +299,10 @@ def load_model_from_init_py(
*, *,
vocab: Union["Vocab", bool] = True, vocab: Union["Vocab", bool] = True,
disable: Iterable[str] = tuple(), disable: Iterable[str] = tuple(),
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(), config: Union[Dict[str, Any], Config] = SimpleFrozenDict(),
) -> "Language": ) -> "Language":
"""Helper function to use in the `load()` method of a model package's """Helper function to use in the `load()` method of a model package's
__init__.py. __init__.py.
init_file (str): Path to model's __init__.py, i.e. `__file__`.
**overrides: Specific overrides, like pipeline components to disable.
RETURNS (Language): `Language` class with loaded model.
""" """
model_path = Path(init_file).parent model_path = Path(init_file).parent
meta = get_model_meta(model_path) meta = get_model_meta(model_path)
@ -325,7 +311,7 @@ def load_model_from_init_py(
if not model_path.exists(): if not model_path.exists():
raise IOError(Errors.E052.format(path=data_path)) raise IOError(Errors.E052.format(path=data_path))
return load_model_from_path( return load_model_from_path(
data_path, vocab=vocab, meta=meta, disable=disable, component_cfg=component_cfg data_path, vocab=vocab, meta=meta, disable=disable, config=config
) )

View File

@ -32,13 +32,13 @@ loaded in via [`Language.from_disk`](/api/language#from_disk).
> nlp = spacy.load("en_core_web_sm", disable=["parser", "tagger"]) > nlp = spacy.load("en_core_web_sm", disable=["parser", "tagger"])
> ``` > ```
| Name | Type | Description | | Name | Type | Description |
| ------------------------------------------ | ----------------- | --------------------------------------------------------------------------------- | | ----------------------------------- | ---------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------- |
| `name` | str / `Path` | Model to load, i.e. package name or path. | | `name` | str / `Path` | Model to load, i.e. package name or path. |
| _keyword-only_ | | | | _keyword-only_ | | |
| `disable` | `List[str]` | Names of pipeline components to [disable](/usage/processing-pipelines#disabling). | | `disable` | `List[str]` | Names of pipeline components to [disable](/usage/processing-pipelines#disabling). |
| `component_cfg` <Tag variant="new">3</Tag> | `Dict[str, dict]` | Optional config overrides for pipeline components, keyed by component names. | | `config` <Tag variant="new">3</Tag> | `Dict[str, Any]` / [`Config`](https://thinc.ai/docs/api-config#config) | Optional config overrides, either as nested dict or dict keyed by section value in dot notation, e.g. `"components.name.value"`. |
| **RETURNS** | `Language` | A `Language` object with the loaded model. | | **RETURNS** | `Language` | A `Language` object with the loaded model. |
Essentially, `spacy.load()` is a convenience wrapper that reads the language ID Essentially, `spacy.load()` is a convenience wrapper that reads the language ID
and pipeline components from a model's `meta.json`, initializes the `Language` and pipeline components from a model's `meta.json`, initializes the `Language`