Allow loaded but disabled components

This commit is contained in:
Ines Montani 2020-08-28 15:20:14 +02:00
parent adc050cdc5
commit 3ce5be4b76
8 changed files with 259 additions and 70 deletions

View File

@ -28,17 +28,22 @@ if sys.maxunicode == 65535:
def load( def load(
name: Union[str, Path], name: Union[str, Path],
disable: Iterable[str] = tuple(), disable: Iterable[str] = tuple(),
exclude: Iterable[str] = tuple(),
config: Union[Dict[str, Any], Config] = util.SimpleFrozenDict(), config: Union[Dict[str, Any], Config] = util.SimpleFrozenDict(),
) -> Language: ) -> Language:
"""Load a spaCy model from an installed package or a local path. """Load a spaCy model from an installed package or a local path.
name (str): Package name or model path. name (str): Package name or model path.
disable (Iterable[str]): Names of pipeline components to disable. disable (Iterable[str]): Names of pipeline components to disable. Disabled
pipes will be loaded but they won't be run unless you explicitly
enable them by calling nlp.enable_pipe.
exclude (Iterable[str]): Names of pipeline components to exclude. Excluded
components won't be loaded.
config (Dict[str, Any] / Config): Config overrides as nested dict or dict config (Dict[str, Any] / Config): Config overrides as nested dict or dict
keyed by section values in dot notation. keyed by section values in dot notation.
RETURNS (Language): The loaded nlp object. RETURNS (Language): The loaded nlp object.
""" """
return util.load_model(name, disable=disable, config=config) return util.load_model(name, disable=disable, exclude=exclude, config=config)
def blank(name: str, **overrides) -> Language: def blank(name: str, **overrides) -> Language:

View File

@ -11,6 +11,7 @@ use_pytorch_for_gpu_memory = false
[nlp] [nlp]
lang = null lang = null
pipeline = [] pipeline = []
disabled = []
load_vocab_data = true load_vocab_data = true
before_creation = null before_creation = null
after_creation = null after_creation = null

View File

@ -6,7 +6,7 @@ import itertools
import weakref import weakref
import functools import functools
from contextlib import contextmanager from contextlib import contextmanager
from copy import copy, deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
import warnings import warnings
from thinc.api import get_current_ops, Config, require_gpu, Optimizer from thinc.api import get_current_ops, Config, require_gpu, Optimizer
@ -159,7 +159,8 @@ class Language:
self.vocab: Vocab = vocab self.vocab: Vocab = vocab
if self.lang is None: if self.lang is None:
self.lang = self.vocab.lang self.lang = self.vocab.lang
self.pipeline = [] self._pipeline = []
self._disabled = set()
self.max_length = max_length self.max_length = max_length
self.resolved = {} self.resolved = {}
# Create the default tokenizer from the default config # Create the default tokenizer from the default config
@ -210,6 +211,7 @@ class Language:
# TODO: Adding this back to prevent breaking people's code etc., but # TODO: Adding this back to prevent breaking people's code etc., but
# we should consider removing it # we should consider removing it
self._meta["pipeline"] = self.pipe_names self._meta["pipeline"] = self.pipe_names
self._meta["disabled"] = list(self._disabled)
return self._meta return self._meta
@meta.setter @meta.setter
@ -232,13 +234,14 @@ class Language:
# we can populate the config again later # we can populate the config again later
pipeline = {} pipeline = {}
score_weights = [] score_weights = []
for pipe_name in self.pipe_names: for pipe_name in self._pipe_names:
pipe_meta = self.get_pipe_meta(pipe_name) pipe_meta = self.get_pipe_meta(pipe_name)
pipe_config = self.get_pipe_config(pipe_name) pipe_config = self.get_pipe_config(pipe_name)
pipeline[pipe_name] = {"factory": pipe_meta.factory, **pipe_config} pipeline[pipe_name] = {"factory": pipe_meta.factory, **pipe_config}
if pipe_meta.default_score_weights: if pipe_meta.default_score_weights:
score_weights.append(pipe_meta.default_score_weights) score_weights.append(pipe_meta.default_score_weights)
self._config["nlp"]["pipeline"] = self.pipe_names self._config["nlp"]["pipeline"] = self._pipe_names
self._config["nlp"]["disabled"] = list(self._disabled)
self._config["components"] = pipeline self._config["components"] = pipeline
self._config["training"]["score_weights"] = combine_score_weights(score_weights) self._config["training"]["score_weights"] = combine_score_weights(score_weights)
if not srsly.is_json_serializable(self._config): if not srsly.is_json_serializable(self._config):
@ -257,9 +260,30 @@ class Language:
""" """
return list(self.factories.keys()) return list(self.factories.keys())
@property
def _pipe_names(self) -> List[str]:
"""Get the names of the available pipeline components. Includes all
active and inactive pipeline components.
RETURNS (List[str]): List of component name strings, in order.
"""
# TODO: Should we make this available via a user-facing property? (The
# underscore distinction works well internally)
return [pipe_name for pipe_name, _ in self._pipeline]
@property
def pipeline(self) -> List[Tuple[str, Callable[[Doc], Doc]]]:
"""The processing pipeline consisting of (name, component) tuples. The
components are called on the Doc in order as it passes through the
pipeline.
RETURNS (List[Tuple[str, Callable[[Doc], Doc]]]): The pipeline.
"""
return [(name, p) for name, p in self._pipeline if name not in self._disabled]
@property @property
def pipe_names(self) -> List[str]: def pipe_names(self) -> List[str]:
"""Get names of available pipeline components. """Get names of available active pipeline components.
RETURNS (List[str]): List of component name strings, in order. RETURNS (List[str]): List of component name strings, in order.
""" """
@ -272,7 +296,7 @@ class Language:
RETURNS (Dict[str, str]): Factory names, keyed by component names. RETURNS (Dict[str, str]): Factory names, keyed by component names.
""" """
factories = {} factories = {}
for pipe_name, pipe in self.pipeline: for pipe_name, pipe in self._pipeline:
factories[pipe_name] = self.get_pipe_meta(pipe_name).factory factories[pipe_name] = self.get_pipe_meta(pipe_name).factory
return factories return factories
@ -284,7 +308,7 @@ class Language:
RETURNS (Dict[str, List[str]]): Labels keyed by component name. RETURNS (Dict[str, List[str]]): Labels keyed by component name.
""" """
labels = {} labels = {}
for name, pipe in self.pipeline: for name, pipe in self._pipeline:
if hasattr(pipe, "labels"): if hasattr(pipe, "labels"):
labels[name] = list(pipe.labels) labels[name] = list(pipe.labels)
return labels return labels
@ -512,10 +536,10 @@ class Language:
DOCS: https://spacy.io/api/language#get_pipe DOCS: https://spacy.io/api/language#get_pipe
""" """
for pipe_name, component in self.pipeline: for pipe_name, component in self._pipeline:
if pipe_name == name: if pipe_name == name:
return component return component
raise KeyError(Errors.E001.format(name=name, opts=self.pipe_names)) raise KeyError(Errors.E001.format(name=name, opts=self._pipe_names))
def create_pipe( def create_pipe(
self, self,
@ -660,8 +684,8 @@ class Language:
err = Errors.E966.format(component=bad_val, name=name) err = Errors.E966.format(component=bad_val, name=name)
raise ValueError(err) raise ValueError(err)
name = name if name is not None else factory_name name = name if name is not None else factory_name
if name in self.pipe_names: if name in self._pipe_names:
raise ValueError(Errors.E007.format(name=name, opts=self.pipe_names)) raise ValueError(Errors.E007.format(name=name, opts=self._pipe_names))
if source is not None: if source is not None:
# We're loading the component from a model. After loading the # We're loading the component from a model. After loading the
# component, we know its real factory name # component, we know its real factory name
@ -686,7 +710,7 @@ class Language:
) )
pipe_index = self._get_pipe_index(before, after, first, last) pipe_index = self._get_pipe_index(before, after, first, last)
self._pipe_meta[name] = self.get_factory_meta(factory_name) self._pipe_meta[name] = self.get_factory_meta(factory_name)
self.pipeline.insert(pipe_index, (name, pipe_component)) self._pipeline.insert(pipe_index, (name, pipe_component))
return pipe_component return pipe_component
def _get_pipe_index( def _get_pipe_index(
@ -707,32 +731,34 @@ class Language:
""" """
all_args = {"before": before, "after": after, "first": first, "last": last} all_args = {"before": before, "after": after, "first": first, "last": last}
if sum(arg is not None for arg in [before, after, first, last]) >= 2: if sum(arg is not None for arg in [before, after, first, last]) >= 2:
raise ValueError(Errors.E006.format(args=all_args, opts=self.pipe_names)) raise ValueError(Errors.E006.format(args=all_args, opts=self._pipe_names))
if last or not any(value is not None for value in [first, before, after]): if last or not any(value is not None for value in [first, before, after]):
return len(self.pipeline) return len(self._pipeline)
elif first: elif first:
return 0 return 0
elif isinstance(before, str): elif isinstance(before, str):
if before not in self.pipe_names: if before not in self._pipe_names:
raise ValueError(Errors.E001.format(name=before, opts=self.pipe_names)) raise ValueError(Errors.E001.format(name=before, opts=self._pipe_names))
return self.pipe_names.index(before) return self._pipe_names.index(before)
elif isinstance(after, str): elif isinstance(after, str):
if after not in self.pipe_names: if after not in self._pipe_names:
raise ValueError(Errors.E001.format(name=after, opts=self.pipe_names)) raise ValueError(Errors.E001.format(name=after, opts=self._pipe_names))
return self.pipe_names.index(after) + 1 return self._pipe_names.index(after) + 1
# We're only accepting indices referring to components that exist # We're only accepting indices referring to components that exist
# (can't just do isinstance here because bools are instance of int, too) # (can't just do isinstance here because bools are instance of int, too)
elif type(before) == int: elif type(before) == int:
if before >= len(self.pipeline) or before < 0: if before >= len(self._pipeline) or before < 0:
err = Errors.E959.format(dir="before", idx=before, opts=self.pipe_names) err = Errors.E959.format(
dir="before", idx=before, opts=self._pipe_names
)
raise ValueError(err) raise ValueError(err)
return before return before
elif type(after) == int: elif type(after) == int:
if after >= len(self.pipeline) or after < 0: if after >= len(self._pipeline) or after < 0:
err = Errors.E959.format(dir="after", idx=after, opts=self.pipe_names) err = Errors.E959.format(dir="after", idx=after, opts=self._pipe_names)
raise ValueError(err) raise ValueError(err)
return after + 1 return after + 1
raise ValueError(Errors.E006.format(args=all_args, opts=self.pipe_names)) raise ValueError(Errors.E006.format(args=all_args, opts=self._pipe_names))
def has_pipe(self, name: str) -> bool: def has_pipe(self, name: str) -> bool:
"""Check if a component name is present in the pipeline. Equivalent to """Check if a component name is present in the pipeline. Equivalent to
@ -773,7 +799,7 @@ class Language:
# to Language.pipeline to make sure the configs are handled correctly # to Language.pipeline to make sure the configs are handled correctly
pipe_index = self.pipe_names.index(name) pipe_index = self.pipe_names.index(name)
self.remove_pipe(name) self.remove_pipe(name)
if not len(self.pipeline) or pipe_index == len(self.pipeline): if not len(self._pipeline) or pipe_index == len(self._pipeline):
# we have no components to insert before/after, or we're replacing the last component # we have no components to insert before/after, or we're replacing the last component
self.add_pipe(factory_name, name=name, config=config, validate=validate) self.add_pipe(factory_name, name=name, config=config, validate=validate)
else: else:
@ -793,12 +819,12 @@ class Language:
DOCS: https://spacy.io/api/language#rename_pipe DOCS: https://spacy.io/api/language#rename_pipe
""" """
if old_name not in self.pipe_names: if old_name not in self._pipe_names:
raise ValueError(Errors.E001.format(name=old_name, opts=self.pipe_names)) raise ValueError(Errors.E001.format(name=old_name, opts=self._pipe_names))
if new_name in self.pipe_names: if new_name in self._pipe_names:
raise ValueError(Errors.E007.format(name=new_name, opts=self.pipe_names)) raise ValueError(Errors.E007.format(name=new_name, opts=self._pipe_names))
i = self.pipe_names.index(old_name) i = self._pipe_names.index(old_name)
self.pipeline[i] = (new_name, self.pipeline[i][1]) self._pipeline[i] = (new_name, self._pipeline[i][1])
self._pipe_meta[new_name] = self._pipe_meta.pop(old_name) self._pipe_meta[new_name] = self._pipe_meta.pop(old_name)
self._pipe_configs[new_name] = self._pipe_configs.pop(old_name) self._pipe_configs[new_name] = self._pipe_configs.pop(old_name)
@ -810,15 +836,41 @@ class Language:
DOCS: https://spacy.io/api/language#remove_pipe DOCS: https://spacy.io/api/language#remove_pipe
""" """
if name not in self.pipe_names: if name not in self._pipe_names:
raise ValueError(Errors.E001.format(name=name, opts=self.pipe_names)) raise ValueError(Errors.E001.format(name=name, opts=self._pipe_names))
removed = self.pipeline.pop(self.pipe_names.index(name)) removed = self._pipeline.pop(self._pipe_names.index(name))
# We're only removing the component itself from the metas/configs here # We're only removing the component itself from the metas/configs here
# because factory may be used for something else # because factory may be used for something else
self._pipe_meta.pop(name) self._pipe_meta.pop(name)
self._pipe_configs.pop(name) self._pipe_configs.pop(name)
# Make sure the name is also removed from the set of disabled components
if name in self._disabled:
self._disabled.remove(name)
return removed return removed
def disable_pipe(self, name: str) -> None:
"""Disable a pipeline component. The component will still exist on
the nlp object, but it won't be run as part of the pipeline.
name (str): The name of the component to disable.
"""
if name not in self._pipe_names:
raise ValueError(Errors.E001.format(name=name, opts=self._pipe_names))
# TODO: should we raise if pipe is already disabled?
self._disabled.add(name)
def enable_pipe(self, name: str) -> None:
"""Enable a previously disabled pipeline component so it's run as part
of the pipeline.
name (str): The name of the component to enable.
"""
if name not in self._pipe_names:
raise ValueError(Errors.E001.format(name=name, opts=self._pipe_names))
# TODO: should we raise if pipe is already enabled?
if name in self._disabled:
self._disabled.remove(name)
def __call__( def __call__(
self, self,
text: str, text: str,
@ -1366,6 +1418,7 @@ class Language:
*, *,
vocab: Union[Vocab, bool] = True, vocab: Union[Vocab, bool] = True,
disable: Iterable[str] = tuple(), disable: Iterable[str] = tuple(),
exclude: Iterable[str] = tuple(),
auto_fill: bool = True, auto_fill: bool = True,
validate: bool = True, validate: bool = True,
) -> "Language": ) -> "Language":
@ -1375,7 +1428,11 @@ class Language:
config (Dict[str, Any] / Config): The loaded config. config (Dict[str, Any] / Config): The loaded config.
vocab (Vocab): A Vocab object. If True, a vocab is created. vocab (Vocab): A Vocab object. If True, a vocab is created.
disable (Iterable[str]): List of pipeline component names to disable. disable (Iterable[str]): Names of pipeline components to disable.
Disabled pipes will be loaded but they won't be run unless you
explicitly enable them by calling nlp.enable_pipe.
exclude (Iterable[str]): Names of pipeline components to exclude.
Excluded components won't be loaded.
auto_fill (bool): Automatically fill in missing values in config based auto_fill (bool): Automatically fill in missing values in config based
on defaults and function argument annotations. on defaults and function argument annotations.
validate (bool): Validate the component config and arguments against validate (bool): Validate the component config and arguments against
@ -1448,7 +1505,7 @@ class Language:
raise ValueError(Errors.E956.format(name=pipe_name, opts=opts)) raise ValueError(Errors.E956.format(name=pipe_name, opts=opts))
pipe_cfg = util.copy_config(pipeline[pipe_name]) pipe_cfg = util.copy_config(pipeline[pipe_name])
raw_config = Config(filled["components"][pipe_name]) raw_config = Config(filled["components"][pipe_name])
if pipe_name not in disable: if pipe_name not in exclude:
if "factory" not in pipe_cfg and "source" not in pipe_cfg: if "factory" not in pipe_cfg and "source" not in pipe_cfg:
err = Errors.E984.format(name=pipe_name, config=pipe_cfg) err = Errors.E984.format(name=pipe_name, config=pipe_cfg)
raise ValueError(err) raise ValueError(err)
@ -1473,6 +1530,8 @@ class Language:
) )
source_name = pipe_cfg.get("component", pipe_name) source_name = pipe_cfg.get("component", pipe_name)
nlp.add_pipe(source_name, source=source_nlps[model], name=pipe_name) nlp.add_pipe(source_name, source=source_nlps[model], name=pipe_name)
disabled_pipes = [*config["nlp"]["disabled"], *disable]
nlp._disabled = set(p for p in disabled_pipes if p not in exclude)
nlp.config = filled if auto_fill else config nlp.config = filled if auto_fill else config
nlp.resolved = resolved nlp.resolved = resolved
if after_pipeline_creation is not None: if after_pipeline_creation is not None:
@ -1502,9 +1561,10 @@ class Language:
) )
serializers["meta.json"] = lambda p: srsly.write_json(p, self.meta) serializers["meta.json"] = lambda p: srsly.write_json(p, self.meta)
serializers["config.cfg"] = lambda p: self.config.to_disk(p) serializers["config.cfg"] = lambda p: self.config.to_disk(p)
for name, proc in self.pipeline: for name, proc in self._pipeline:
if not hasattr(proc, "name"): # TODO: why did we add this?
continue # if not hasattr(proc, "name"):
# continue
if name in exclude: if name in exclude:
continue continue
if not hasattr(proc, "to_disk"): if not hasattr(proc, "to_disk"):
@ -1548,7 +1608,7 @@ class Language:
deserializers["tokenizer"] = lambda p: self.tokenizer.from_disk( deserializers["tokenizer"] = lambda p: self.tokenizer.from_disk(
p, exclude=["vocab"] p, exclude=["vocab"]
) )
for name, proc in self.pipeline: for name, proc in self._pipeline:
if name in exclude: if name in exclude:
continue continue
if not hasattr(proc, "from_disk"): if not hasattr(proc, "from_disk"):
@ -1577,7 +1637,7 @@ class Language:
serializers["tokenizer"] = lambda: self.tokenizer.to_bytes(exclude=["vocab"]) serializers["tokenizer"] = lambda: self.tokenizer.to_bytes(exclude=["vocab"])
serializers["meta.json"] = lambda: srsly.json_dumps(self.meta) serializers["meta.json"] = lambda: srsly.json_dumps(self.meta)
serializers["config.cfg"] = lambda: self.config.to_bytes() serializers["config.cfg"] = lambda: self.config.to_bytes()
for name, proc in self.pipeline: for name, proc in self._pipeline:
if name in exclude: if name in exclude:
continue continue
if not hasattr(proc, "to_bytes"): if not hasattr(proc, "to_bytes"):
@ -1611,7 +1671,7 @@ class Language:
deserializers["tokenizer"] = lambda b: self.tokenizer.from_bytes( deserializers["tokenizer"] = lambda b: self.tokenizer.from_bytes(
b, exclude=["vocab"] b, exclude=["vocab"]
) )
for name, proc in self.pipeline: for name, proc in self._pipeline:
if name in exclude: if name in exclude:
continue continue
if not hasattr(proc, "from_bytes"): if not hasattr(proc, "from_bytes"):
@ -1647,14 +1707,10 @@ class DisabledPipes(list):
def __init__(self, nlp: Language, names: List[str]) -> None: def __init__(self, nlp: Language, names: List[str]) -> None:
self.nlp = nlp self.nlp = nlp
self.names = names self.names = names
# Important! Not deep copy -- we just want the container (but we also for name in self.names:
# want to support people providing arbitrarily typed nlp.pipeline self.nlp.disable_pipe(name)
# objects.)
self.original_pipeline = copy(nlp.pipeline)
self.metas = {name: nlp.get_pipe_meta(name) for name in names}
self.configs = {name: nlp.get_pipe_config(name) for name in names}
list.__init__(self) list.__init__(self)
self.extend(nlp.remove_pipe(name) for name in names) self.extend(self.names)
def __enter__(self): def __enter__(self):
return self return self
@ -1664,14 +1720,10 @@ class DisabledPipes(list):
def restore(self) -> None: def restore(self) -> None:
"""Restore the pipeline to its state when DisabledPipes was created.""" """Restore the pipeline to its state when DisabledPipes was created."""
current, self.nlp.pipeline = self.nlp.pipeline, self.original_pipeline for name in self.names:
unexpected = [name for name, pipe in current if not self.nlp.has_pipe(name)] self.nlp.enable_pipe(name)
if unexpected: # TODO: maybe add some more checks / catch errors that may occur if
# Don't change the pipeline if we're raising an error. # user removes a disabled pipe in the with block
self.nlp.pipeline = current
raise ValueError(Errors.E008.format(names=unexpected))
self.nlp._pipe_meta.update(self.metas)
self.nlp._pipe_configs.update(self.configs)
self[:] = [] self[:] = []

View File

@ -223,6 +223,7 @@ class ConfigSchemaNlp(BaseModel):
# fmt: off # fmt: off
lang: StrictStr = Field(..., title="The base language to use") lang: StrictStr = Field(..., title="The base language to use")
pipeline: List[StrictStr] = Field(..., title="The pipeline component names in order") pipeline: List[StrictStr] = Field(..., title="The pipeline component names in order")
disabled: List[StrictStr] = Field(..., title="Pipeline components to disable by default")
tokenizer: Callable = Field(..., title="The tokenizer to use") tokenizer: Callable = Field(..., title="The tokenizer to use")
load_vocab_data: StrictBool = Field(..., title="Whether to load additional vocab data from spacy-lookups-data") load_vocab_data: StrictBool = Field(..., title="Whether to load additional vocab data from spacy-lookups-data")
before_creation: Optional[Callable[[Type["Language"]], Type["Language"]]] = Field(..., title="Optional callback to modify Language class before initialization") before_creation: Optional[Callable[[Type["Language"]], Type["Language"]]] = Field(..., title="Optional callback to modify Language class before initialization")

View File

@ -249,3 +249,66 @@ def test_add_pipe_before_after():
nlp.add_pipe("entity_ruler", before=True) nlp.add_pipe("entity_ruler", before=True)
with pytest.raises(ValueError): with pytest.raises(ValueError):
nlp.add_pipe("entity_ruler", first=False) nlp.add_pipe("entity_ruler", first=False)
def test_disable_enable_pipes():
name = "test_disable_enable_pipes"
results = {}
def make_component(name):
results[name] = ""
def component(doc):
nonlocal results
results[name] = doc.text
return doc
return component
c1 = Language.component(f"{name}1", func=make_component(f"{name}1"))
c2 = Language.component(f"{name}2", func=make_component(f"{name}2"))
nlp = Language()
nlp.add_pipe(f"{name}1")
nlp.add_pipe(f"{name}2")
assert results[f"{name}1"] == ""
assert results[f"{name}2"] == ""
assert nlp.pipeline == [(f"{name}1", c1), (f"{name}2", c2)]
assert nlp.pipe_names == [f"{name}1", f"{name}2"]
nlp.disable_pipe(f"{name}1")
assert nlp._disabled == set([f"{name}1"])
assert nlp._pipe_names == [f"{name}1", f"{name}2"]
assert nlp.pipe_names == [f"{name}2"]
assert nlp.config["nlp"]["disabled"] == [f"{name}1"]
nlp("hello")
assert results[f"{name}1"] == "" # didn't run
assert results[f"{name}2"] == "hello" # ran
nlp.enable_pipe(f"{name}1")
assert nlp._disabled == set()
assert nlp.pipe_names == [f"{name}1", f"{name}2"]
assert nlp.config["nlp"]["disabled"] == []
nlp("world")
assert results[f"{name}1"] == "world"
assert results[f"{name}2"] == "world"
nlp.disable_pipe(f"{name}2")
nlp.remove_pipe(f"{name}2")
assert nlp._pipeline == [(f"{name}1", c1)]
assert nlp.pipeline == [(f"{name}1", c1)]
assert nlp._pipe_names == [f"{name}1"]
assert nlp.pipe_names == [f"{name}1"]
assert nlp._disabled == set()
assert nlp.config["nlp"]["disabled"] == []
nlp.rename_pipe(f"{name}1", name)
assert nlp._pipeline == [(name, c1)]
assert nlp._pipe_names == [name]
nlp("!")
assert results[f"{name}1"] == "!"
assert results[f"{name}2"] == "world"
with pytest.raises(ValueError):
nlp.disable_pipe(f"{name}2")
nlp.disable_pipe(name)
assert nlp._pipe_names == [name]
assert nlp.pipe_names == []
assert nlp.config["nlp"]["disabled"] == [name]
nlp("?")
assert results[f"{name}1"] == "!"

View File

@ -161,6 +161,7 @@ def test_issue4674():
assert kb2.get_size_entities() == 1 assert kb2.get_size_entities() == 1
@pytest.mark.skip(reason="API change: disable just disables, new exclude arg")
def test_issue4707(): def test_issue4707():
"""Tests that disabled component names are also excluded from nlp.from_disk """Tests that disabled component names are also excluded from nlp.from_disk
by default when loading a model. by default when loading a model.

View File

@ -6,6 +6,8 @@ from spacy.pipeline.dep_parser import DEFAULT_PARSER_MODEL
from spacy.pipeline.tagger import DEFAULT_TAGGER_MODEL from spacy.pipeline.tagger import DEFAULT_TAGGER_MODEL
from spacy.pipeline.textcat import DEFAULT_TEXTCAT_MODEL from spacy.pipeline.textcat import DEFAULT_TEXTCAT_MODEL
from spacy.pipeline.senter import DEFAULT_SENTER_MODEL from spacy.pipeline.senter import DEFAULT_SENTER_MODEL
from spacy.lang.en import English
import spacy
from ..util import make_tempdir from ..util import make_tempdir
@ -173,3 +175,34 @@ def test_serialize_sentencerecognizer(en_vocab):
sr_b = sr.to_bytes() sr_b = sr.to_bytes()
sr_d = SentenceRecognizer(en_vocab, model).from_bytes(sr_b) sr_d = SentenceRecognizer(en_vocab, model).from_bytes(sr_b)
assert sr.to_bytes() == sr_d.to_bytes() assert sr.to_bytes() == sr_d.to_bytes()
def test_serialize_pipeline_disable_enable():
nlp = English()
nlp.add_pipe("ner")
nlp.add_pipe("tagger")
nlp.disable_pipe("tagger")
assert nlp.config["nlp"]["disabled"] == ["tagger"]
config = nlp.config.copy()
nlp2 = English.from_config(config)
assert nlp2.pipe_names == ["ner"]
assert nlp2._pipe_names == ["ner", "tagger"]
assert nlp2._disabled == set(["tagger"])
assert nlp2.config["nlp"]["disabled"] == ["tagger"]
with make_tempdir() as d:
nlp2.to_disk(d)
nlp3 = spacy.load(d)
assert nlp3.pipe_names == ["ner"]
assert nlp3._pipe_names == ["ner", "tagger"]
with make_tempdir() as d:
nlp3.to_disk(d)
nlp4 = spacy.load(d, disable=["ner"])
assert nlp4.pipe_names == []
assert nlp4._pipe_names == ["ner", "tagger"]
assert nlp4._disabled == set(["ner", "tagger"])
with make_tempdir() as d:
nlp.to_disk(d)
nlp5 = spacy.load(d, exclude=["tagger"])
assert nlp5.pipe_names == ["ner"]
assert nlp5._pipe_names == ["ner"]
assert nlp5._disabled == set()

View File

@ -216,6 +216,7 @@ def load_model(
*, *,
vocab: Union["Vocab", bool] = True, vocab: Union["Vocab", bool] = True,
disable: Iterable[str] = tuple(), disable: Iterable[str] = tuple(),
exclude: Iterable[str] = tuple(),
config: Union[Dict[str, Any], Config] = SimpleFrozenDict(), config: Union[Dict[str, Any], Config] = SimpleFrozenDict(),
) -> "Language": ) -> "Language":
"""Load a model from a package or data path. """Load a model from a package or data path.
@ -228,7 +229,7 @@ def load_model(
keyed by section values in dot notation. keyed by section values in dot notation.
RETURNS (Language): The loaded nlp object. RETURNS (Language): The loaded nlp object.
""" """
kwargs = {"vocab": vocab, "disable": disable, "config": config} kwargs = {"vocab": vocab, "disable": disable, "exclude": exclude, "config": config}
if isinstance(name, str): # name or string path if isinstance(name, str): # name or string path
if name.startswith("blank:"): # shortcut for blank model if name.startswith("blank:"): # shortcut for blank model
return get_lang_class(name.replace("blank:", ""))() return get_lang_class(name.replace("blank:", ""))()
@ -248,6 +249,7 @@ def load_model_from_package(
*, *,
vocab: Union["Vocab", bool] = True, vocab: Union["Vocab", bool] = True,
disable: Iterable[str] = tuple(), disable: Iterable[str] = tuple(),
exclude: Iterable[str] = tuple(),
config: Union[Dict[str, Any], Config] = SimpleFrozenDict(), config: Union[Dict[str, Any], Config] = SimpleFrozenDict(),
) -> "Language": ) -> "Language":
"""Load a model from an installed package. """Load a model from an installed package.
@ -255,13 +257,17 @@ def load_model_from_package(
name (str): The package name. name (str): The package name.
vocab (Vocab / True): Optional vocab to pass in on initialization. If True, vocab (Vocab / True): Optional vocab to pass in on initialization. If True,
a new Vocab object will be created. a new Vocab object will be created.
disable (Iterable[str]): Names of pipeline components to disable. disable (Iterable[str]): Names of pipeline components to disable. Disabled
pipes will be loaded but they won't be run unless you explicitly
enable them by calling nlp.enable_pipe.
exclude (Iterable[str]): Names of pipeline components to exclude. Excluded
components won't be loaded.
config (Dict[str, Any] / Config): Config overrides as nested dict or dict config (Dict[str, Any] / Config): Config overrides as nested dict or dict
keyed by section values in dot notation. keyed by section values in dot notation.
RETURNS (Language): The loaded nlp object. RETURNS (Language): The loaded nlp object.
""" """
cls = importlib.import_module(name) cls = importlib.import_module(name)
return cls.load(vocab=vocab, disable=disable, config=config) return cls.load(vocab=vocab, disable=disable, exclude=exclude, config=config)
def load_model_from_path( def load_model_from_path(
@ -270,6 +276,7 @@ def load_model_from_path(
meta: Optional[Dict[str, Any]] = None, meta: Optional[Dict[str, Any]] = None,
vocab: Union["Vocab", bool] = True, vocab: Union["Vocab", bool] = True,
disable: Iterable[str] = tuple(), disable: Iterable[str] = tuple(),
exclude: Iterable[str] = tuple(),
config: Union[Dict[str, Any], Config] = SimpleFrozenDict(), config: Union[Dict[str, Any], Config] = SimpleFrozenDict(),
) -> "Language": ) -> "Language":
"""Load a model from a data directory path. Creates Language class with """Load a model from a data directory path. Creates Language class with
@ -279,7 +286,11 @@ def load_model_from_path(
meta (Dict[str, Any]): Optional model meta. meta (Dict[str, Any]): Optional model meta.
vocab (Vocab / True): Optional vocab to pass in on initialization. If True, vocab (Vocab / True): Optional vocab to pass in on initialization. If True,
a new Vocab object will be created. a new Vocab object will be created.
disable (Iterable[str]): Names of pipeline components to disable. disable (Iterable[str]): Names of pipeline components to disable. Disabled
pipes will be loaded but they won't be run unless you explicitly
enable them by calling nlp.enable_pipe.
exclude (Iterable[str]): Names of pipeline components to exclude. Excluded
components won't be loaded.
config (Dict[str, Any] / Config): Config overrides as nested dict or dict config (Dict[str, Any] / Config): Config overrides as nested dict or dict
keyed by section values in dot notation. keyed by section values in dot notation.
RETURNS (Language): The loaded nlp object. RETURNS (Language): The loaded nlp object.
@ -290,8 +301,10 @@ def load_model_from_path(
meta = get_model_meta(model_path) meta = get_model_meta(model_path)
config_path = model_path / "config.cfg" config_path = model_path / "config.cfg"
config = load_config(config_path, overrides=dict_to_dot(config)) config = load_config(config_path, overrides=dict_to_dot(config))
nlp, _ = load_model_from_config(config, vocab=vocab, disable=disable) nlp, _ = load_model_from_config(
return nlp.from_disk(model_path, exclude=disable) config, vocab=vocab, disable=disable, exclude=exclude
)
return nlp.from_disk(model_path, exclude=exclude)
def load_model_from_config( def load_model_from_config(
@ -299,6 +312,7 @@ def load_model_from_config(
*, *,
vocab: Union["Vocab", bool] = True, vocab: Union["Vocab", bool] = True,
disable: Iterable[str] = tuple(), disable: Iterable[str] = tuple(),
exclude: Iterable[str] = tuple(),
auto_fill: bool = False, auto_fill: bool = False,
validate: bool = True, validate: bool = True,
) -> Tuple["Language", Config]: ) -> Tuple["Language", Config]:
@ -309,7 +323,11 @@ def load_model_from_config(
meta (Dict[str, Any]): Optional model meta. meta (Dict[str, Any]): Optional model meta.
vocab (Vocab / True): Optional vocab to pass in on initialization. If True, vocab (Vocab / True): Optional vocab to pass in on initialization. If True,
a new Vocab object will be created. a new Vocab object will be created.
disable (Iterable[str]): Names of pipeline components to disable. disable (Iterable[str]): Names of pipeline components to disable. Disabled
pipes will be loaded but they won't be run unless you explicitly
enable them by calling nlp.enable_pipe.
exclude (Iterable[str]): Names of pipeline components to exclude. Excluded
components won't be loaded.
auto_fill (bool): Whether to auto-fill config with missing defaults. auto_fill (bool): Whether to auto-fill config with missing defaults.
validate (bool): Whether to show config validation errors. validate (bool): Whether to show config validation errors.
RETURNS (Language): The loaded nlp object. RETURNS (Language): The loaded nlp object.
@ -323,7 +341,12 @@ def load_model_from_config(
# registry, including custom subclasses provided via entry points # registry, including custom subclasses provided via entry points
lang_cls = get_lang_class(nlp_config["lang"]) lang_cls = get_lang_class(nlp_config["lang"])
nlp = lang_cls.from_config( nlp = lang_cls.from_config(
config, vocab=vocab, disable=disable, auto_fill=auto_fill, validate=validate, config,
vocab=vocab,
disable=disable,
exclude=exclude,
auto_fill=auto_fill,
validate=validate,
) )
return nlp, nlp.resolved return nlp, nlp.resolved
@ -333,6 +356,7 @@ def load_model_from_init_py(
*, *,
vocab: Union["Vocab", bool] = True, vocab: Union["Vocab", bool] = True,
disable: Iterable[str] = tuple(), disable: Iterable[str] = tuple(),
exclude: Iterable[str] = tuple(),
config: Union[Dict[str, Any], Config] = SimpleFrozenDict(), config: Union[Dict[str, Any], Config] = SimpleFrozenDict(),
) -> "Language": ) -> "Language":
"""Helper function to use in the `load()` method of a model package's """Helper function to use in the `load()` method of a model package's
@ -340,7 +364,11 @@ def load_model_from_init_py(
vocab (Vocab / True): Optional vocab to pass in on initialization. If True, vocab (Vocab / True): Optional vocab to pass in on initialization. If True,
a new Vocab object will be created. a new Vocab object will be created.
disable (Iterable[str]): Names of pipeline components to disable. disable (Iterable[str]): Names of pipeline components to disable. Disabled
pipes will be loaded but they won't be run unless you explicitly
enable them by calling nlp.enable_pipe.
exclude (Iterable[str]): Names of pipeline components to exclude. Excluded
components won't be loaded.
config (Dict[str, Any] / Config): Config overrides as nested dict or dict config (Dict[str, Any] / Config): Config overrides as nested dict or dict
keyed by section values in dot notation. keyed by section values in dot notation.
RETURNS (Language): The loaded nlp object. RETURNS (Language): The loaded nlp object.
@ -352,7 +380,12 @@ def load_model_from_init_py(
if not model_path.exists(): if not model_path.exists():
raise IOError(Errors.E052.format(path=data_path)) raise IOError(Errors.E052.format(path=data_path))
return load_model_from_path( return load_model_from_path(
data_path, vocab=vocab, meta=meta, disable=disable, config=config data_path,
vocab=vocab,
meta=meta,
disable=disable,
exclude=exclude,
config=config,
) )