mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 02:06:31 +03:00
Simplify config overrides in CLI and deserialization (#5880)
This commit is contained in:
parent
50311a4d37
commit
5cc0d89fad
|
@ -8,6 +8,7 @@ warnings.filterwarnings("ignore", message="numpy.ufunc size changed") # noqa
|
|||
|
||||
# These are imported as part of the API
|
||||
from thinc.api import prefer_gpu, require_gpu # noqa: F401
|
||||
from thinc.api import Config
|
||||
|
||||
from . import pipeline # noqa: F401
|
||||
from .cli.info import info # noqa: F401
|
||||
|
@ -26,17 +27,17 @@ if sys.maxunicode == 65535:
|
|||
def load(
|
||||
name: Union[str, Path],
|
||||
disable: Iterable[str] = tuple(),
|
||||
component_cfg: Dict[str, Dict[str, Any]] = util.SimpleFrozenDict(),
|
||||
config: Union[Dict[str, Any], Config] = util.SimpleFrozenDict(),
|
||||
) -> Language:
|
||||
"""Load a spaCy model from an installed package or a local path.
|
||||
|
||||
name (str): Package name or model path.
|
||||
disable (Iterable[str]): Names of pipeline components to disable.
|
||||
component_cfg (Dict[str, dict]): Config overrides for pipeline components,
|
||||
keyed by component names.
|
||||
config (Dict[str, Any] / Config): Config overrides as nested dict or dict
|
||||
keyed by section values in dot notation.
|
||||
RETURNS (Language): The loaded nlp object.
|
||||
"""
|
||||
return util.load_model(name, disable=disable, component_cfg=component_cfg)
|
||||
return util.load_model(name, disable=disable, config=config)
|
||||
|
||||
|
||||
def blank(name: str, **overrides) -> Language:
|
||||
|
|
|
@ -49,11 +49,9 @@ def debug_config_cli(
|
|||
overrides = parse_config_overrides(ctx.args)
|
||||
import_code(code_path)
|
||||
with show_validation_error(config_path):
|
||||
config = Config().from_disk(config_path)
|
||||
config = Config().from_disk(config_path, overrides=overrides)
|
||||
try:
|
||||
nlp, _ = util.load_model_from_config(
|
||||
config, overrides=overrides, auto_fill=auto_fill
|
||||
)
|
||||
nlp, _ = util.load_model_from_config(config, auto_fill=auto_fill)
|
||||
except ValueError as e:
|
||||
msg.fail(str(e), exits=1)
|
||||
if auto_fill:
|
||||
|
@ -136,8 +134,8 @@ def debug_data(
|
|||
if not config_path.exists():
|
||||
msg.fail("Config file not found", config_path, exists=1)
|
||||
with show_validation_error(config_path):
|
||||
cfg = Config().from_disk(config_path)
|
||||
nlp, config = util.load_model_from_config(cfg, overrides=config_overrides)
|
||||
cfg = Config().from_disk(config_path, overrides=config_overrides)
|
||||
nlp, config = util.load_model_from_config(cfg)
|
||||
# Use original config here, not resolved version
|
||||
sourced_components = get_sourced_components(cfg)
|
||||
frozen_components = config["training"]["frozen_components"]
|
||||
|
|
|
@ -49,9 +49,9 @@ def debug_model_cli(
|
|||
}
|
||||
config_overrides = parse_config_overrides(ctx.args)
|
||||
with show_validation_error(config_path):
|
||||
cfg = Config().from_disk(config_path)
|
||||
cfg = Config().from_disk(config_path, overrides=config_overrides)
|
||||
try:
|
||||
nlp, config = util.load_model_from_config(cfg, overrides=config_overrides)
|
||||
nlp, config = util.load_model_from_config(cfg)
|
||||
except ValueError as e:
|
||||
msg.fail(str(e), exits=1)
|
||||
seed = config.get("training", {}).get("seed", None)
|
||||
|
|
|
@ -88,8 +88,8 @@ def pretrain(
|
|||
msg.info("Using CPU")
|
||||
msg.info(f"Loading config from: {config_path}")
|
||||
with show_validation_error(config_path):
|
||||
config = Config().from_disk(config_path)
|
||||
nlp, config = util.load_model_from_config(config, overrides=config_overrides)
|
||||
config = Config().from_disk(config_path, overrides=config_overrides)
|
||||
nlp, config = util.load_model_from_config(config)
|
||||
# TODO: validate that [pretraining] block exists
|
||||
if not output_dir.exists():
|
||||
output_dir.mkdir()
|
||||
|
|
|
@ -75,13 +75,13 @@ def train(
|
|||
msg.info("Using CPU")
|
||||
msg.info(f"Loading config and nlp from: {config_path}")
|
||||
with show_validation_error(config_path):
|
||||
config = Config().from_disk(config_path)
|
||||
config = Config().from_disk(config_path, overrides=config_overrides)
|
||||
if config.get("training", {}).get("seed") is not None:
|
||||
fix_random_seed(config["training"]["seed"])
|
||||
# Use original config here before it's resolved to functions
|
||||
sourced_components = get_sourced_components(config)
|
||||
with show_validation_error(config_path):
|
||||
nlp, config = util.load_model_from_config(config, overrides=config_overrides)
|
||||
nlp, config = util.load_model_from_config(config)
|
||||
if config["training"]["vectors"] is not None:
|
||||
util.load_vectors_into_model(nlp, config["training"]["vectors"])
|
||||
verify_config(nlp)
|
||||
|
@ -144,7 +144,7 @@ def train(
|
|||
max_steps=T_cfg["max_steps"],
|
||||
eval_frequency=T_cfg["eval_frequency"],
|
||||
raw_text=None,
|
||||
exclude=frozen_components
|
||||
exclude=frozen_components,
|
||||
)
|
||||
msg.info(f"Training. Initial learn rate: {optimizer.learn_rate}")
|
||||
print_row = setup_printer(T_cfg, nlp)
|
||||
|
|
|
@ -558,7 +558,6 @@ class Language:
|
|||
name: Optional[str] = None,
|
||||
*,
|
||||
config: Optional[Dict[str, Any]] = SimpleFrozenDict(),
|
||||
overrides: Optional[Dict[str, Any]] = SimpleFrozenDict(),
|
||||
validate: bool = True,
|
||||
) -> Callable[[Doc], Doc]:
|
||||
"""Create a pipeline component. Mostly used internally. To create and
|
||||
|
@ -569,8 +568,6 @@ class Language:
|
|||
Defaults to factory name if not set.
|
||||
config (Optional[Dict[str, Any]]): Config parameters to use for this
|
||||
component. Will be merged with default config, if available.
|
||||
overrides (Optional[Dict[str, Any]]): Config overrides, typically
|
||||
passed in via the CLI.
|
||||
validate (bool): Whether to validate the component config against the
|
||||
arguments and types expected by the factory.
|
||||
RETURNS (Callable[[Doc], Doc]): The pipeline component.
|
||||
|
@ -613,7 +610,7 @@ class Language:
|
|||
# registered functions twice
|
||||
# TODO: customize validation to make it more readable / relate it to
|
||||
# pipeline component and why it failed, explain default config
|
||||
resolved, filled = registry.resolve(cfg, validate=validate, overrides=overrides)
|
||||
resolved, filled = registry.resolve(cfg, validate=validate)
|
||||
filled = filled[factory_name]
|
||||
filled["factory"] = factory_name
|
||||
filled.pop("@factories", None)
|
||||
|
@ -657,7 +654,6 @@ class Language:
|
|||
last: Optional[bool] = None,
|
||||
source: Optional["Language"] = None,
|
||||
config: Optional[Dict[str, Any]] = SimpleFrozenDict(),
|
||||
overrides: Optional[Dict[str, Any]] = SimpleFrozenDict(),
|
||||
validate: bool = True,
|
||||
) -> Callable[[Doc], Doc]:
|
||||
"""Add a component to the processing pipeline. Valid components are
|
||||
|
@ -679,8 +675,6 @@ class Language:
|
|||
component from.
|
||||
config (Optional[Dict[str, Any]]): Config parameters to use for this
|
||||
component. Will be merged with default config, if available.
|
||||
overrides (Optional[Dict[str, Any]]): Config overrides, typically
|
||||
passed in via the CLI.
|
||||
validate (bool): Whether to validate the component config against the
|
||||
arguments and types expected by the factory.
|
||||
RETURNS (Callable[[Doc], Doc]): The pipeline component.
|
||||
|
@ -710,11 +704,7 @@ class Language:
|
|||
lang_code=self.lang,
|
||||
)
|
||||
pipe_component = self.create_pipe(
|
||||
factory_name,
|
||||
name=name,
|
||||
config=config,
|
||||
overrides=overrides,
|
||||
validate=validate,
|
||||
factory_name, name=name, config=config, validate=validate,
|
||||
)
|
||||
pipe_index = self._get_pipe_index(before, after, first, last)
|
||||
self._pipe_meta[name] = self.get_factory_meta(factory_name)
|
||||
|
@ -1416,7 +1406,6 @@ class Language:
|
|||
*,
|
||||
vocab: Union[Vocab, bool] = True,
|
||||
disable: Iterable[str] = tuple(),
|
||||
overrides: Dict[str, Any] = {},
|
||||
auto_fill: bool = True,
|
||||
validate: bool = True,
|
||||
) -> "Language":
|
||||
|
@ -1456,9 +1445,8 @@ class Language:
|
|||
config = util.copy_config(config)
|
||||
orig_pipeline = config.pop("components", {})
|
||||
config["components"] = {}
|
||||
non_pipe_overrides, pipe_overrides = _get_config_overrides(overrides)
|
||||
resolved, filled = registry.resolve(
|
||||
config, validate=validate, schema=ConfigSchema, overrides=non_pipe_overrides
|
||||
config, validate=validate, schema=ConfigSchema
|
||||
)
|
||||
filled["components"] = orig_pipeline
|
||||
config["components"] = orig_pipeline
|
||||
|
@ -1507,11 +1495,7 @@ class Language:
|
|||
# The pipe name (key in the config) here is the unique name
|
||||
# of the component, not necessarily the factory
|
||||
nlp.add_pipe(
|
||||
factory,
|
||||
name=pipe_name,
|
||||
config=pipe_cfg,
|
||||
overrides=pipe_overrides,
|
||||
validate=validate,
|
||||
factory, name=pipe_name, config=pipe_cfg, validate=validate,
|
||||
)
|
||||
else:
|
||||
model = pipe_cfg["source"]
|
||||
|
@ -1696,15 +1680,6 @@ class FactoryMeta:
|
|||
default_score_weights: Optional[Dict[str, float]] = None # noqa: E704
|
||||
|
||||
|
||||
def _get_config_overrides(
|
||||
items: Dict[str, Any], prefix: str = "components"
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
prefix = f"{prefix}."
|
||||
non_pipe = {k: v for k, v in items.items() if not k.startswith(prefix)}
|
||||
pipe = {k.replace(prefix, ""): v for k, v in items.items() if k.startswith(prefix)}
|
||||
return non_pipe, pipe
|
||||
|
||||
|
||||
def _fix_pretrained_vectors_name(nlp: Language) -> None:
|
||||
# TODO: Replace this once we handle vectors consistently as static
|
||||
# data
|
||||
|
|
|
@ -27,6 +27,6 @@ def test_issue5137():
|
|||
|
||||
with make_tempdir() as tmpdir:
|
||||
nlp.to_disk(tmpdir)
|
||||
overrides = {"my_component": {"categories": "my_categories"}}
|
||||
nlp2 = spacy.load(tmpdir, component_cfg=overrides)
|
||||
overrides = {"components": {"my_component": {"categories": "my_categories"}}}
|
||||
nlp2 = spacy.load(tmpdir, config=overrides)
|
||||
assert nlp2.get_pipe("my_component").categories == "my_categories"
|
||||
|
|
|
@ -2,6 +2,7 @@ import pytest
|
|||
from thinc.config import Config, ConfigValidationError
|
||||
import spacy
|
||||
from spacy.lang.en import English
|
||||
from spacy.lang.de import German
|
||||
from spacy.language import Language
|
||||
from spacy.util import registry, deep_merge_configs, load_model_from_config
|
||||
from spacy.ml.models import build_Tok2Vec_model, build_tb_parser_model
|
||||
|
@ -282,3 +283,33 @@ def test_serialize_config_missing_pipes():
|
|||
assert "tok2vec" not in config["components"]
|
||||
with pytest.raises(ValueError):
|
||||
load_model_from_config(config, auto_fill=True)
|
||||
|
||||
|
||||
def test_config_overrides():
|
||||
overrides_nested = {"nlp": {"lang": "de", "pipeline": ["tagger"]}}
|
||||
overrides_dot = {"nlp.lang": "de", "nlp.pipeline": ["tagger"]}
|
||||
# load_model from config with overrides passed directly to Config
|
||||
config = Config().from_str(nlp_config_string, overrides=overrides_dot)
|
||||
nlp, _ = load_model_from_config(config, auto_fill=True)
|
||||
assert isinstance(nlp, German)
|
||||
assert nlp.pipe_names == ["tagger"]
|
||||
# Serialized roundtrip with config passed in
|
||||
base_config = Config().from_str(nlp_config_string)
|
||||
base_nlp, _ = load_model_from_config(base_config, auto_fill=True)
|
||||
assert isinstance(base_nlp, English)
|
||||
assert base_nlp.pipe_names == ["tok2vec", "tagger"]
|
||||
with make_tempdir() as d:
|
||||
base_nlp.to_disk(d)
|
||||
nlp = spacy.load(d, config=overrides_nested)
|
||||
assert isinstance(nlp, German)
|
||||
assert nlp.pipe_names == ["tagger"]
|
||||
with make_tempdir() as d:
|
||||
base_nlp.to_disk(d)
|
||||
nlp = spacy.load(d, config=overrides_dot)
|
||||
assert isinstance(nlp, German)
|
||||
assert nlp.pipe_names == ["tagger"]
|
||||
with make_tempdir() as d:
|
||||
base_nlp.to_disk(d)
|
||||
nlp = spacy.load(d)
|
||||
assert isinstance(nlp, English)
|
||||
assert nlp.pipe_names == ["tok2vec", "tagger"]
|
||||
|
|
|
@ -210,7 +210,7 @@ def load_model(
|
|||
*,
|
||||
vocab: Union["Vocab", bool] = True,
|
||||
disable: Iterable[str] = tuple(),
|
||||
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
|
||||
config: Union[Dict[str, Any], Config] = SimpleFrozenDict(),
|
||||
) -> "Language":
|
||||
"""Load a model from a package or data path.
|
||||
|
||||
|
@ -218,11 +218,11 @@ def load_model(
|
|||
vocab (Vocab / True): Optional vocab to pass in on initialization. If True,
|
||||
a new Vocab object will be created.
|
||||
disable (Iterable[str]): Names of pipeline components to disable.
|
||||
component_cfg (Dict[str, dict]): Config overrides for pipeline components,
|
||||
keyed by component names.
|
||||
config (Dict[str, Any] / Config): Config overrides as nested dict or dict
|
||||
keyed by section values in dot notation.
|
||||
RETURNS (Language): The loaded nlp object.
|
||||
"""
|
||||
kwargs = {"vocab": vocab, "disable": disable, "component_cfg": component_cfg}
|
||||
kwargs = {"vocab": vocab, "disable": disable, "config": config}
|
||||
if isinstance(name, str): # name or string path
|
||||
if name.startswith("blank:"): # shortcut for blank model
|
||||
return get_lang_class(name.replace("blank:", ""))()
|
||||
|
@ -240,11 +240,11 @@ def load_model_from_package(
|
|||
*,
|
||||
vocab: Union["Vocab", bool] = True,
|
||||
disable: Iterable[str] = tuple(),
|
||||
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
|
||||
config: Union[Dict[str, Any], Config] = SimpleFrozenDict(),
|
||||
) -> "Language":
|
||||
"""Load a model from an installed package."""
|
||||
cls = importlib.import_module(name)
|
||||
return cls.load(vocab=vocab, disable=disable, component_cfg=component_cfg)
|
||||
return cls.load(vocab=vocab, disable=disable, config=config)
|
||||
|
||||
|
||||
def load_model_from_path(
|
||||
|
@ -253,7 +253,7 @@ def load_model_from_path(
|
|||
meta: Optional[Dict[str, Any]] = None,
|
||||
vocab: Union["Vocab", bool] = True,
|
||||
disable: Iterable[str] = tuple(),
|
||||
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
|
||||
config: Union[Dict[str, Any], Config] = SimpleFrozenDict(),
|
||||
) -> "Language":
|
||||
"""Load a model from a data directory path. Creates Language class with
|
||||
pipeline from config.cfg and then calls from_disk() with path."""
|
||||
|
@ -264,12 +264,8 @@ def load_model_from_path(
|
|||
config_path = model_path / "config.cfg"
|
||||
if not config_path.exists() or not config_path.is_file():
|
||||
raise IOError(Errors.E053.format(path=config_path, name="config.cfg"))
|
||||
config = Config().from_disk(config_path)
|
||||
override_cfg = {"components": {p: dict_to_dot(c) for p, c in component_cfg.items()}}
|
||||
overrides = dict_to_dot(override_cfg)
|
||||
nlp, _ = load_model_from_config(
|
||||
config, vocab=vocab, disable=disable, overrides=overrides
|
||||
)
|
||||
config = Config().from_disk(config_path, overrides=dict_to_dot(config))
|
||||
nlp, _ = load_model_from_config(config, vocab=vocab, disable=disable)
|
||||
return nlp.from_disk(model_path, exclude=disable)
|
||||
|
||||
|
||||
|
@ -278,7 +274,6 @@ def load_model_from_config(
|
|||
*,
|
||||
vocab: Union["Vocab", bool] = True,
|
||||
disable: Iterable[str] = tuple(),
|
||||
overrides: Dict[str, Any] = {},
|
||||
auto_fill: bool = False,
|
||||
validate: bool = True,
|
||||
) -> Tuple["Language", Config]:
|
||||
|
@ -294,12 +289,7 @@ def load_model_from_config(
|
|||
# registry, including custom subclasses provided via entry points
|
||||
lang_cls = get_lang_class(nlp_config["lang"])
|
||||
nlp = lang_cls.from_config(
|
||||
config,
|
||||
vocab=vocab,
|
||||
disable=disable,
|
||||
overrides=overrides,
|
||||
auto_fill=auto_fill,
|
||||
validate=validate,
|
||||
config, vocab=vocab, disable=disable, auto_fill=auto_fill, validate=validate,
|
||||
)
|
||||
return nlp, nlp.resolved
|
||||
|
||||
|
@ -309,14 +299,10 @@ def load_model_from_init_py(
|
|||
*,
|
||||
vocab: Union["Vocab", bool] = True,
|
||||
disable: Iterable[str] = tuple(),
|
||||
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
|
||||
config: Union[Dict[str, Any], Config] = SimpleFrozenDict(),
|
||||
) -> "Language":
|
||||
"""Helper function to use in the `load()` method of a model package's
|
||||
__init__.py.
|
||||
|
||||
init_file (str): Path to model's __init__.py, i.e. `__file__`.
|
||||
**overrides: Specific overrides, like pipeline components to disable.
|
||||
RETURNS (Language): `Language` class with loaded model.
|
||||
"""
|
||||
model_path = Path(init_file).parent
|
||||
meta = get_model_meta(model_path)
|
||||
|
@ -325,7 +311,7 @@ def load_model_from_init_py(
|
|||
if not model_path.exists():
|
||||
raise IOError(Errors.E052.format(path=data_path))
|
||||
return load_model_from_path(
|
||||
data_path, vocab=vocab, meta=meta, disable=disable, component_cfg=component_cfg
|
||||
data_path, vocab=vocab, meta=meta, disable=disable, config=config
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -33,11 +33,11 @@ loaded in via [`Language.from_disk`](/api/language#from_disk).
|
|||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ------------------------------------------ | ----------------- | --------------------------------------------------------------------------------- |
|
||||
| ----------------------------------- | ---------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `name` | str / `Path` | Model to load, i.e. package name or path. |
|
||||
| _keyword-only_ | | |
|
||||
| `disable` | `List[str]` | Names of pipeline components to [disable](/usage/processing-pipelines#disabling). |
|
||||
| `component_cfg` <Tag variant="new">3</Tag> | `Dict[str, dict]` | Optional config overrides for pipeline components, keyed by component names. |
|
||||
| `config` <Tag variant="new">3</Tag> | `Dict[str, Any]` / [`Config`](https://thinc.ai/docs/api-config#config) | Optional config overrides, either as nested dict or dict keyed by section value in dot notation, e.g. `"components.name.value"`. |
|
||||
| **RETURNS** | `Language` | A `Language` object with the loaded model. |
|
||||
|
||||
Essentially, `spacy.load()` is a convenience wrapper that reads the language ID
|
||||
|
|
Loading…
Reference in New Issue
Block a user