Allow loaded but disabled components

This commit is contained in:
Ines Montani 2020-08-28 15:20:14 +02:00
parent adc050cdc5
commit 3ce5be4b76
8 changed files with 259 additions and 70 deletions

View File

@ -28,17 +28,22 @@ if sys.maxunicode == 65535:
def load(
name: Union[str, Path],
disable: Iterable[str] = tuple(),
exclude: Iterable[str] = tuple(),
config: Union[Dict[str, Any], Config] = util.SimpleFrozenDict(),
) -> Language:
"""Load a spaCy model from an installed package or a local path.
name (str): Package name or model path.
disable (Iterable[str]): Names of pipeline components to disable.
disable (Iterable[str]): Names of pipeline components to disable. Disabled
pipes will be loaded but they won't be run unless you explicitly
enable them by calling nlp.enable_pipe.
exclude (Iterable[str]): Names of pipeline components to exclude. Excluded
components won't be loaded.
config (Dict[str, Any] / Config): Config overrides as nested dict or dict
keyed by section values in dot notation.
RETURNS (Language): The loaded nlp object.
"""
return util.load_model(name, disable=disable, config=config)
return util.load_model(name, disable=disable, exclude=exclude, config=config)
def blank(name: str, **overrides) -> Language:

View File

@ -11,6 +11,7 @@ use_pytorch_for_gpu_memory = false
[nlp]
lang = null
pipeline = []
disabled = []
load_vocab_data = true
before_creation = null
after_creation = null

View File

@ -6,7 +6,7 @@ import itertools
import weakref
import functools
from contextlib import contextmanager
from copy import copy, deepcopy
from copy import deepcopy
from pathlib import Path
import warnings
from thinc.api import get_current_ops, Config, require_gpu, Optimizer
@ -159,7 +159,8 @@ class Language:
self.vocab: Vocab = vocab
if self.lang is None:
self.lang = self.vocab.lang
self.pipeline = []
self._pipeline = []
self._disabled = set()
self.max_length = max_length
self.resolved = {}
# Create the default tokenizer from the default config
@ -210,6 +211,7 @@ class Language:
# TODO: Adding this back to prevent breaking people's code etc., but
# we should consider removing it
self._meta["pipeline"] = self.pipe_names
self._meta["disabled"] = list(self._disabled)
return self._meta
@meta.setter
@ -232,13 +234,14 @@ class Language:
# we can populate the config again later
pipeline = {}
score_weights = []
for pipe_name in self.pipe_names:
for pipe_name in self._pipe_names:
pipe_meta = self.get_pipe_meta(pipe_name)
pipe_config = self.get_pipe_config(pipe_name)
pipeline[pipe_name] = {"factory": pipe_meta.factory, **pipe_config}
if pipe_meta.default_score_weights:
score_weights.append(pipe_meta.default_score_weights)
self._config["nlp"]["pipeline"] = self.pipe_names
self._config["nlp"]["pipeline"] = self._pipe_names
self._config["nlp"]["disabled"] = list(self._disabled)
self._config["components"] = pipeline
self._config["training"]["score_weights"] = combine_score_weights(score_weights)
if not srsly.is_json_serializable(self._config):
@ -257,9 +260,30 @@ class Language:
"""
return list(self.factories.keys())
@property
def _pipe_names(self) -> List[str]:
"""Get the names of the available pipeline components. Includes all
active and inactive pipeline components.
RETURNS (List[str]): List of component name strings, in order.
"""
# TODO: Should we make this available via a user-facing property? (The
# underscore distinction works well internally)
return [pipe_name for pipe_name, _ in self._pipeline]
@property
def pipeline(self) -> List[Tuple[str, Callable[[Doc], Doc]]]:
"""The processing pipeline consisting of (name, component) tuples. The
components are called on the Doc in order as it passes through the
pipeline.
RETURNS (List[Tuple[str, Callable[[Doc], Doc]]]): The pipeline.
"""
return [(name, p) for name, p in self._pipeline if name not in self._disabled]
@property
def pipe_names(self) -> List[str]:
"""Get names of available pipeline components.
"""Get names of available active pipeline components.
RETURNS (List[str]): List of component name strings, in order.
"""
@ -272,7 +296,7 @@ class Language:
RETURNS (Dict[str, str]): Factory names, keyed by component names.
"""
factories = {}
for pipe_name, pipe in self.pipeline:
for pipe_name, pipe in self._pipeline:
factories[pipe_name] = self.get_pipe_meta(pipe_name).factory
return factories
@ -284,7 +308,7 @@ class Language:
RETURNS (Dict[str, List[str]]): Labels keyed by component name.
"""
labels = {}
for name, pipe in self.pipeline:
for name, pipe in self._pipeline:
if hasattr(pipe, "labels"):
labels[name] = list(pipe.labels)
return labels
@ -512,10 +536,10 @@ class Language:
DOCS: https://spacy.io/api/language#get_pipe
"""
for pipe_name, component in self.pipeline:
for pipe_name, component in self._pipeline:
if pipe_name == name:
return component
raise KeyError(Errors.E001.format(name=name, opts=self.pipe_names))
raise KeyError(Errors.E001.format(name=name, opts=self._pipe_names))
def create_pipe(
self,
@ -660,8 +684,8 @@ class Language:
err = Errors.E966.format(component=bad_val, name=name)
raise ValueError(err)
name = name if name is not None else factory_name
if name in self.pipe_names:
raise ValueError(Errors.E007.format(name=name, opts=self.pipe_names))
if name in self._pipe_names:
raise ValueError(Errors.E007.format(name=name, opts=self._pipe_names))
if source is not None:
# We're loading the component from a model. After loading the
# component, we know its real factory name
@ -686,7 +710,7 @@ class Language:
)
pipe_index = self._get_pipe_index(before, after, first, last)
self._pipe_meta[name] = self.get_factory_meta(factory_name)
self.pipeline.insert(pipe_index, (name, pipe_component))
self._pipeline.insert(pipe_index, (name, pipe_component))
return pipe_component
def _get_pipe_index(
@ -707,32 +731,34 @@ class Language:
"""
all_args = {"before": before, "after": after, "first": first, "last": last}
if sum(arg is not None for arg in [before, after, first, last]) >= 2:
raise ValueError(Errors.E006.format(args=all_args, opts=self.pipe_names))
raise ValueError(Errors.E006.format(args=all_args, opts=self._pipe_names))
if last or not any(value is not None for value in [first, before, after]):
return len(self.pipeline)
return len(self._pipeline)
elif first:
return 0
elif isinstance(before, str):
if before not in self.pipe_names:
raise ValueError(Errors.E001.format(name=before, opts=self.pipe_names))
return self.pipe_names.index(before)
if before not in self._pipe_names:
raise ValueError(Errors.E001.format(name=before, opts=self._pipe_names))
return self._pipe_names.index(before)
elif isinstance(after, str):
if after not in self.pipe_names:
raise ValueError(Errors.E001.format(name=after, opts=self.pipe_names))
return self.pipe_names.index(after) + 1
if after not in self._pipe_names:
raise ValueError(Errors.E001.format(name=after, opts=self._pipe_names))
return self._pipe_names.index(after) + 1
# We're only accepting indices referring to components that exist
# (can't just do isinstance here because bools are instance of int, too)
elif type(before) == int:
if before >= len(self.pipeline) or before < 0:
err = Errors.E959.format(dir="before", idx=before, opts=self.pipe_names)
if before >= len(self._pipeline) or before < 0:
err = Errors.E959.format(
dir="before", idx=before, opts=self._pipe_names
)
raise ValueError(err)
return before
elif type(after) == int:
if after >= len(self.pipeline) or after < 0:
err = Errors.E959.format(dir="after", idx=after, opts=self.pipe_names)
if after >= len(self._pipeline) or after < 0:
err = Errors.E959.format(dir="after", idx=after, opts=self._pipe_names)
raise ValueError(err)
return after + 1
raise ValueError(Errors.E006.format(args=all_args, opts=self.pipe_names))
raise ValueError(Errors.E006.format(args=all_args, opts=self._pipe_names))
def has_pipe(self, name: str) -> bool:
"""Check if a component name is present in the pipeline. Equivalent to
@ -773,7 +799,7 @@ class Language:
# to Language.pipeline to make sure the configs are handled correctly
pipe_index = self.pipe_names.index(name)
self.remove_pipe(name)
if not len(self.pipeline) or pipe_index == len(self.pipeline):
if not len(self._pipeline) or pipe_index == len(self._pipeline):
# we have no components to insert before/after, or we're replacing the last component
self.add_pipe(factory_name, name=name, config=config, validate=validate)
else:
@ -793,12 +819,12 @@ class Language:
DOCS: https://spacy.io/api/language#rename_pipe
"""
if old_name not in self.pipe_names:
raise ValueError(Errors.E001.format(name=old_name, opts=self.pipe_names))
if new_name in self.pipe_names:
raise ValueError(Errors.E007.format(name=new_name, opts=self.pipe_names))
i = self.pipe_names.index(old_name)
self.pipeline[i] = (new_name, self.pipeline[i][1])
if old_name not in self._pipe_names:
raise ValueError(Errors.E001.format(name=old_name, opts=self._pipe_names))
if new_name in self._pipe_names:
raise ValueError(Errors.E007.format(name=new_name, opts=self._pipe_names))
i = self._pipe_names.index(old_name)
self._pipeline[i] = (new_name, self._pipeline[i][1])
self._pipe_meta[new_name] = self._pipe_meta.pop(old_name)
self._pipe_configs[new_name] = self._pipe_configs.pop(old_name)
@ -810,15 +836,41 @@ class Language:
DOCS: https://spacy.io/api/language#remove_pipe
"""
if name not in self.pipe_names:
raise ValueError(Errors.E001.format(name=name, opts=self.pipe_names))
removed = self.pipeline.pop(self.pipe_names.index(name))
if name not in self._pipe_names:
raise ValueError(Errors.E001.format(name=name, opts=self._pipe_names))
removed = self._pipeline.pop(self._pipe_names.index(name))
# We're only removing the component itself from the metas/configs here
# because factory may be used for something else
self._pipe_meta.pop(name)
self._pipe_configs.pop(name)
# Make sure the name is also removed from the set of disabled components
if name in self._disabled:
self._disabled.remove(name)
return removed
def disable_pipe(self, name: str) -> None:
"""Disable a pipeline component. The component will still exist on
the nlp object, but it won't be run as part of the pipeline.
name (str): The name of the component to disable.
"""
if name not in self._pipe_names:
raise ValueError(Errors.E001.format(name=name, opts=self._pipe_names))
# TODO: should we raise if pipe is already disabled?
self._disabled.add(name)
def enable_pipe(self, name: str) -> None:
"""Enable a previously disabled pipeline component so it's run as part
of the pipeline.
name (str): The name of the component to enable.
"""
if name not in self._pipe_names:
raise ValueError(Errors.E001.format(name=name, opts=self._pipe_names))
# TODO: should we raise if pipe is already enabled?
if name in self._disabled:
self._disabled.remove(name)
def __call__(
self,
text: str,
@ -1366,6 +1418,7 @@ class Language:
*,
vocab: Union[Vocab, bool] = True,
disable: Iterable[str] = tuple(),
exclude: Iterable[str] = tuple(),
auto_fill: bool = True,
validate: bool = True,
) -> "Language":
@ -1375,7 +1428,11 @@ class Language:
config (Dict[str, Any] / Config): The loaded config.
vocab (Vocab): A Vocab object. If True, a vocab is created.
disable (Iterable[str]): List of pipeline component names to disable.
disable (Iterable[str]): Names of pipeline components to disable.
Disabled pipes will be loaded but they won't be run unless you
explicitly enable them by calling nlp.enable_pipe.
exclude (Iterable[str]): Names of pipeline components to exclude.
Excluded components won't be loaded.
auto_fill (bool): Automatically fill in missing values in config based
on defaults and function argument annotations.
validate (bool): Validate the component config and arguments against
@ -1448,7 +1505,7 @@ class Language:
raise ValueError(Errors.E956.format(name=pipe_name, opts=opts))
pipe_cfg = util.copy_config(pipeline[pipe_name])
raw_config = Config(filled["components"][pipe_name])
if pipe_name not in disable:
if pipe_name not in exclude:
if "factory" not in pipe_cfg and "source" not in pipe_cfg:
err = Errors.E984.format(name=pipe_name, config=pipe_cfg)
raise ValueError(err)
@ -1473,6 +1530,8 @@ class Language:
)
source_name = pipe_cfg.get("component", pipe_name)
nlp.add_pipe(source_name, source=source_nlps[model], name=pipe_name)
disabled_pipes = [*config["nlp"]["disabled"], *disable]
nlp._disabled = set(p for p in disabled_pipes if p not in exclude)
nlp.config = filled if auto_fill else config
nlp.resolved = resolved
if after_pipeline_creation is not None:
@ -1502,9 +1561,10 @@ class Language:
)
serializers["meta.json"] = lambda p: srsly.write_json(p, self.meta)
serializers["config.cfg"] = lambda p: self.config.to_disk(p)
for name, proc in self.pipeline:
if not hasattr(proc, "name"):
continue
for name, proc in self._pipeline:
# TODO: why did we add this?
# if not hasattr(proc, "name"):
# continue
if name in exclude:
continue
if not hasattr(proc, "to_disk"):
@ -1548,7 +1608,7 @@ class Language:
deserializers["tokenizer"] = lambda p: self.tokenizer.from_disk(
p, exclude=["vocab"]
)
for name, proc in self.pipeline:
for name, proc in self._pipeline:
if name in exclude:
continue
if not hasattr(proc, "from_disk"):
@ -1577,7 +1637,7 @@ class Language:
serializers["tokenizer"] = lambda: self.tokenizer.to_bytes(exclude=["vocab"])
serializers["meta.json"] = lambda: srsly.json_dumps(self.meta)
serializers["config.cfg"] = lambda: self.config.to_bytes()
for name, proc in self.pipeline:
for name, proc in self._pipeline:
if name in exclude:
continue
if not hasattr(proc, "to_bytes"):
@ -1611,7 +1671,7 @@ class Language:
deserializers["tokenizer"] = lambda b: self.tokenizer.from_bytes(
b, exclude=["vocab"]
)
for name, proc in self.pipeline:
for name, proc in self._pipeline:
if name in exclude:
continue
if not hasattr(proc, "from_bytes"):
@ -1647,14 +1707,10 @@ class DisabledPipes(list):
def __init__(self, nlp: Language, names: List[str]) -> None:
self.nlp = nlp
self.names = names
# Important! Not deep copy -- we just want the container (but we also
# want to support people providing arbitrarily typed nlp.pipeline
# objects.)
self.original_pipeline = copy(nlp.pipeline)
self.metas = {name: nlp.get_pipe_meta(name) for name in names}
self.configs = {name: nlp.get_pipe_config(name) for name in names}
for name in self.names:
self.nlp.disable_pipe(name)
list.__init__(self)
self.extend(nlp.remove_pipe(name) for name in names)
self.extend(self.names)
def __enter__(self):
return self
@ -1664,14 +1720,10 @@ class DisabledPipes(list):
def restore(self) -> None:
"""Restore the pipeline to its state when DisabledPipes was created."""
current, self.nlp.pipeline = self.nlp.pipeline, self.original_pipeline
unexpected = [name for name, pipe in current if not self.nlp.has_pipe(name)]
if unexpected:
# Don't change the pipeline if we're raising an error.
self.nlp.pipeline = current
raise ValueError(Errors.E008.format(names=unexpected))
self.nlp._pipe_meta.update(self.metas)
self.nlp._pipe_configs.update(self.configs)
for name in self.names:
self.nlp.enable_pipe(name)
# TODO: maybe add some more checks / catch errors that may occur if
# user removes a disabled pipe in the with block
self[:] = []

View File

@ -223,6 +223,7 @@ class ConfigSchemaNlp(BaseModel):
# fmt: off
lang: StrictStr = Field(..., title="The base language to use")
pipeline: List[StrictStr] = Field(..., title="The pipeline component names in order")
disabled: List[StrictStr] = Field(..., title="Pipeline components to disable by default")
tokenizer: Callable = Field(..., title="The tokenizer to use")
load_vocab_data: StrictBool = Field(..., title="Whether to load additional vocab data from spacy-lookups-data")
before_creation: Optional[Callable[[Type["Language"]], Type["Language"]]] = Field(..., title="Optional callback to modify Language class before initialization")

View File

@ -249,3 +249,66 @@ def test_add_pipe_before_after():
nlp.add_pipe("entity_ruler", before=True)
with pytest.raises(ValueError):
nlp.add_pipe("entity_ruler", first=False)
def test_disable_enable_pipes():
name = "test_disable_enable_pipes"
results = {}
def make_component(name):
results[name] = ""
def component(doc):
nonlocal results
results[name] = doc.text
return doc
return component
c1 = Language.component(f"{name}1", func=make_component(f"{name}1"))
c2 = Language.component(f"{name}2", func=make_component(f"{name}2"))
nlp = Language()
nlp.add_pipe(f"{name}1")
nlp.add_pipe(f"{name}2")
assert results[f"{name}1"] == ""
assert results[f"{name}2"] == ""
assert nlp.pipeline == [(f"{name}1", c1), (f"{name}2", c2)]
assert nlp.pipe_names == [f"{name}1", f"{name}2"]
nlp.disable_pipe(f"{name}1")
assert nlp._disabled == set([f"{name}1"])
assert nlp._pipe_names == [f"{name}1", f"{name}2"]
assert nlp.pipe_names == [f"{name}2"]
assert nlp.config["nlp"]["disabled"] == [f"{name}1"]
nlp("hello")
assert results[f"{name}1"] == "" # didn't run
assert results[f"{name}2"] == "hello" # ran
nlp.enable_pipe(f"{name}1")
assert nlp._disabled == set()
assert nlp.pipe_names == [f"{name}1", f"{name}2"]
assert nlp.config["nlp"]["disabled"] == []
nlp("world")
assert results[f"{name}1"] == "world"
assert results[f"{name}2"] == "world"
nlp.disable_pipe(f"{name}2")
nlp.remove_pipe(f"{name}2")
assert nlp._pipeline == [(f"{name}1", c1)]
assert nlp.pipeline == [(f"{name}1", c1)]
assert nlp._pipe_names == [f"{name}1"]
assert nlp.pipe_names == [f"{name}1"]
assert nlp._disabled == set()
assert nlp.config["nlp"]["disabled"] == []
nlp.rename_pipe(f"{name}1", name)
assert nlp._pipeline == [(name, c1)]
assert nlp._pipe_names == [name]
nlp("!")
assert results[f"{name}1"] == "!"
assert results[f"{name}2"] == "world"
with pytest.raises(ValueError):
nlp.disable_pipe(f"{name}2")
nlp.disable_pipe(name)
assert nlp._pipe_names == [name]
assert nlp.pipe_names == []
assert nlp.config["nlp"]["disabled"] == [name]
nlp("?")
assert results[f"{name}1"] == "!"

View File

@ -161,6 +161,7 @@ def test_issue4674():
assert kb2.get_size_entities() == 1
@pytest.mark.skip(reason="API change: disable just disables, new exclude arg")
def test_issue4707():
"""Tests that disabled component names are also excluded from nlp.from_disk
by default when loading a model.

View File

@ -6,6 +6,8 @@ from spacy.pipeline.dep_parser import DEFAULT_PARSER_MODEL
from spacy.pipeline.tagger import DEFAULT_TAGGER_MODEL
from spacy.pipeline.textcat import DEFAULT_TEXTCAT_MODEL
from spacy.pipeline.senter import DEFAULT_SENTER_MODEL
from spacy.lang.en import English
import spacy
from ..util import make_tempdir
@ -173,3 +175,34 @@ def test_serialize_sentencerecognizer(en_vocab):
sr_b = sr.to_bytes()
sr_d = SentenceRecognizer(en_vocab, model).from_bytes(sr_b)
assert sr.to_bytes() == sr_d.to_bytes()
def test_serialize_pipeline_disable_enable():
nlp = English()
nlp.add_pipe("ner")
nlp.add_pipe("tagger")
nlp.disable_pipe("tagger")
assert nlp.config["nlp"]["disabled"] == ["tagger"]
config = nlp.config.copy()
nlp2 = English.from_config(config)
assert nlp2.pipe_names == ["ner"]
assert nlp2._pipe_names == ["ner", "tagger"]
assert nlp2._disabled == set(["tagger"])
assert nlp2.config["nlp"]["disabled"] == ["tagger"]
with make_tempdir() as d:
nlp2.to_disk(d)
nlp3 = spacy.load(d)
assert nlp3.pipe_names == ["ner"]
assert nlp3._pipe_names == ["ner", "tagger"]
with make_tempdir() as d:
nlp3.to_disk(d)
nlp4 = spacy.load(d, disable=["ner"])
assert nlp4.pipe_names == []
assert nlp4._pipe_names == ["ner", "tagger"]
assert nlp4._disabled == set(["ner", "tagger"])
with make_tempdir() as d:
nlp.to_disk(d)
nlp5 = spacy.load(d, exclude=["tagger"])
assert nlp5.pipe_names == ["ner"]
assert nlp5._pipe_names == ["ner"]
assert nlp5._disabled == set()

View File

@ -216,6 +216,7 @@ def load_model(
*,
vocab: Union["Vocab", bool] = True,
disable: Iterable[str] = tuple(),
exclude: Iterable[str] = tuple(),
config: Union[Dict[str, Any], Config] = SimpleFrozenDict(),
) -> "Language":
"""Load a model from a package or data path.
@ -228,7 +229,7 @@ def load_model(
keyed by section values in dot notation.
RETURNS (Language): The loaded nlp object.
"""
kwargs = {"vocab": vocab, "disable": disable, "config": config}
kwargs = {"vocab": vocab, "disable": disable, "exclude": exclude, "config": config}
if isinstance(name, str): # name or string path
if name.startswith("blank:"): # shortcut for blank model
return get_lang_class(name.replace("blank:", ""))()
@ -248,6 +249,7 @@ def load_model_from_package(
*,
vocab: Union["Vocab", bool] = True,
disable: Iterable[str] = tuple(),
exclude: Iterable[str] = tuple(),
config: Union[Dict[str, Any], Config] = SimpleFrozenDict(),
) -> "Language":
"""Load a model from an installed package.
@ -255,13 +257,17 @@ def load_model_from_package(
name (str): The package name.
vocab (Vocab / True): Optional vocab to pass in on initialization. If True,
a new Vocab object will be created.
disable (Iterable[str]): Names of pipeline components to disable.
disable (Iterable[str]): Names of pipeline components to disable. Disabled
pipes will be loaded but they won't be run unless you explicitly
enable them by calling nlp.enable_pipe.
exclude (Iterable[str]): Names of pipeline components to exclude. Excluded
components won't be loaded.
config (Dict[str, Any] / Config): Config overrides as nested dict or dict
keyed by section values in dot notation.
RETURNS (Language): The loaded nlp object.
"""
cls = importlib.import_module(name)
return cls.load(vocab=vocab, disable=disable, config=config)
return cls.load(vocab=vocab, disable=disable, exclude=exclude, config=config)
def load_model_from_path(
@ -270,6 +276,7 @@ def load_model_from_path(
meta: Optional[Dict[str, Any]] = None,
vocab: Union["Vocab", bool] = True,
disable: Iterable[str] = tuple(),
exclude: Iterable[str] = tuple(),
config: Union[Dict[str, Any], Config] = SimpleFrozenDict(),
) -> "Language":
"""Load a model from a data directory path. Creates Language class with
@ -279,7 +286,11 @@ def load_model_from_path(
meta (Dict[str, Any]): Optional model meta.
vocab (Vocab / True): Optional vocab to pass in on initialization. If True,
a new Vocab object will be created.
disable (Iterable[str]): Names of pipeline components to disable.
disable (Iterable[str]): Names of pipeline components to disable. Disabled
pipes will be loaded but they won't be run unless you explicitly
enable them by calling nlp.enable_pipe.
exclude (Iterable[str]): Names of pipeline components to exclude. Excluded
components won't be loaded.
config (Dict[str, Any] / Config): Config overrides as nested dict or dict
keyed by section values in dot notation.
RETURNS (Language): The loaded nlp object.
@ -290,8 +301,10 @@ def load_model_from_path(
meta = get_model_meta(model_path)
config_path = model_path / "config.cfg"
config = load_config(config_path, overrides=dict_to_dot(config))
nlp, _ = load_model_from_config(config, vocab=vocab, disable=disable)
return nlp.from_disk(model_path, exclude=disable)
nlp, _ = load_model_from_config(
config, vocab=vocab, disable=disable, exclude=exclude
)
return nlp.from_disk(model_path, exclude=exclude)
def load_model_from_config(
@ -299,6 +312,7 @@ def load_model_from_config(
*,
vocab: Union["Vocab", bool] = True,
disable: Iterable[str] = tuple(),
exclude: Iterable[str] = tuple(),
auto_fill: bool = False,
validate: bool = True,
) -> Tuple["Language", Config]:
@ -309,7 +323,11 @@ def load_model_from_config(
meta (Dict[str, Any]): Optional model meta.
vocab (Vocab / True): Optional vocab to pass in on initialization. If True,
a new Vocab object will be created.
disable (Iterable[str]): Names of pipeline components to disable.
disable (Iterable[str]): Names of pipeline components to disable. Disabled
pipes will be loaded but they won't be run unless you explicitly
enable them by calling nlp.enable_pipe.
exclude (Iterable[str]): Names of pipeline components to exclude. Excluded
components won't be loaded.
auto_fill (bool): Whether to auto-fill config with missing defaults.
validate (bool): Whether to show config validation errors.
RETURNS (Language): The loaded nlp object.
@ -323,7 +341,12 @@ def load_model_from_config(
# registry, including custom subclasses provided via entry points
lang_cls = get_lang_class(nlp_config["lang"])
nlp = lang_cls.from_config(
config, vocab=vocab, disable=disable, auto_fill=auto_fill, validate=validate,
config,
vocab=vocab,
disable=disable,
exclude=exclude,
auto_fill=auto_fill,
validate=validate,
)
return nlp, nlp.resolved
@ -333,6 +356,7 @@ def load_model_from_init_py(
*,
vocab: Union["Vocab", bool] = True,
disable: Iterable[str] = tuple(),
exclude: Iterable[str] = tuple(),
config: Union[Dict[str, Any], Config] = SimpleFrozenDict(),
) -> "Language":
"""Helper function to use in the `load()` method of a model package's
@ -340,7 +364,11 @@ def load_model_from_init_py(
vocab (Vocab / True): Optional vocab to pass in on initialization. If True,
a new Vocab object will be created.
disable (Iterable[str]): Names of pipeline components to disable.
disable (Iterable[str]): Names of pipeline components to disable. Disabled
pipes will be loaded but they won't be run unless you explicitly
enable them by calling nlp.enable_pipe.
exclude (Iterable[str]): Names of pipeline components to exclude. Excluded
components won't be loaded.
config (Dict[str, Any] / Config): Config overrides as nested dict or dict
keyed by section values in dot notation.
RETURNS (Language): The loaded nlp object.
@ -352,7 +380,12 @@ def load_model_from_init_py(
if not model_path.exists():
raise IOError(Errors.E052.format(path=data_path))
return load_model_from_path(
data_path, vocab=vocab, meta=meta, disable=disable, config=config
data_path,
vocab=vocab,
meta=meta,
disable=disable,
exclude=exclude,
config=config,
)