diff --git a/examples/experiments/onto-joint/defaults.cfg b/examples/experiments/onto-joint/defaults.cfg index f456e3fbe..3ab3ddaba 100644 --- a/examples/experiments/onto-joint/defaults.cfg +++ b/examples/experiments/onto-joint/defaults.cfg @@ -19,6 +19,7 @@ max_epochs = 0 patience = 10000 eval_frequency = 200 score_weights = {"dep_las": 0.4, "ents_f": 0.4, "tag_acc": 0.2} +frozen_components = [] [training.train_corpus] @readers = "spacy.Corpus.v1" diff --git a/spacy/cli/_util.py b/spacy/cli/_util.py index 0130e60bb..93ec9f31e 100644 --- a/spacy/cli/_util.py +++ b/spacy/cli/_util.py @@ -6,7 +6,7 @@ import hashlib import typer from typer.main import get_command from contextlib import contextmanager -from thinc.config import ConfigValidationError +from thinc.config import Config, ConfigValidationError from configparser import InterpolationError import sys @@ -217,3 +217,15 @@ def import_code(code_path: Optional[Union[Path, str]]) -> None: import_file("python_code", code_path) except Exception as e: msg.fail(f"Couldn't load Python code: {code_path}", e, exits=1) + + +def get_sourced_components(config: Union[Dict[str, Any], Config]) -> List[str]: + """RETURNS (List[str]): All sourced components in the original config, + e.g. {"source": "en_core_web_sm"}. If the config contains a key + "factory", we assume it refers to a component factory. + """ + return [ + name + for name, cfg in config.get("components", {}).items() + if "factory" not in cfg and "source" in cfg + ] diff --git a/spacy/cli/debug_data.py b/spacy/cli/debug_data.py index cf3822a59..1fd9fd813 100644 --- a/spacy/cli/debug_data.py +++ b/spacy/cli/debug_data.py @@ -8,7 +8,7 @@ import typer from thinc.api import Config from ._util import app, Arg, Opt, show_validation_error, parse_config_overrides -from ._util import import_code, debug_cli +from ._util import import_code, debug_cli, get_sourced_components from ..gold import Corpus, Example from ..pipeline._parser_internals import nonproj from ..language import Language @@ -138,9 +138,10 @@ def debug_data( with show_validation_error(config_path): cfg = Config().from_disk(config_path) nlp, config = util.load_model_from_config(cfg, overrides=config_overrides) - # TODO: handle base model - lang = config["nlp"]["lang"] - base_model = config["training"]["base_model"] + # Use original config here, not resolved version + sourced_components = get_sourced_components(cfg) + frozen_components = config["training"]["frozen_components"] + resume_components = [p for p in sourced_components if p not in frozen_components] pipeline = nlp.pipe_names factory_names = [nlp.get_pipe_meta(pipe).factory for pipe in nlp.pipe_names] tag_map_path = util.ensure_path(config["training"]["tag_map"]) @@ -187,13 +188,15 @@ def debug_data( train_texts = gold_train_data["texts"] dev_texts = gold_dev_data["texts"] + frozen_components = config["training"]["frozen_components"] msg.divider("Training stats") + msg.text(f"Language: {config['nlp']['lang']}") msg.text(f"Training pipeline: {', '.join(pipeline)}") - if base_model: - msg.text(f"Starting with base model '{base_model}'") - else: - msg.text(f"Starting with blank model '{lang}'") + if resume_components: + msg.text(f"Components from other models: {', '.join(resume_components)}") + if frozen_components: + msg.text(f"Frozen components: {', '.join(frozen_components)}") msg.text(f"{len(train_dataset)} training docs") msg.text(f"{len(dev_dataset)} evaluation docs") @@ -204,7 +207,9 @@ def debug_data( msg.warn(f"{overlap} training examples also in evaluation data") else: msg.good("No overlap between training and evaluation data") - if not base_model and len(train_dataset) < BLANK_MODEL_THRESHOLD: + # TODO: make this feedback more fine-grained and report on updated + # components vs. blank components + if not resume_components and len(train_dataset) < BLANK_MODEL_THRESHOLD: text = ( f"Low number of examples to train from a blank model ({len(train_dataset)})" ) diff --git a/spacy/cli/train.py b/spacy/cli/train.py index ca2bd04ab..9b071ed55 100644 --- a/spacy/cli/train.py +++ b/spacy/cli/train.py @@ -11,7 +11,7 @@ import random import typer from ._util import app, Arg, Opt, parse_config_overrides, show_validation_error -from ._util import import_code +from ._util import import_code, get_sourced_components from ..language import Language from .. import util from ..gold.example import Example @@ -78,6 +78,8 @@ def train( config = Config().from_disk(config_path) if config.get("training", {}).get("seed") is not None: fix_random_seed(config["training"]["seed"]) + # Use original config here before it's resolved to functions + sourced_components = get_sourced_components(config) with show_validation_error(config_path): nlp, config = util.load_model_from_config(config, overrides=config_overrides) if config["training"]["vectors"] is not None: @@ -92,11 +94,16 @@ def train( train_corpus = T_cfg["train_corpus"] dev_corpus = T_cfg["dev_corpus"] batcher = T_cfg["batcher"] - if resume_training: - msg.info("Resuming training") - nlp.resume_training() - else: - msg.info(f"Initializing the nlp pipeline: {nlp.pipe_names}") + # Components that shouldn't be updated during training + frozen_components = T_cfg["frozen_components"] + # Sourced components that require resume_training + resume_components = [p for p in sourced_components if p not in frozen_components] + msg.info(f"Pipeline: {nlp.pipe_names}") + if resume_components: + with nlp.select_pipes(enable=resume_components): + msg.info(f"Resuming training for: {resume_components}") + nlp.resume_training() + with nlp.select_pipes(disable=[*frozen_components, *resume_components]): nlp.begin_training(lambda: train_corpus(nlp)) if tag_map: @@ -136,7 +143,7 @@ def train( patience=T_cfg["patience"], max_steps=T_cfg["max_steps"], eval_frequency=T_cfg["eval_frequency"], - raw_text=None + raw_text=None, ) msg.info(f"Training. Initial learn rate: {optimizer.learn_rate}") print_row = setup_printer(T_cfg, nlp) @@ -192,9 +199,7 @@ def create_train_batches(iterator, batcher, max_epochs: int): def create_evaluation_callback( - nlp: Language, - dev_corpus: Callable, - weights: Dict[str, float], + nlp: Language, dev_corpus: Callable, weights: Dict[str, float], ) -> Callable[[], Tuple[float, Dict[str, float]]]: def evaluate() -> Tuple[float, Dict[str, float]]: dev_examples = list(dev_corpus(nlp)) @@ -223,6 +228,7 @@ def train_while_improving( patience: int, max_steps: int, raw_text: List[Dict[str, str]], + exclude: List[str], ): """Train until an evaluation stops improving. Works as a generator, with each iteration yielding a tuple `(batch, info, is_best_checkpoint)`, @@ -268,8 +274,6 @@ def train_while_improving( dropouts = dropout results = [] losses = {} - to_enable = [name for name, proc in nlp.pipeline if hasattr(proc, "model")] - if raw_text: random.shuffle(raw_text) raw_examples = [ @@ -279,17 +283,19 @@ def train_while_improving( for step, (epoch, batch) in enumerate(train_data): dropout = next(dropouts) - with nlp.select_pipes(enable=to_enable): - for subbatch in subdivide_batch(batch, accumulate_gradient): - nlp.update(subbatch, drop=dropout, losses=losses, sgd=False) - if raw_text: - # If raw text is available, perform 'rehearsal' updates, - # which use unlabelled data to reduce overfitting. - raw_batch = list(next(raw_batches)) - nlp.rehearse(raw_batch, sgd=optimizer, losses=losses) - for name, proc in nlp.pipeline: - if hasattr(proc, "model"): - proc.model.finish_update(optimizer) + for subbatch in subdivide_batch(batch, accumulate_gradient): + nlp.update( + subbatch, drop=dropout, losses=losses, sgd=False, exclude=exclude + ) + if raw_text: + # If raw text is available, perform 'rehearsal' updates, + # which use unlabelled data to reduce overfitting. + raw_batch = list(next(raw_batches)) + nlp.rehearse(raw_batch, sgd=optimizer, losses=losses, exclude=exclude) + # TODO: refactor this so we don't have to run it separately in here + for name, proc in nlp.pipeline: + if name not in exclude and hasattr(proc, "model"): + proc.model.finish_update(optimizer) optimizer.step_schedules() if not (step % eval_frequency): if optimizer.averages: @@ -418,10 +424,7 @@ def load_from_paths( return raw_text, tag_map, morph_rules, weights_data -def verify_cli_args( - config_path: Path, - output_path: Optional[Path] = None, -) -> None: +def verify_cli_args(config_path: Path, output_path: Optional[Path] = None,) -> None: # Make sure all files and paths exists if they are needed if not config_path or not config_path.exists(): msg.fail("Config file not found", config_path, exits=1) diff --git a/spacy/default_config.cfg b/spacy/default_config.cfg index 1c56810e3..f35be605c 100644 --- a/spacy/default_config.cfg +++ b/spacy/default_config.cfg @@ -37,6 +37,8 @@ max_steps = 20000 eval_frequency = 200 # Control how scores are printed and checkpoints are evaluated. score_weights = {} +# Names of pipeline components that shouldn't be updated during training +frozen_components = [] [training.train_corpus] @readers = "spacy.Corpus.v1" diff --git a/spacy/errors.py b/spacy/errors.py index 4ae43a497..973843bb7 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -482,6 +482,10 @@ class Errors: E199 = ("Unable to merge 0-length span at doc[{start}:{end}].") # TODO: fix numbering after merging develop into master + E944 = ("Can't copy pipeline component '{name}' from source model '{model}': " + "not found in pipeline. Available components: {opts}") + E945 = ("Can't copy pipeline component '{name}' from source. Expected loaded " + "nlp object, but got: {source}") E946 = ("The Vocab for the knowledge base is not initialized. Did you forget to " "call kb.initialize()?") E947 = ("Matcher.add received invalid 'greedy' argument: expected " @@ -571,11 +575,13 @@ class Errors: "into {values}, but found {value}.") E983 = ("Invalid key for '{dict}': {key}. Available keys: " "{keys}") - E984 = ("Invalid component config for '{name}': no 'factory' key " - "specifying the registered function used to initialize the " - "component. For example, factory = \"ner\" will use the 'ner' " - "factory and all other settings in the block will be passed " - "to it as arguments.\n\n{config}") + E984 = ("Invalid component config for '{name}': component block needs either " + "a key 'factory' specifying the registered function used to " + "initialize the component, or a key 'source' key specifying a " + "spaCy model to copy the component from. For example, factory = " + "\"ner\" will use the 'ner' factory and all other settings in the " + "block will be passed to it as arguments. Alternatively, source = " + "\"en_core_web_sm\" will copy the component from that model.\n\n{config}") E985 = ("Can't load model from config file: no 'nlp' section found.\n\n{config}") E986 = ("Could not create any training batches: check your input. " "Perhaps discard_oversize should be set to False ?") diff --git a/spacy/language.py b/spacy/language.py index d1b180cef..4196e25a4 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -620,6 +620,32 @@ class Language: self._pipe_configs[name] = filled return resolved[factory_name] + def create_pipe_from_source( + self, source_name: str, source: "Language", *, name: str, + ) -> Tuple[Callable[[Doc], Doc], str]: + """Create a pipeline component by copying it from an existing model. + + source_name (str): Name of the component in the source pipeline. + source (Language): The source nlp object to copy from. + name (str): Optional alternative name to use in current pipeline. + RETURNS (Tuple[Callable, str]): The component and its factory name. + """ + # TODO: handle errors and mismatches (vectors etc.) + if not isinstance(source, self.__class__): + raise ValueError(Errors.E945.format(name=source_name, source=type(source))) + if not source.has_pipe(source_name): + raise KeyError( + Errors.E944.format( + name=source_name, + model=f"{source.meta['lang']}_{source.meta['name']}", + opts=", ".join(source.pipe_names), + ) + ) + pipe = source.get_pipe(source_name) + pipe_config = util.copy_config(source.config["components"][source_name]) + self._pipe_configs[name] = pipe_config + return pipe, pipe_config["factory"] + def add_pipe( self, factory_name: str, @@ -629,6 +655,7 @@ class Language: after: Optional[Union[str, int]] = None, first: Optional[bool] = None, last: Optional[bool] = None, + source: Optional["Language"] = None, config: Optional[Dict[str, Any]] = SimpleFrozenDict(), overrides: Optional[Dict[str, Any]] = SimpleFrozenDict(), validate: bool = True, @@ -648,6 +675,8 @@ class Language: component directly after. first (bool): If True, insert component first in the pipeline. last (bool): If True, insert component last in the pipeline. + source (Language): Optional loaded nlp object to copy the pipeline + component from. config (Optional[Dict[str, Any]]): Config parameters to use for this component. Will be merged with default config, if available. overrides (Optional[Dict[str, Any]]): Config overrides, typically @@ -662,24 +691,31 @@ class Language: bad_val = repr(factory_name) err = Errors.E966.format(component=bad_val, name=name) raise ValueError(err) - if not self.has_factory(factory_name): - err = Errors.E002.format( - name=factory_name, - opts=", ".join(self.factory_names), - method="add_pipe", - lang=util.get_object_name(self), - lang_code=self.lang, - ) name = name if name is not None else factory_name if name in self.pipe_names: raise ValueError(Errors.E007.format(name=name, opts=self.pipe_names)) - pipe_component = self.create_pipe( - factory_name, - name=name, - config=config, - overrides=overrides, - validate=validate, - ) + if source is not None: + # We're loading the component from a model. After loading the + # component, we know its real factory name + pipe_component, factory_name = self.create_pipe_from_source( + factory_name, source, name=name + ) + else: + if not self.has_factory(factory_name): + err = Errors.E002.format( + name=factory_name, + opts=", ".join(self.factory_names), + method="add_pipe", + lang=util.get_object_name(self), + lang_code=self.lang, + ) + pipe_component = self.create_pipe( + factory_name, + name=name, + config=config, + overrides=overrides, + validate=validate, + ) pipe_index = self._get_pipe_index(before, after, first, last) self._pipe_meta[name] = self.get_factory_meta(factory_name) self.pipeline.insert(pipe_index, (name, pipe_component)) @@ -911,6 +947,7 @@ class Language: sgd: Optional[Optimizer] = None, losses: Optional[Dict[str, float]] = None, component_cfg: Optional[Dict[str, Dict[str, Any]]] = None, + exclude: Iterable[str] = tuple(), ): """Update the models in the pipeline. @@ -921,6 +958,7 @@ class Language: losses (Dict[str, float]): Dictionary to update with the loss, keyed by component. component_cfg (Dict[str, Dict]): Config parameters for specific pipeline components, keyed by component name. + exclude (Iterable[str]): Names of components that shouldn't be updated. RETURNS (Dict[str, float]): The updated losses dictionary DOCS: https://spacy.io/api/language#update @@ -953,12 +991,12 @@ class Language: component_cfg[name].setdefault("drop", drop) component_cfg[name].setdefault("set_annotations", False) for name, proc in self.pipeline: - if not hasattr(proc, "update"): + if name in exclude or not hasattr(proc, "update"): continue proc.update(examples, sgd=None, losses=losses, **component_cfg[name]) if sgd not in (None, False): for name, proc in self.pipeline: - if hasattr(proc, "model"): + if name not in exclude and hasattr(proc, "model"): proc.model.finish_update(sgd) return losses @@ -969,6 +1007,7 @@ class Language: sgd: Optional[Optimizer] = None, losses: Optional[Dict[str, float]] = None, component_cfg: Optional[Dict[str, Dict[str, Any]]] = None, + exclude: Iterable[str] = tuple(), ) -> Dict[str, float]: """Make a "rehearsal" update to the models in the pipeline, to prevent forgetting. Rehearsal updates run an initial copy of the model over some @@ -980,6 +1019,7 @@ class Language: sgd (Optional[Optimizer]): An optimizer. component_cfg (Dict[str, Dict]): Config parameters for specific pipeline components, keyed by component name. + exclude (Iterable[str]): Names of components that shouldn't be updated. RETURNS (dict): Results from the update. EXAMPLE: @@ -1023,7 +1063,7 @@ class Language: get_grads.b1 = sgd.b1 get_grads.b2 = sgd.b2 for name, proc in pipes: - if not hasattr(proc, "rehearse"): + if name in exclude or not hasattr(proc, "rehearse"): continue grads = {} proc.rehearse( @@ -1074,7 +1114,7 @@ class Language: return self._optimizer def resume_training( - self, *, sgd: Optional[Optimizer] = None, device: int = -1 + self, *, sgd: Optional[Optimizer] = None, device: int = -1, ) -> Optimizer: """Continue training a pretrained model. @@ -1373,6 +1413,7 @@ class Language: cls, config: Union[Dict[str, Any], Config] = {}, *, + vocab: Union[Vocab, bool] = True, disable: Iterable[str] = tuple(), overrides: Dict[str, Any] = {}, auto_fill: bool = True, @@ -1383,6 +1424,7 @@ class Language: the default config of the given language is used. config (Dict[str, Any] / Config): The loaded config. + vocab (Vocab): A Vocab object. If True, a vocab is created. disable (Iterable[str]): List of pipeline component names to disable. auto_fill (bool): Automatically fill in missing values in config based on defaults and function argument annotations. @@ -1422,32 +1464,48 @@ class Language: create_tokenizer = resolved["nlp"]["tokenizer"] create_lemmatizer = resolved["nlp"]["lemmatizer"] nlp = cls( - create_tokenizer=create_tokenizer, create_lemmatizer=create_lemmatizer, + vocab=vocab, + create_tokenizer=create_tokenizer, + create_lemmatizer=create_lemmatizer, ) # Note that we don't load vectors here, instead they get loaded explicitly # inside stuff like the spacy train function. If we loaded them here, # then we would load them twice at runtime: once when we make from config, # and then again when we load from disk. pipeline = config.get("components", {}) + # If components are loaded from a source (existing models), we cache + # them here so they're only loaded once + source_nlps = {} for pipe_name in config["nlp"]["pipeline"]: if pipe_name not in pipeline: opts = ", ".join(pipeline.keys()) raise ValueError(Errors.E956.format(name=pipe_name, opts=opts)) pipe_cfg = util.copy_config(pipeline[pipe_name]) if pipe_name not in disable: - if "factory" not in pipe_cfg: + if "factory" not in pipe_cfg and "source" not in pipe_cfg: err = Errors.E984.format(name=pipe_name, config=pipe_cfg) raise ValueError(err) - factory = pipe_cfg.pop("factory") - # The pipe name (key in the config) here is the unique name of the - # component, not necessarily the factory - nlp.add_pipe( - factory, - name=pipe_name, - config=pipe_cfg, - overrides=pipe_overrides, - validate=validate, - ) + if "factory" in pipe_cfg: + factory = pipe_cfg.pop("factory") + # The pipe name (key in the config) here is the unique name + # of the component, not necessarily the factory + nlp.add_pipe( + factory, + name=pipe_name, + config=pipe_cfg, + overrides=pipe_overrides, + validate=validate, + ) + else: + model = pipe_cfg["source"] + if model not in source_nlps: + # We only need the components here and we need to init + # model with the same vocab as the current nlp object + source_nlps[model] = util.load_model( + model, vocab=nlp.vocab, disable=["vocab", "tokenizer"] + ) + source_name = pipe_cfg.get("component", pipe_name) + nlp.add_pipe(source_name, source=source_nlps[model], name=pipe_name) nlp.config = filled if auto_fill else config nlp.resolved = resolved return nlp diff --git a/spacy/schemas.py b/spacy/schemas.py index 413daed7f..eea2d3dc3 100644 --- a/spacy/schemas.py +++ b/spacy/schemas.py @@ -202,6 +202,7 @@ class ConfigSchemaTraining(BaseModel): init_tok2vec: Optional[StrictStr] = Field(..., title="Path to pretrained tok2vec weights") raw_text: Optional[StrictStr] = Field(default=None, title="Raw text") optimizer: Optimizer = Field(..., title="The optimizer to use") + frozen_components: List[str] = Field(..., title="Pipeline components that shouldn't be updated during training") # fmt: on class Config: diff --git a/spacy/tests/pipeline/test_pipe_factories.py b/spacy/tests/pipeline/test_pipe_factories.py index 64c6c2d6f..9948f6bcd 100644 --- a/spacy/tests/pipeline/test_pipe_factories.py +++ b/spacy/tests/pipeline/test_pipe_factories.py @@ -8,6 +8,8 @@ from thinc.api import Model, Linear from thinc.config import ConfigValidationError from pydantic import StrictInt, StrictStr +from ..util import make_tempdir + def test_pipe_function_component(): name = "test_component" @@ -374,3 +376,65 @@ def test_language_factories_scores(): cfg = nlp.config["training"] expected_weights = {"a1": 0.25, "a2": 0.25, "b1": 0.1, "b2": 0.35, "b3": 0.05} assert cfg["score_weights"] == expected_weights + + +def test_pipe_factories_from_source(): + """Test adding components from a source model.""" + source_nlp = English() + source_nlp.add_pipe("tagger", name="my_tagger") + nlp = English() + with pytest.raises(ValueError): + nlp.add_pipe("my_tagger", source="en_core_web_sm") + nlp.add_pipe("my_tagger", source=source_nlp) + assert "my_tagger" in nlp.pipe_names + with pytest.raises(KeyError): + nlp.add_pipe("custom", source=source_nlp) + + +def test_pipe_factories_from_source_custom(): + """Test adding components from a source model with custom components.""" + name = "test_pipe_factories_from_source_custom" + + @Language.factory(name, default_config={"arg": "hello"}) + def test_factory(nlp, name, arg: str): + return lambda doc: doc + + source_nlp = English() + source_nlp.add_pipe("tagger") + source_nlp.add_pipe(name, config={"arg": "world"}) + nlp = English() + nlp.add_pipe(name, source=source_nlp) + assert name in nlp.pipe_names + assert nlp.get_pipe_meta(name).default_config["arg"] == "hello" + config = nlp.config["components"][name] + assert config["factory"] == name + assert config["arg"] == "world" + + +def test_pipe_factories_from_source_config(): + name = "test_pipe_factories_from_source_config" + + @Language.factory(name, default_config={"arg": "hello"}) + def test_factory(nlp, name, arg: str): + return lambda doc: doc + + source_nlp = English() + source_nlp.add_pipe("tagger") + source_nlp.add_pipe(name, name="yolo", config={"arg": "world"}) + dest_nlp_cfg = {"lang": "en", "pipeline": ["parser", "custom"]} + with make_tempdir() as tempdir: + source_nlp.to_disk(tempdir) + dest_components_cfg = { + "parser": {"factory": "parser"}, + "custom": {"source": str(tempdir), "component": "yolo"}, + } + dest_config = {"nlp": dest_nlp_cfg, "components": dest_components_cfg} + nlp = English.from_config(dest_config) + assert nlp.pipe_names == ["parser", "custom"] + assert nlp.pipe_factories == {"parser": "parser", "custom": name} + meta = nlp.get_pipe_meta("custom") + assert meta.factory == name + assert meta.default_config["arg"] == "hello" + config = nlp.config["components"]["custom"] + assert config["factory"] == name + assert config["arg"] == "world" diff --git a/spacy/util.py b/spacy/util.py index d9e67440f..4e84b6a6b 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -205,43 +205,51 @@ def load_vectors_into_model( def load_model( name: Union[str, Path], + *, + vocab: Union["Vocab", bool] = True, disable: Iterable[str] = tuple(), component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(), ) -> "Language": """Load a model from a package or data path. name (str): Package name or model path. + vocab (Vocab / True): Optional vocab to pass in on initialization. If True, + a new Vocab object will be created. disable (Iterable[str]): Names of pipeline components to disable. component_cfg (Dict[str, dict]): Config overrides for pipeline components, keyed by component names. RETURNS (Language): The loaded nlp object. """ - cfg = component_cfg + kwargs = {"vocab": vocab, "disable": disable, "component_cfg": component_cfg} if isinstance(name, str): # name or string path if name.startswith("blank:"): # shortcut for blank model return get_lang_class(name.replace("blank:", ""))() if is_package(name): # installed as package - return load_model_from_package(name, disable=disable, component_cfg=cfg) + return load_model_from_package(name, **kwargs) if Path(name).exists(): # path to model data directory - return load_model_from_path(Path(name), disable=disable, component_cfg=cfg) + return load_model_from_path(Path(name), **kwargs) elif hasattr(name, "exists"): # Path or Path-like to model data - return load_model_from_path(name, disable=disable, component_cfg=cfg) + return load_model_from_path(name, **kwargs) raise IOError(Errors.E050.format(name=name)) def load_model_from_package( name: str, + *, + vocab: Union["Vocab", bool] = True, disable: Iterable[str] = tuple(), component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(), ) -> "Language": """Load a model from an installed package.""" cls = importlib.import_module(name) - return cls.load(disable=disable, component_cfg=component_cfg) + return cls.load(vocab=vocab, disable=disable, component_cfg=component_cfg) def load_model_from_path( model_path: Union[str, Path], + *, meta: Optional[Dict[str, Any]] = None, + vocab: Union["Vocab", bool] = True, disable: Iterable[str] = tuple(), component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(), ) -> "Language": @@ -257,12 +265,16 @@ def load_model_from_path( config = Config().from_disk(config_path) override_cfg = {"components": {p: dict_to_dot(c) for p, c in component_cfg.items()}} overrides = dict_to_dot(override_cfg) - nlp, _ = load_model_from_config(config, disable=disable, overrides=overrides) + nlp, _ = load_model_from_config( + config, vocab=vocab, disable=disable, overrides=overrides + ) return nlp.from_disk(model_path, exclude=disable) def load_model_from_config( config: Union[Dict[str, Any], Config], + *, + vocab: Union["Vocab", bool] = True, disable: Iterable[str] = tuple(), overrides: Dict[str, Any] = {}, auto_fill: bool = False, @@ -281,6 +293,7 @@ def load_model_from_config( lang_cls = get_lang_class(nlp_config["lang"]) nlp = lang_cls.from_config( config, + vocab=vocab, disable=disable, overrides=overrides, auto_fill=auto_fill, @@ -291,6 +304,8 @@ def load_model_from_config( def load_model_from_init_py( init_file: Union[Path, str], + *, + vocab: Union["Vocab", bool] = True, disable: Iterable[str] = tuple(), component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(), ) -> "Language": @@ -308,7 +323,7 @@ def load_model_from_init_py( if not model_path.exists(): raise IOError(Errors.E052.format(path=data_path)) return load_model_from_path( - data_path, meta, disable=disable, component_cfg=component_cfg + data_path, vocab=vocab, meta=meta, disable=disable, component_cfg=component_cfg )