mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 09:26:27 +03:00
Simplify config overrides in CLI and deserialization (#5880)
This commit is contained in:
parent
50311a4d37
commit
5cc0d89fad
|
@ -8,6 +8,7 @@ warnings.filterwarnings("ignore", message="numpy.ufunc size changed") # noqa
|
||||||
|
|
||||||
# These are imported as part of the API
|
# These are imported as part of the API
|
||||||
from thinc.api import prefer_gpu, require_gpu # noqa: F401
|
from thinc.api import prefer_gpu, require_gpu # noqa: F401
|
||||||
|
from thinc.api import Config
|
||||||
|
|
||||||
from . import pipeline # noqa: F401
|
from . import pipeline # noqa: F401
|
||||||
from .cli.info import info # noqa: F401
|
from .cli.info import info # noqa: F401
|
||||||
|
@ -26,17 +27,17 @@ if sys.maxunicode == 65535:
|
||||||
def load(
|
def load(
|
||||||
name: Union[str, Path],
|
name: Union[str, Path],
|
||||||
disable: Iterable[str] = tuple(),
|
disable: Iterable[str] = tuple(),
|
||||||
component_cfg: Dict[str, Dict[str, Any]] = util.SimpleFrozenDict(),
|
config: Union[Dict[str, Any], Config] = util.SimpleFrozenDict(),
|
||||||
) -> Language:
|
) -> Language:
|
||||||
"""Load a spaCy model from an installed package or a local path.
|
"""Load a spaCy model from an installed package or a local path.
|
||||||
|
|
||||||
name (str): Package name or model path.
|
name (str): Package name or model path.
|
||||||
disable (Iterable[str]): Names of pipeline components to disable.
|
disable (Iterable[str]): Names of pipeline components to disable.
|
||||||
component_cfg (Dict[str, dict]): Config overrides for pipeline components,
|
config (Dict[str, Any] / Config): Config overrides as nested dict or dict
|
||||||
keyed by component names.
|
keyed by section values in dot notation.
|
||||||
RETURNS (Language): The loaded nlp object.
|
RETURNS (Language): The loaded nlp object.
|
||||||
"""
|
"""
|
||||||
return util.load_model(name, disable=disable, component_cfg=component_cfg)
|
return util.load_model(name, disable=disable, config=config)
|
||||||
|
|
||||||
|
|
||||||
def blank(name: str, **overrides) -> Language:
|
def blank(name: str, **overrides) -> Language:
|
||||||
|
|
|
@ -49,11 +49,9 @@ def debug_config_cli(
|
||||||
overrides = parse_config_overrides(ctx.args)
|
overrides = parse_config_overrides(ctx.args)
|
||||||
import_code(code_path)
|
import_code(code_path)
|
||||||
with show_validation_error(config_path):
|
with show_validation_error(config_path):
|
||||||
config = Config().from_disk(config_path)
|
config = Config().from_disk(config_path, overrides=overrides)
|
||||||
try:
|
try:
|
||||||
nlp, _ = util.load_model_from_config(
|
nlp, _ = util.load_model_from_config(config, auto_fill=auto_fill)
|
||||||
config, overrides=overrides, auto_fill=auto_fill
|
|
||||||
)
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
msg.fail(str(e), exits=1)
|
msg.fail(str(e), exits=1)
|
||||||
if auto_fill:
|
if auto_fill:
|
||||||
|
@ -136,8 +134,8 @@ def debug_data(
|
||||||
if not config_path.exists():
|
if not config_path.exists():
|
||||||
msg.fail("Config file not found", config_path, exists=1)
|
msg.fail("Config file not found", config_path, exists=1)
|
||||||
with show_validation_error(config_path):
|
with show_validation_error(config_path):
|
||||||
cfg = Config().from_disk(config_path)
|
cfg = Config().from_disk(config_path, overrides=config_overrides)
|
||||||
nlp, config = util.load_model_from_config(cfg, overrides=config_overrides)
|
nlp, config = util.load_model_from_config(cfg)
|
||||||
# Use original config here, not resolved version
|
# Use original config here, not resolved version
|
||||||
sourced_components = get_sourced_components(cfg)
|
sourced_components = get_sourced_components(cfg)
|
||||||
frozen_components = config["training"]["frozen_components"]
|
frozen_components = config["training"]["frozen_components"]
|
||||||
|
|
|
@ -49,9 +49,9 @@ def debug_model_cli(
|
||||||
}
|
}
|
||||||
config_overrides = parse_config_overrides(ctx.args)
|
config_overrides = parse_config_overrides(ctx.args)
|
||||||
with show_validation_error(config_path):
|
with show_validation_error(config_path):
|
||||||
cfg = Config().from_disk(config_path)
|
cfg = Config().from_disk(config_path, overrides=config_overrides)
|
||||||
try:
|
try:
|
||||||
nlp, config = util.load_model_from_config(cfg, overrides=config_overrides)
|
nlp, config = util.load_model_from_config(cfg)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
msg.fail(str(e), exits=1)
|
msg.fail(str(e), exits=1)
|
||||||
seed = config.get("training", {}).get("seed", None)
|
seed = config.get("training", {}).get("seed", None)
|
||||||
|
|
|
@ -88,8 +88,8 @@ def pretrain(
|
||||||
msg.info("Using CPU")
|
msg.info("Using CPU")
|
||||||
msg.info(f"Loading config from: {config_path}")
|
msg.info(f"Loading config from: {config_path}")
|
||||||
with show_validation_error(config_path):
|
with show_validation_error(config_path):
|
||||||
config = Config().from_disk(config_path)
|
config = Config().from_disk(config_path, overrides=config_overrides)
|
||||||
nlp, config = util.load_model_from_config(config, overrides=config_overrides)
|
nlp, config = util.load_model_from_config(config)
|
||||||
# TODO: validate that [pretraining] block exists
|
# TODO: validate that [pretraining] block exists
|
||||||
if not output_dir.exists():
|
if not output_dir.exists():
|
||||||
output_dir.mkdir()
|
output_dir.mkdir()
|
||||||
|
|
|
@ -75,13 +75,13 @@ def train(
|
||||||
msg.info("Using CPU")
|
msg.info("Using CPU")
|
||||||
msg.info(f"Loading config and nlp from: {config_path}")
|
msg.info(f"Loading config and nlp from: {config_path}")
|
||||||
with show_validation_error(config_path):
|
with show_validation_error(config_path):
|
||||||
config = Config().from_disk(config_path)
|
config = Config().from_disk(config_path, overrides=config_overrides)
|
||||||
if config.get("training", {}).get("seed") is not None:
|
if config.get("training", {}).get("seed") is not None:
|
||||||
fix_random_seed(config["training"]["seed"])
|
fix_random_seed(config["training"]["seed"])
|
||||||
# Use original config here before it's resolved to functions
|
# Use original config here before it's resolved to functions
|
||||||
sourced_components = get_sourced_components(config)
|
sourced_components = get_sourced_components(config)
|
||||||
with show_validation_error(config_path):
|
with show_validation_error(config_path):
|
||||||
nlp, config = util.load_model_from_config(config, overrides=config_overrides)
|
nlp, config = util.load_model_from_config(config)
|
||||||
if config["training"]["vectors"] is not None:
|
if config["training"]["vectors"] is not None:
|
||||||
util.load_vectors_into_model(nlp, config["training"]["vectors"])
|
util.load_vectors_into_model(nlp, config["training"]["vectors"])
|
||||||
verify_config(nlp)
|
verify_config(nlp)
|
||||||
|
@ -144,7 +144,7 @@ def train(
|
||||||
max_steps=T_cfg["max_steps"],
|
max_steps=T_cfg["max_steps"],
|
||||||
eval_frequency=T_cfg["eval_frequency"],
|
eval_frequency=T_cfg["eval_frequency"],
|
||||||
raw_text=None,
|
raw_text=None,
|
||||||
exclude=frozen_components
|
exclude=frozen_components,
|
||||||
)
|
)
|
||||||
msg.info(f"Training. Initial learn rate: {optimizer.learn_rate}")
|
msg.info(f"Training. Initial learn rate: {optimizer.learn_rate}")
|
||||||
print_row = setup_printer(T_cfg, nlp)
|
print_row = setup_printer(T_cfg, nlp)
|
||||||
|
|
|
@ -558,7 +558,6 @@ class Language:
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
*,
|
*,
|
||||||
config: Optional[Dict[str, Any]] = SimpleFrozenDict(),
|
config: Optional[Dict[str, Any]] = SimpleFrozenDict(),
|
||||||
overrides: Optional[Dict[str, Any]] = SimpleFrozenDict(),
|
|
||||||
validate: bool = True,
|
validate: bool = True,
|
||||||
) -> Callable[[Doc], Doc]:
|
) -> Callable[[Doc], Doc]:
|
||||||
"""Create a pipeline component. Mostly used internally. To create and
|
"""Create a pipeline component. Mostly used internally. To create and
|
||||||
|
@ -569,8 +568,6 @@ class Language:
|
||||||
Defaults to factory name if not set.
|
Defaults to factory name if not set.
|
||||||
config (Optional[Dict[str, Any]]): Config parameters to use for this
|
config (Optional[Dict[str, Any]]): Config parameters to use for this
|
||||||
component. Will be merged with default config, if available.
|
component. Will be merged with default config, if available.
|
||||||
overrides (Optional[Dict[str, Any]]): Config overrides, typically
|
|
||||||
passed in via the CLI.
|
|
||||||
validate (bool): Whether to validate the component config against the
|
validate (bool): Whether to validate the component config against the
|
||||||
arguments and types expected by the factory.
|
arguments and types expected by the factory.
|
||||||
RETURNS (Callable[[Doc], Doc]): The pipeline component.
|
RETURNS (Callable[[Doc], Doc]): The pipeline component.
|
||||||
|
@ -613,7 +610,7 @@ class Language:
|
||||||
# registered functions twice
|
# registered functions twice
|
||||||
# TODO: customize validation to make it more readable / relate it to
|
# TODO: customize validation to make it more readable / relate it to
|
||||||
# pipeline component and why it failed, explain default config
|
# pipeline component and why it failed, explain default config
|
||||||
resolved, filled = registry.resolve(cfg, validate=validate, overrides=overrides)
|
resolved, filled = registry.resolve(cfg, validate=validate)
|
||||||
filled = filled[factory_name]
|
filled = filled[factory_name]
|
||||||
filled["factory"] = factory_name
|
filled["factory"] = factory_name
|
||||||
filled.pop("@factories", None)
|
filled.pop("@factories", None)
|
||||||
|
@ -657,7 +654,6 @@ class Language:
|
||||||
last: Optional[bool] = None,
|
last: Optional[bool] = None,
|
||||||
source: Optional["Language"] = None,
|
source: Optional["Language"] = None,
|
||||||
config: Optional[Dict[str, Any]] = SimpleFrozenDict(),
|
config: Optional[Dict[str, Any]] = SimpleFrozenDict(),
|
||||||
overrides: Optional[Dict[str, Any]] = SimpleFrozenDict(),
|
|
||||||
validate: bool = True,
|
validate: bool = True,
|
||||||
) -> Callable[[Doc], Doc]:
|
) -> Callable[[Doc], Doc]:
|
||||||
"""Add a component to the processing pipeline. Valid components are
|
"""Add a component to the processing pipeline. Valid components are
|
||||||
|
@ -679,8 +675,6 @@ class Language:
|
||||||
component from.
|
component from.
|
||||||
config (Optional[Dict[str, Any]]): Config parameters to use for this
|
config (Optional[Dict[str, Any]]): Config parameters to use for this
|
||||||
component. Will be merged with default config, if available.
|
component. Will be merged with default config, if available.
|
||||||
overrides (Optional[Dict[str, Any]]): Config overrides, typically
|
|
||||||
passed in via the CLI.
|
|
||||||
validate (bool): Whether to validate the component config against the
|
validate (bool): Whether to validate the component config against the
|
||||||
arguments and types expected by the factory.
|
arguments and types expected by the factory.
|
||||||
RETURNS (Callable[[Doc], Doc]): The pipeline component.
|
RETURNS (Callable[[Doc], Doc]): The pipeline component.
|
||||||
|
@ -710,11 +704,7 @@ class Language:
|
||||||
lang_code=self.lang,
|
lang_code=self.lang,
|
||||||
)
|
)
|
||||||
pipe_component = self.create_pipe(
|
pipe_component = self.create_pipe(
|
||||||
factory_name,
|
factory_name, name=name, config=config, validate=validate,
|
||||||
name=name,
|
|
||||||
config=config,
|
|
||||||
overrides=overrides,
|
|
||||||
validate=validate,
|
|
||||||
)
|
)
|
||||||
pipe_index = self._get_pipe_index(before, after, first, last)
|
pipe_index = self._get_pipe_index(before, after, first, last)
|
||||||
self._pipe_meta[name] = self.get_factory_meta(factory_name)
|
self._pipe_meta[name] = self.get_factory_meta(factory_name)
|
||||||
|
@ -1416,7 +1406,6 @@ class Language:
|
||||||
*,
|
*,
|
||||||
vocab: Union[Vocab, bool] = True,
|
vocab: Union[Vocab, bool] = True,
|
||||||
disable: Iterable[str] = tuple(),
|
disable: Iterable[str] = tuple(),
|
||||||
overrides: Dict[str, Any] = {},
|
|
||||||
auto_fill: bool = True,
|
auto_fill: bool = True,
|
||||||
validate: bool = True,
|
validate: bool = True,
|
||||||
) -> "Language":
|
) -> "Language":
|
||||||
|
@ -1456,9 +1445,8 @@ class Language:
|
||||||
config = util.copy_config(config)
|
config = util.copy_config(config)
|
||||||
orig_pipeline = config.pop("components", {})
|
orig_pipeline = config.pop("components", {})
|
||||||
config["components"] = {}
|
config["components"] = {}
|
||||||
non_pipe_overrides, pipe_overrides = _get_config_overrides(overrides)
|
|
||||||
resolved, filled = registry.resolve(
|
resolved, filled = registry.resolve(
|
||||||
config, validate=validate, schema=ConfigSchema, overrides=non_pipe_overrides
|
config, validate=validate, schema=ConfigSchema
|
||||||
)
|
)
|
||||||
filled["components"] = orig_pipeline
|
filled["components"] = orig_pipeline
|
||||||
config["components"] = orig_pipeline
|
config["components"] = orig_pipeline
|
||||||
|
@ -1507,11 +1495,7 @@ class Language:
|
||||||
# The pipe name (key in the config) here is the unique name
|
# The pipe name (key in the config) here is the unique name
|
||||||
# of the component, not necessarily the factory
|
# of the component, not necessarily the factory
|
||||||
nlp.add_pipe(
|
nlp.add_pipe(
|
||||||
factory,
|
factory, name=pipe_name, config=pipe_cfg, validate=validate,
|
||||||
name=pipe_name,
|
|
||||||
config=pipe_cfg,
|
|
||||||
overrides=pipe_overrides,
|
|
||||||
validate=validate,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
model = pipe_cfg["source"]
|
model = pipe_cfg["source"]
|
||||||
|
@ -1696,15 +1680,6 @@ class FactoryMeta:
|
||||||
default_score_weights: Optional[Dict[str, float]] = None # noqa: E704
|
default_score_weights: Optional[Dict[str, float]] = None # noqa: E704
|
||||||
|
|
||||||
|
|
||||||
def _get_config_overrides(
|
|
||||||
items: Dict[str, Any], prefix: str = "components"
|
|
||||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
||||||
prefix = f"{prefix}."
|
|
||||||
non_pipe = {k: v for k, v in items.items() if not k.startswith(prefix)}
|
|
||||||
pipe = {k.replace(prefix, ""): v for k, v in items.items() if k.startswith(prefix)}
|
|
||||||
return non_pipe, pipe
|
|
||||||
|
|
||||||
|
|
||||||
def _fix_pretrained_vectors_name(nlp: Language) -> None:
|
def _fix_pretrained_vectors_name(nlp: Language) -> None:
|
||||||
# TODO: Replace this once we handle vectors consistently as static
|
# TODO: Replace this once we handle vectors consistently as static
|
||||||
# data
|
# data
|
||||||
|
|
|
@ -27,6 +27,6 @@ def test_issue5137():
|
||||||
|
|
||||||
with make_tempdir() as tmpdir:
|
with make_tempdir() as tmpdir:
|
||||||
nlp.to_disk(tmpdir)
|
nlp.to_disk(tmpdir)
|
||||||
overrides = {"my_component": {"categories": "my_categories"}}
|
overrides = {"components": {"my_component": {"categories": "my_categories"}}}
|
||||||
nlp2 = spacy.load(tmpdir, component_cfg=overrides)
|
nlp2 = spacy.load(tmpdir, config=overrides)
|
||||||
assert nlp2.get_pipe("my_component").categories == "my_categories"
|
assert nlp2.get_pipe("my_component").categories == "my_categories"
|
||||||
|
|
|
@ -2,6 +2,7 @@ import pytest
|
||||||
from thinc.config import Config, ConfigValidationError
|
from thinc.config import Config, ConfigValidationError
|
||||||
import spacy
|
import spacy
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
|
from spacy.lang.de import German
|
||||||
from spacy.language import Language
|
from spacy.language import Language
|
||||||
from spacy.util import registry, deep_merge_configs, load_model_from_config
|
from spacy.util import registry, deep_merge_configs, load_model_from_config
|
||||||
from spacy.ml.models import build_Tok2Vec_model, build_tb_parser_model
|
from spacy.ml.models import build_Tok2Vec_model, build_tb_parser_model
|
||||||
|
@ -282,3 +283,33 @@ def test_serialize_config_missing_pipes():
|
||||||
assert "tok2vec" not in config["components"]
|
assert "tok2vec" not in config["components"]
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
load_model_from_config(config, auto_fill=True)
|
load_model_from_config(config, auto_fill=True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_overrides():
|
||||||
|
overrides_nested = {"nlp": {"lang": "de", "pipeline": ["tagger"]}}
|
||||||
|
overrides_dot = {"nlp.lang": "de", "nlp.pipeline": ["tagger"]}
|
||||||
|
# load_model from config with overrides passed directly to Config
|
||||||
|
config = Config().from_str(nlp_config_string, overrides=overrides_dot)
|
||||||
|
nlp, _ = load_model_from_config(config, auto_fill=True)
|
||||||
|
assert isinstance(nlp, German)
|
||||||
|
assert nlp.pipe_names == ["tagger"]
|
||||||
|
# Serialized roundtrip with config passed in
|
||||||
|
base_config = Config().from_str(nlp_config_string)
|
||||||
|
base_nlp, _ = load_model_from_config(base_config, auto_fill=True)
|
||||||
|
assert isinstance(base_nlp, English)
|
||||||
|
assert base_nlp.pipe_names == ["tok2vec", "tagger"]
|
||||||
|
with make_tempdir() as d:
|
||||||
|
base_nlp.to_disk(d)
|
||||||
|
nlp = spacy.load(d, config=overrides_nested)
|
||||||
|
assert isinstance(nlp, German)
|
||||||
|
assert nlp.pipe_names == ["tagger"]
|
||||||
|
with make_tempdir() as d:
|
||||||
|
base_nlp.to_disk(d)
|
||||||
|
nlp = spacy.load(d, config=overrides_dot)
|
||||||
|
assert isinstance(nlp, German)
|
||||||
|
assert nlp.pipe_names == ["tagger"]
|
||||||
|
with make_tempdir() as d:
|
||||||
|
base_nlp.to_disk(d)
|
||||||
|
nlp = spacy.load(d)
|
||||||
|
assert isinstance(nlp, English)
|
||||||
|
assert nlp.pipe_names == ["tok2vec", "tagger"]
|
||||||
|
|
|
@ -210,7 +210,7 @@ def load_model(
|
||||||
*,
|
*,
|
||||||
vocab: Union["Vocab", bool] = True,
|
vocab: Union["Vocab", bool] = True,
|
||||||
disable: Iterable[str] = tuple(),
|
disable: Iterable[str] = tuple(),
|
||||||
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
|
config: Union[Dict[str, Any], Config] = SimpleFrozenDict(),
|
||||||
) -> "Language":
|
) -> "Language":
|
||||||
"""Load a model from a package or data path.
|
"""Load a model from a package or data path.
|
||||||
|
|
||||||
|
@ -218,11 +218,11 @@ def load_model(
|
||||||
vocab (Vocab / True): Optional vocab to pass in on initialization. If True,
|
vocab (Vocab / True): Optional vocab to pass in on initialization. If True,
|
||||||
a new Vocab object will be created.
|
a new Vocab object will be created.
|
||||||
disable (Iterable[str]): Names of pipeline components to disable.
|
disable (Iterable[str]): Names of pipeline components to disable.
|
||||||
component_cfg (Dict[str, dict]): Config overrides for pipeline components,
|
config (Dict[str, Any] / Config): Config overrides as nested dict or dict
|
||||||
keyed by component names.
|
keyed by section values in dot notation.
|
||||||
RETURNS (Language): The loaded nlp object.
|
RETURNS (Language): The loaded nlp object.
|
||||||
"""
|
"""
|
||||||
kwargs = {"vocab": vocab, "disable": disable, "component_cfg": component_cfg}
|
kwargs = {"vocab": vocab, "disable": disable, "config": config}
|
||||||
if isinstance(name, str): # name or string path
|
if isinstance(name, str): # name or string path
|
||||||
if name.startswith("blank:"): # shortcut for blank model
|
if name.startswith("blank:"): # shortcut for blank model
|
||||||
return get_lang_class(name.replace("blank:", ""))()
|
return get_lang_class(name.replace("blank:", ""))()
|
||||||
|
@ -240,11 +240,11 @@ def load_model_from_package(
|
||||||
*,
|
*,
|
||||||
vocab: Union["Vocab", bool] = True,
|
vocab: Union["Vocab", bool] = True,
|
||||||
disable: Iterable[str] = tuple(),
|
disable: Iterable[str] = tuple(),
|
||||||
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
|
config: Union[Dict[str, Any], Config] = SimpleFrozenDict(),
|
||||||
) -> "Language":
|
) -> "Language":
|
||||||
"""Load a model from an installed package."""
|
"""Load a model from an installed package."""
|
||||||
cls = importlib.import_module(name)
|
cls = importlib.import_module(name)
|
||||||
return cls.load(vocab=vocab, disable=disable, component_cfg=component_cfg)
|
return cls.load(vocab=vocab, disable=disable, config=config)
|
||||||
|
|
||||||
|
|
||||||
def load_model_from_path(
|
def load_model_from_path(
|
||||||
|
@ -253,7 +253,7 @@ def load_model_from_path(
|
||||||
meta: Optional[Dict[str, Any]] = None,
|
meta: Optional[Dict[str, Any]] = None,
|
||||||
vocab: Union["Vocab", bool] = True,
|
vocab: Union["Vocab", bool] = True,
|
||||||
disable: Iterable[str] = tuple(),
|
disable: Iterable[str] = tuple(),
|
||||||
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
|
config: Union[Dict[str, Any], Config] = SimpleFrozenDict(),
|
||||||
) -> "Language":
|
) -> "Language":
|
||||||
"""Load a model from a data directory path. Creates Language class with
|
"""Load a model from a data directory path. Creates Language class with
|
||||||
pipeline from config.cfg and then calls from_disk() with path."""
|
pipeline from config.cfg and then calls from_disk() with path."""
|
||||||
|
@ -264,12 +264,8 @@ def load_model_from_path(
|
||||||
config_path = model_path / "config.cfg"
|
config_path = model_path / "config.cfg"
|
||||||
if not config_path.exists() or not config_path.is_file():
|
if not config_path.exists() or not config_path.is_file():
|
||||||
raise IOError(Errors.E053.format(path=config_path, name="config.cfg"))
|
raise IOError(Errors.E053.format(path=config_path, name="config.cfg"))
|
||||||
config = Config().from_disk(config_path)
|
config = Config().from_disk(config_path, overrides=dict_to_dot(config))
|
||||||
override_cfg = {"components": {p: dict_to_dot(c) for p, c in component_cfg.items()}}
|
nlp, _ = load_model_from_config(config, vocab=vocab, disable=disable)
|
||||||
overrides = dict_to_dot(override_cfg)
|
|
||||||
nlp, _ = load_model_from_config(
|
|
||||||
config, vocab=vocab, disable=disable, overrides=overrides
|
|
||||||
)
|
|
||||||
return nlp.from_disk(model_path, exclude=disable)
|
return nlp.from_disk(model_path, exclude=disable)
|
||||||
|
|
||||||
|
|
||||||
|
@ -278,7 +274,6 @@ def load_model_from_config(
|
||||||
*,
|
*,
|
||||||
vocab: Union["Vocab", bool] = True,
|
vocab: Union["Vocab", bool] = True,
|
||||||
disable: Iterable[str] = tuple(),
|
disable: Iterable[str] = tuple(),
|
||||||
overrides: Dict[str, Any] = {},
|
|
||||||
auto_fill: bool = False,
|
auto_fill: bool = False,
|
||||||
validate: bool = True,
|
validate: bool = True,
|
||||||
) -> Tuple["Language", Config]:
|
) -> Tuple["Language", Config]:
|
||||||
|
@ -294,12 +289,7 @@ def load_model_from_config(
|
||||||
# registry, including custom subclasses provided via entry points
|
# registry, including custom subclasses provided via entry points
|
||||||
lang_cls = get_lang_class(nlp_config["lang"])
|
lang_cls = get_lang_class(nlp_config["lang"])
|
||||||
nlp = lang_cls.from_config(
|
nlp = lang_cls.from_config(
|
||||||
config,
|
config, vocab=vocab, disable=disable, auto_fill=auto_fill, validate=validate,
|
||||||
vocab=vocab,
|
|
||||||
disable=disable,
|
|
||||||
overrides=overrides,
|
|
||||||
auto_fill=auto_fill,
|
|
||||||
validate=validate,
|
|
||||||
)
|
)
|
||||||
return nlp, nlp.resolved
|
return nlp, nlp.resolved
|
||||||
|
|
||||||
|
@ -309,14 +299,10 @@ def load_model_from_init_py(
|
||||||
*,
|
*,
|
||||||
vocab: Union["Vocab", bool] = True,
|
vocab: Union["Vocab", bool] = True,
|
||||||
disable: Iterable[str] = tuple(),
|
disable: Iterable[str] = tuple(),
|
||||||
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
|
config: Union[Dict[str, Any], Config] = SimpleFrozenDict(),
|
||||||
) -> "Language":
|
) -> "Language":
|
||||||
"""Helper function to use in the `load()` method of a model package's
|
"""Helper function to use in the `load()` method of a model package's
|
||||||
__init__.py.
|
__init__.py.
|
||||||
|
|
||||||
init_file (str): Path to model's __init__.py, i.e. `__file__`.
|
|
||||||
**overrides: Specific overrides, like pipeline components to disable.
|
|
||||||
RETURNS (Language): `Language` class with loaded model.
|
|
||||||
"""
|
"""
|
||||||
model_path = Path(init_file).parent
|
model_path = Path(init_file).parent
|
||||||
meta = get_model_meta(model_path)
|
meta = get_model_meta(model_path)
|
||||||
|
@ -325,7 +311,7 @@ def load_model_from_init_py(
|
||||||
if not model_path.exists():
|
if not model_path.exists():
|
||||||
raise IOError(Errors.E052.format(path=data_path))
|
raise IOError(Errors.E052.format(path=data_path))
|
||||||
return load_model_from_path(
|
return load_model_from_path(
|
||||||
data_path, vocab=vocab, meta=meta, disable=disable, component_cfg=component_cfg
|
data_path, vocab=vocab, meta=meta, disable=disable, config=config
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -32,13 +32,13 @@ loaded in via [`Language.from_disk`](/api/language#from_disk).
|
||||||
> nlp = spacy.load("en_core_web_sm", disable=["parser", "tagger"])
|
> nlp = spacy.load("en_core_web_sm", disable=["parser", "tagger"])
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
| Name | Type | Description |
|
| Name | Type | Description |
|
||||||
| ------------------------------------------ | ----------------- | --------------------------------------------------------------------------------- |
|
| ----------------------------------- | ---------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
| `name` | str / `Path` | Model to load, i.e. package name or path. |
|
| `name` | str / `Path` | Model to load, i.e. package name or path. |
|
||||||
| _keyword-only_ | | |
|
| _keyword-only_ | | |
|
||||||
| `disable` | `List[str]` | Names of pipeline components to [disable](/usage/processing-pipelines#disabling). |
|
| `disable` | `List[str]` | Names of pipeline components to [disable](/usage/processing-pipelines#disabling). |
|
||||||
| `component_cfg` <Tag variant="new">3</Tag> | `Dict[str, dict]` | Optional config overrides for pipeline components, keyed by component names. |
|
| `config` <Tag variant="new">3</Tag> | `Dict[str, Any]` / [`Config`](https://thinc.ai/docs/api-config#config) | Optional config overrides, either as nested dict or dict keyed by section value in dot notation, e.g. `"components.name.value"`. |
|
||||||
| **RETURNS** | `Language` | A `Language` object with the loaded model. |
|
| **RETURNS** | `Language` | A `Language` object with the loaded model. |
|
||||||
|
|
||||||
Essentially, `spacy.load()` is a convenience wrapper that reads the language ID
|
Essentially, `spacy.load()` is a convenience wrapper that reads the language ID
|
||||||
and pipeline components from a model's `meta.json`, initializes the `Language`
|
and pipeline components from a model's `meta.json`, initializes the `Language`
|
||||||
|
|
Loading…
Reference in New Issue
Block a user