Allow adding pipeline components from source model (#5857)

* Allow adding pipeline components from source model

* Config: name -> component

* Improve error messages

* Fix error and test

* Add frozen components and exclude logic

* Remove exclude from Language.evaluate

* Init sourced components with current vocab

* Fix error codes
This commit is contained in:
Ines Montani 2020-08-04 23:39:19 +02:00 committed by GitHub
parent 34873c4911
commit b795f02fbd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 247 additions and 80 deletions

View File

@ -19,6 +19,7 @@ max_epochs = 0
patience = 10000
eval_frequency = 200
score_weights = {"dep_las": 0.4, "ents_f": 0.4, "tag_acc": 0.2}
frozen_components = []
[training.train_corpus]
@readers = "spacy.Corpus.v1"

View File

@ -6,7 +6,7 @@ import hashlib
import typer
from typer.main import get_command
from contextlib import contextmanager
from thinc.config import ConfigValidationError
from thinc.config import Config, ConfigValidationError
from configparser import InterpolationError
import sys
@ -217,3 +217,15 @@ def import_code(code_path: Optional[Union[Path, str]]) -> None:
import_file("python_code", code_path)
except Exception as e:
msg.fail(f"Couldn't load Python code: {code_path}", e, exits=1)
def get_sourced_components(config: Union[Dict[str, Any], Config]) -> List[str]:
"""RETURNS (List[str]): All sourced components in the original config,
e.g. {"source": "en_core_web_sm"}. If the config contains a key
"factory", we assume it refers to a component factory.
"""
return [
name
for name, cfg in config.get("components", {}).items()
if "factory" not in cfg and "source" in cfg
]

View File

@ -8,7 +8,7 @@ import typer
from thinc.api import Config
from ._util import app, Arg, Opt, show_validation_error, parse_config_overrides
from ._util import import_code, debug_cli
from ._util import import_code, debug_cli, get_sourced_components
from ..gold import Corpus, Example
from ..pipeline._parser_internals import nonproj
from ..language import Language
@ -138,9 +138,10 @@ def debug_data(
with show_validation_error(config_path):
cfg = Config().from_disk(config_path)
nlp, config = util.load_model_from_config(cfg, overrides=config_overrides)
# TODO: handle base model
lang = config["nlp"]["lang"]
base_model = config["training"]["base_model"]
# Use original config here, not resolved version
sourced_components = get_sourced_components(cfg)
frozen_components = config["training"]["frozen_components"]
resume_components = [p for p in sourced_components if p not in frozen_components]
pipeline = nlp.pipe_names
factory_names = [nlp.get_pipe_meta(pipe).factory for pipe in nlp.pipe_names]
tag_map_path = util.ensure_path(config["training"]["tag_map"])
@ -187,13 +188,15 @@ def debug_data(
train_texts = gold_train_data["texts"]
dev_texts = gold_dev_data["texts"]
frozen_components = config["training"]["frozen_components"]
msg.divider("Training stats")
msg.text(f"Language: {config['nlp']['lang']}")
msg.text(f"Training pipeline: {', '.join(pipeline)}")
if base_model:
msg.text(f"Starting with base model '{base_model}'")
else:
msg.text(f"Starting with blank model '{lang}'")
if resume_components:
msg.text(f"Components from other models: {', '.join(resume_components)}")
if frozen_components:
msg.text(f"Frozen components: {', '.join(frozen_components)}")
msg.text(f"{len(train_dataset)} training docs")
msg.text(f"{len(dev_dataset)} evaluation docs")
@ -204,7 +207,9 @@ def debug_data(
msg.warn(f"{overlap} training examples also in evaluation data")
else:
msg.good("No overlap between training and evaluation data")
if not base_model and len(train_dataset) < BLANK_MODEL_THRESHOLD:
# TODO: make this feedback more fine-grained and report on updated
# components vs. blank components
if not resume_components and len(train_dataset) < BLANK_MODEL_THRESHOLD:
text = (
f"Low number of examples to train from a blank model ({len(train_dataset)})"
)

View File

@ -11,7 +11,7 @@ import random
import typer
from ._util import app, Arg, Opt, parse_config_overrides, show_validation_error
from ._util import import_code
from ._util import import_code, get_sourced_components
from ..language import Language
from .. import util
from ..gold.example import Example
@ -78,6 +78,8 @@ def train(
config = Config().from_disk(config_path)
if config.get("training", {}).get("seed") is not None:
fix_random_seed(config["training"]["seed"])
# Use original config here before it's resolved to functions
sourced_components = get_sourced_components(config)
with show_validation_error(config_path):
nlp, config = util.load_model_from_config(config, overrides=config_overrides)
if config["training"]["vectors"] is not None:
@ -92,11 +94,16 @@ def train(
train_corpus = T_cfg["train_corpus"]
dev_corpus = T_cfg["dev_corpus"]
batcher = T_cfg["batcher"]
if resume_training:
msg.info("Resuming training")
nlp.resume_training()
else:
msg.info(f"Initializing the nlp pipeline: {nlp.pipe_names}")
# Components that shouldn't be updated during training
frozen_components = T_cfg["frozen_components"]
# Sourced components that require resume_training
resume_components = [p for p in sourced_components if p not in frozen_components]
msg.info(f"Pipeline: {nlp.pipe_names}")
if resume_components:
with nlp.select_pipes(enable=resume_components):
msg.info(f"Resuming training for: {resume_components}")
nlp.resume_training()
with nlp.select_pipes(disable=[*frozen_components, *resume_components]):
nlp.begin_training(lambda: train_corpus(nlp))
if tag_map:
@ -136,7 +143,7 @@ def train(
patience=T_cfg["patience"],
max_steps=T_cfg["max_steps"],
eval_frequency=T_cfg["eval_frequency"],
raw_text=None
raw_text=None,
)
msg.info(f"Training. Initial learn rate: {optimizer.learn_rate}")
print_row = setup_printer(T_cfg, nlp)
@ -192,9 +199,7 @@ def create_train_batches(iterator, batcher, max_epochs: int):
def create_evaluation_callback(
nlp: Language,
dev_corpus: Callable,
weights: Dict[str, float],
nlp: Language, dev_corpus: Callable, weights: Dict[str, float],
) -> Callable[[], Tuple[float, Dict[str, float]]]:
def evaluate() -> Tuple[float, Dict[str, float]]:
dev_examples = list(dev_corpus(nlp))
@ -223,6 +228,7 @@ def train_while_improving(
patience: int,
max_steps: int,
raw_text: List[Dict[str, str]],
exclude: List[str],
):
"""Train until an evaluation stops improving. Works as a generator,
with each iteration yielding a tuple `(batch, info, is_best_checkpoint)`,
@ -268,8 +274,6 @@ def train_while_improving(
dropouts = dropout
results = []
losses = {}
to_enable = [name for name, proc in nlp.pipeline if hasattr(proc, "model")]
if raw_text:
random.shuffle(raw_text)
raw_examples = [
@ -279,17 +283,19 @@ def train_while_improving(
for step, (epoch, batch) in enumerate(train_data):
dropout = next(dropouts)
with nlp.select_pipes(enable=to_enable):
for subbatch in subdivide_batch(batch, accumulate_gradient):
nlp.update(subbatch, drop=dropout, losses=losses, sgd=False)
if raw_text:
# If raw text is available, perform 'rehearsal' updates,
# which use unlabelled data to reduce overfitting.
raw_batch = list(next(raw_batches))
nlp.rehearse(raw_batch, sgd=optimizer, losses=losses)
for name, proc in nlp.pipeline:
if hasattr(proc, "model"):
proc.model.finish_update(optimizer)
for subbatch in subdivide_batch(batch, accumulate_gradient):
nlp.update(
subbatch, drop=dropout, losses=losses, sgd=False, exclude=exclude
)
if raw_text:
# If raw text is available, perform 'rehearsal' updates,
# which use unlabelled data to reduce overfitting.
raw_batch = list(next(raw_batches))
nlp.rehearse(raw_batch, sgd=optimizer, losses=losses, exclude=exclude)
# TODO: refactor this so we don't have to run it separately in here
for name, proc in nlp.pipeline:
if name not in exclude and hasattr(proc, "model"):
proc.model.finish_update(optimizer)
optimizer.step_schedules()
if not (step % eval_frequency):
if optimizer.averages:
@ -418,10 +424,7 @@ def load_from_paths(
return raw_text, tag_map, morph_rules, weights_data
def verify_cli_args(
config_path: Path,
output_path: Optional[Path] = None,
) -> None:
def verify_cli_args(config_path: Path, output_path: Optional[Path] = None,) -> None:
# Make sure all files and paths exists if they are needed
if not config_path or not config_path.exists():
msg.fail("Config file not found", config_path, exits=1)

View File

@ -37,6 +37,8 @@ max_steps = 20000
eval_frequency = 200
# Control how scores are printed and checkpoints are evaluated.
score_weights = {}
# Names of pipeline components that shouldn't be updated during training
frozen_components = []
[training.train_corpus]
@readers = "spacy.Corpus.v1"

View File

@ -482,6 +482,10 @@ class Errors:
E199 = ("Unable to merge 0-length span at doc[{start}:{end}].")
# TODO: fix numbering after merging develop into master
E944 = ("Can't copy pipeline component '{name}' from source model '{model}': "
"not found in pipeline. Available components: {opts}")
E945 = ("Can't copy pipeline component '{name}' from source. Expected loaded "
"nlp object, but got: {source}")
E946 = ("The Vocab for the knowledge base is not initialized. Did you forget to "
"call kb.initialize()?")
E947 = ("Matcher.add received invalid 'greedy' argument: expected "
@ -571,11 +575,13 @@ class Errors:
"into {values}, but found {value}.")
E983 = ("Invalid key for '{dict}': {key}. Available keys: "
"{keys}")
E984 = ("Invalid component config for '{name}': no 'factory' key "
"specifying the registered function used to initialize the "
"component. For example, factory = \"ner\" will use the 'ner' "
"factory and all other settings in the block will be passed "
"to it as arguments.\n\n{config}")
E984 = ("Invalid component config for '{name}': component block needs either "
"a key 'factory' specifying the registered function used to "
"initialize the component, or a key 'source' key specifying a "
"spaCy model to copy the component from. For example, factory = "
"\"ner\" will use the 'ner' factory and all other settings in the "
"block will be passed to it as arguments. Alternatively, source = "
"\"en_core_web_sm\" will copy the component from that model.\n\n{config}")
E985 = ("Can't load model from config file: no 'nlp' section found.\n\n{config}")
E986 = ("Could not create any training batches: check your input. "
"Perhaps discard_oversize should be set to False ?")

View File

@ -620,6 +620,32 @@ class Language:
self._pipe_configs[name] = filled
return resolved[factory_name]
def create_pipe_from_source(
self, source_name: str, source: "Language", *, name: str,
) -> Tuple[Callable[[Doc], Doc], str]:
"""Create a pipeline component by copying it from an existing model.
source_name (str): Name of the component in the source pipeline.
source (Language): The source nlp object to copy from.
name (str): Optional alternative name to use in current pipeline.
RETURNS (Tuple[Callable, str]): The component and its factory name.
"""
# TODO: handle errors and mismatches (vectors etc.)
if not isinstance(source, self.__class__):
raise ValueError(Errors.E945.format(name=source_name, source=type(source)))
if not source.has_pipe(source_name):
raise KeyError(
Errors.E944.format(
name=source_name,
model=f"{source.meta['lang']}_{source.meta['name']}",
opts=", ".join(source.pipe_names),
)
)
pipe = source.get_pipe(source_name)
pipe_config = util.copy_config(source.config["components"][source_name])
self._pipe_configs[name] = pipe_config
return pipe, pipe_config["factory"]
def add_pipe(
self,
factory_name: str,
@ -629,6 +655,7 @@ class Language:
after: Optional[Union[str, int]] = None,
first: Optional[bool] = None,
last: Optional[bool] = None,
source: Optional["Language"] = None,
config: Optional[Dict[str, Any]] = SimpleFrozenDict(),
overrides: Optional[Dict[str, Any]] = SimpleFrozenDict(),
validate: bool = True,
@ -648,6 +675,8 @@ class Language:
component directly after.
first (bool): If True, insert component first in the pipeline.
last (bool): If True, insert component last in the pipeline.
source (Language): Optional loaded nlp object to copy the pipeline
component from.
config (Optional[Dict[str, Any]]): Config parameters to use for this
component. Will be merged with default config, if available.
overrides (Optional[Dict[str, Any]]): Config overrides, typically
@ -662,24 +691,31 @@ class Language:
bad_val = repr(factory_name)
err = Errors.E966.format(component=bad_val, name=name)
raise ValueError(err)
if not self.has_factory(factory_name):
err = Errors.E002.format(
name=factory_name,
opts=", ".join(self.factory_names),
method="add_pipe",
lang=util.get_object_name(self),
lang_code=self.lang,
)
name = name if name is not None else factory_name
if name in self.pipe_names:
raise ValueError(Errors.E007.format(name=name, opts=self.pipe_names))
pipe_component = self.create_pipe(
factory_name,
name=name,
config=config,
overrides=overrides,
validate=validate,
)
if source is not None:
# We're loading the component from a model. After loading the
# component, we know its real factory name
pipe_component, factory_name = self.create_pipe_from_source(
factory_name, source, name=name
)
else:
if not self.has_factory(factory_name):
err = Errors.E002.format(
name=factory_name,
opts=", ".join(self.factory_names),
method="add_pipe",
lang=util.get_object_name(self),
lang_code=self.lang,
)
pipe_component = self.create_pipe(
factory_name,
name=name,
config=config,
overrides=overrides,
validate=validate,
)
pipe_index = self._get_pipe_index(before, after, first, last)
self._pipe_meta[name] = self.get_factory_meta(factory_name)
self.pipeline.insert(pipe_index, (name, pipe_component))
@ -911,6 +947,7 @@ class Language:
sgd: Optional[Optimizer] = None,
losses: Optional[Dict[str, float]] = None,
component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
exclude: Iterable[str] = tuple(),
):
"""Update the models in the pipeline.
@ -921,6 +958,7 @@ class Language:
losses (Dict[str, float]): Dictionary to update with the loss, keyed by component.
component_cfg (Dict[str, Dict]): Config parameters for specific pipeline
components, keyed by component name.
exclude (Iterable[str]): Names of components that shouldn't be updated.
RETURNS (Dict[str, float]): The updated losses dictionary
DOCS: https://spacy.io/api/language#update
@ -953,12 +991,12 @@ class Language:
component_cfg[name].setdefault("drop", drop)
component_cfg[name].setdefault("set_annotations", False)
for name, proc in self.pipeline:
if not hasattr(proc, "update"):
if name in exclude or not hasattr(proc, "update"):
continue
proc.update(examples, sgd=None, losses=losses, **component_cfg[name])
if sgd not in (None, False):
for name, proc in self.pipeline:
if hasattr(proc, "model"):
if name not in exclude and hasattr(proc, "model"):
proc.model.finish_update(sgd)
return losses
@ -969,6 +1007,7 @@ class Language:
sgd: Optional[Optimizer] = None,
losses: Optional[Dict[str, float]] = None,
component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
exclude: Iterable[str] = tuple(),
) -> Dict[str, float]:
"""Make a "rehearsal" update to the models in the pipeline, to prevent
forgetting. Rehearsal updates run an initial copy of the model over some
@ -980,6 +1019,7 @@ class Language:
sgd (Optional[Optimizer]): An optimizer.
component_cfg (Dict[str, Dict]): Config parameters for specific pipeline
components, keyed by component name.
exclude (Iterable[str]): Names of components that shouldn't be updated.
RETURNS (dict): Results from the update.
EXAMPLE:
@ -1023,7 +1063,7 @@ class Language:
get_grads.b1 = sgd.b1
get_grads.b2 = sgd.b2
for name, proc in pipes:
if not hasattr(proc, "rehearse"):
if name in exclude or not hasattr(proc, "rehearse"):
continue
grads = {}
proc.rehearse(
@ -1074,7 +1114,7 @@ class Language:
return self._optimizer
def resume_training(
self, *, sgd: Optional[Optimizer] = None, device: int = -1
self, *, sgd: Optional[Optimizer] = None, device: int = -1,
) -> Optimizer:
"""Continue training a pretrained model.
@ -1373,6 +1413,7 @@ class Language:
cls,
config: Union[Dict[str, Any], Config] = {},
*,
vocab: Union[Vocab, bool] = True,
disable: Iterable[str] = tuple(),
overrides: Dict[str, Any] = {},
auto_fill: bool = True,
@ -1383,6 +1424,7 @@ class Language:
the default config of the given language is used.
config (Dict[str, Any] / Config): The loaded config.
vocab (Vocab): A Vocab object. If True, a vocab is created.
disable (Iterable[str]): List of pipeline component names to disable.
auto_fill (bool): Automatically fill in missing values in config based
on defaults and function argument annotations.
@ -1422,32 +1464,48 @@ class Language:
create_tokenizer = resolved["nlp"]["tokenizer"]
create_lemmatizer = resolved["nlp"]["lemmatizer"]
nlp = cls(
create_tokenizer=create_tokenizer, create_lemmatizer=create_lemmatizer,
vocab=vocab,
create_tokenizer=create_tokenizer,
create_lemmatizer=create_lemmatizer,
)
# Note that we don't load vectors here, instead they get loaded explicitly
# inside stuff like the spacy train function. If we loaded them here,
# then we would load them twice at runtime: once when we make from config,
# and then again when we load from disk.
pipeline = config.get("components", {})
# If components are loaded from a source (existing models), we cache
# them here so they're only loaded once
source_nlps = {}
for pipe_name in config["nlp"]["pipeline"]:
if pipe_name not in pipeline:
opts = ", ".join(pipeline.keys())
raise ValueError(Errors.E956.format(name=pipe_name, opts=opts))
pipe_cfg = util.copy_config(pipeline[pipe_name])
if pipe_name not in disable:
if "factory" not in pipe_cfg:
if "factory" not in pipe_cfg and "source" not in pipe_cfg:
err = Errors.E984.format(name=pipe_name, config=pipe_cfg)
raise ValueError(err)
factory = pipe_cfg.pop("factory")
# The pipe name (key in the config) here is the unique name of the
# component, not necessarily the factory
nlp.add_pipe(
factory,
name=pipe_name,
config=pipe_cfg,
overrides=pipe_overrides,
validate=validate,
)
if "factory" in pipe_cfg:
factory = pipe_cfg.pop("factory")
# The pipe name (key in the config) here is the unique name
# of the component, not necessarily the factory
nlp.add_pipe(
factory,
name=pipe_name,
config=pipe_cfg,
overrides=pipe_overrides,
validate=validate,
)
else:
model = pipe_cfg["source"]
if model not in source_nlps:
# We only need the components here and we need to init
# model with the same vocab as the current nlp object
source_nlps[model] = util.load_model(
model, vocab=nlp.vocab, disable=["vocab", "tokenizer"]
)
source_name = pipe_cfg.get("component", pipe_name)
nlp.add_pipe(source_name, source=source_nlps[model], name=pipe_name)
nlp.config = filled if auto_fill else config
nlp.resolved = resolved
return nlp

View File

@ -202,6 +202,7 @@ class ConfigSchemaTraining(BaseModel):
init_tok2vec: Optional[StrictStr] = Field(..., title="Path to pretrained tok2vec weights")
raw_text: Optional[StrictStr] = Field(default=None, title="Raw text")
optimizer: Optimizer = Field(..., title="The optimizer to use")
frozen_components: List[str] = Field(..., title="Pipeline components that shouldn't be updated during training")
# fmt: on
class Config:

View File

@ -8,6 +8,8 @@ from thinc.api import Model, Linear
from thinc.config import ConfigValidationError
from pydantic import StrictInt, StrictStr
from ..util import make_tempdir
def test_pipe_function_component():
name = "test_component"
@ -374,3 +376,65 @@ def test_language_factories_scores():
cfg = nlp.config["training"]
expected_weights = {"a1": 0.25, "a2": 0.25, "b1": 0.1, "b2": 0.35, "b3": 0.05}
assert cfg["score_weights"] == expected_weights
def test_pipe_factories_from_source():
"""Test adding components from a source model."""
source_nlp = English()
source_nlp.add_pipe("tagger", name="my_tagger")
nlp = English()
with pytest.raises(ValueError):
nlp.add_pipe("my_tagger", source="en_core_web_sm")
nlp.add_pipe("my_tagger", source=source_nlp)
assert "my_tagger" in nlp.pipe_names
with pytest.raises(KeyError):
nlp.add_pipe("custom", source=source_nlp)
def test_pipe_factories_from_source_custom():
"""Test adding components from a source model with custom components."""
name = "test_pipe_factories_from_source_custom"
@Language.factory(name, default_config={"arg": "hello"})
def test_factory(nlp, name, arg: str):
return lambda doc: doc
source_nlp = English()
source_nlp.add_pipe("tagger")
source_nlp.add_pipe(name, config={"arg": "world"})
nlp = English()
nlp.add_pipe(name, source=source_nlp)
assert name in nlp.pipe_names
assert nlp.get_pipe_meta(name).default_config["arg"] == "hello"
config = nlp.config["components"][name]
assert config["factory"] == name
assert config["arg"] == "world"
def test_pipe_factories_from_source_config():
name = "test_pipe_factories_from_source_config"
@Language.factory(name, default_config={"arg": "hello"})
def test_factory(nlp, name, arg: str):
return lambda doc: doc
source_nlp = English()
source_nlp.add_pipe("tagger")
source_nlp.add_pipe(name, name="yolo", config={"arg": "world"})
dest_nlp_cfg = {"lang": "en", "pipeline": ["parser", "custom"]}
with make_tempdir() as tempdir:
source_nlp.to_disk(tempdir)
dest_components_cfg = {
"parser": {"factory": "parser"},
"custom": {"source": str(tempdir), "component": "yolo"},
}
dest_config = {"nlp": dest_nlp_cfg, "components": dest_components_cfg}
nlp = English.from_config(dest_config)
assert nlp.pipe_names == ["parser", "custom"]
assert nlp.pipe_factories == {"parser": "parser", "custom": name}
meta = nlp.get_pipe_meta("custom")
assert meta.factory == name
assert meta.default_config["arg"] == "hello"
config = nlp.config["components"]["custom"]
assert config["factory"] == name
assert config["arg"] == "world"

View File

@ -205,43 +205,51 @@ def load_vectors_into_model(
def load_model(
name: Union[str, Path],
*,
vocab: Union["Vocab", bool] = True,
disable: Iterable[str] = tuple(),
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
) -> "Language":
"""Load a model from a package or data path.
name (str): Package name or model path.
vocab (Vocab / True): Optional vocab to pass in on initialization. If True,
a new Vocab object will be created.
disable (Iterable[str]): Names of pipeline components to disable.
component_cfg (Dict[str, dict]): Config overrides for pipeline components,
keyed by component names.
RETURNS (Language): The loaded nlp object.
"""
cfg = component_cfg
kwargs = {"vocab": vocab, "disable": disable, "component_cfg": component_cfg}
if isinstance(name, str): # name or string path
if name.startswith("blank:"): # shortcut for blank model
return get_lang_class(name.replace("blank:", ""))()
if is_package(name): # installed as package
return load_model_from_package(name, disable=disable, component_cfg=cfg)
return load_model_from_package(name, **kwargs)
if Path(name).exists(): # path to model data directory
return load_model_from_path(Path(name), disable=disable, component_cfg=cfg)
return load_model_from_path(Path(name), **kwargs)
elif hasattr(name, "exists"): # Path or Path-like to model data
return load_model_from_path(name, disable=disable, component_cfg=cfg)
return load_model_from_path(name, **kwargs)
raise IOError(Errors.E050.format(name=name))
def load_model_from_package(
name: str,
*,
vocab: Union["Vocab", bool] = True,
disable: Iterable[str] = tuple(),
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
) -> "Language":
"""Load a model from an installed package."""
cls = importlib.import_module(name)
return cls.load(disable=disable, component_cfg=component_cfg)
return cls.load(vocab=vocab, disable=disable, component_cfg=component_cfg)
def load_model_from_path(
model_path: Union[str, Path],
*,
meta: Optional[Dict[str, Any]] = None,
vocab: Union["Vocab", bool] = True,
disable: Iterable[str] = tuple(),
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
) -> "Language":
@ -257,12 +265,16 @@ def load_model_from_path(
config = Config().from_disk(config_path)
override_cfg = {"components": {p: dict_to_dot(c) for p, c in component_cfg.items()}}
overrides = dict_to_dot(override_cfg)
nlp, _ = load_model_from_config(config, disable=disable, overrides=overrides)
nlp, _ = load_model_from_config(
config, vocab=vocab, disable=disable, overrides=overrides
)
return nlp.from_disk(model_path, exclude=disable)
def load_model_from_config(
config: Union[Dict[str, Any], Config],
*,
vocab: Union["Vocab", bool] = True,
disable: Iterable[str] = tuple(),
overrides: Dict[str, Any] = {},
auto_fill: bool = False,
@ -281,6 +293,7 @@ def load_model_from_config(
lang_cls = get_lang_class(nlp_config["lang"])
nlp = lang_cls.from_config(
config,
vocab=vocab,
disable=disable,
overrides=overrides,
auto_fill=auto_fill,
@ -291,6 +304,8 @@ def load_model_from_config(
def load_model_from_init_py(
init_file: Union[Path, str],
*,
vocab: Union["Vocab", bool] = True,
disable: Iterable[str] = tuple(),
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
) -> "Language":
@ -308,7 +323,7 @@ def load_model_from_init_py(
if not model_path.exists():
raise IOError(Errors.E052.format(path=data_path))
return load_model_from_path(
data_path, meta, disable=disable, component_cfg=component_cfg
data_path, vocab=vocab, meta=meta, disable=disable, component_cfg=component_cfg
)