Use frozen list with custom errors

We don't want to break backwards compatibility too much but we also want to provide the best possible UX
This commit is contained in:
Ines Montani 2020-08-29 15:20:11 +02:00
parent 6520d1a1df
commit 34146750d4
13 changed files with 195 additions and 78 deletions

View File

@ -27,8 +27,8 @@ if sys.maxunicode == 65535:
def load( def load(
name: Union[str, Path], name: Union[str, Path],
disable: Iterable[str] = tuple(), disable: Iterable[str] = util.SimpleFrozenList(),
exclude: Iterable[str] = tuple(), exclude: Iterable[str] = util.SimpleFrozenList(),
config: Union[Dict[str, Any], Config] = util.SimpleFrozenDict(), config: Union[Dict[str, Any], Config] = util.SimpleFrozenDict(),
) -> Language: ) -> Language:
"""Load a spaCy model from an installed package or a local path. """Load a spaCy model from an installed package or a local path.

View File

@ -1,6 +1,6 @@
"""This module contains helpers and subcommands for integrating spaCy projects """This module contains helpers and subcommands for integrating spaCy projects
with Data Version Controk (DVC). https://dvc.org""" with Data Version Controk (DVC). https://dvc.org"""
from typing import Dict, Any, List, Optional from typing import Dict, Any, List, Optional, Iterable
import subprocess import subprocess
from pathlib import Path from pathlib import Path
from wasabi import msg from wasabi import msg
@ -8,6 +8,7 @@ from wasabi import msg
from .._util import PROJECT_FILE, load_project_config, get_hash, project_cli from .._util import PROJECT_FILE, load_project_config, get_hash, project_cli
from .._util import Arg, Opt, NAME, COMMAND from .._util import Arg, Opt, NAME, COMMAND
from ...util import working_dir, split_command, join_command, run_command from ...util import working_dir, split_command, join_command, run_command
from ...util import SimpleFrozenList
DVC_CONFIG = "dvc.yaml" DVC_CONFIG = "dvc.yaml"
@ -130,7 +131,7 @@ def update_dvc_config(
def run_dvc_commands( def run_dvc_commands(
commands: List[str] = tuple(), flags: Dict[str, bool] = {}, commands: Iterable[str] = SimpleFrozenList(), flags: Dict[str, bool] = {},
) -> None: ) -> None:
"""Run a sequence of DVC commands in a subprocess, in order. """Run a sequence of DVC commands in a subprocess, in order.

View File

@ -1,10 +1,11 @@
from typing import Optional, List, Dict, Sequence, Any from typing import Optional, List, Dict, Sequence, Any, Iterable
from pathlib import Path from pathlib import Path
from wasabi import msg from wasabi import msg
import sys import sys
import srsly import srsly
from ...util import working_dir, run_command, split_command, is_cwd, join_command from ...util import working_dir, run_command, split_command, is_cwd, join_command
from ...util import SimpleFrozenList
from .._util import PROJECT_FILE, PROJECT_LOCK, load_project_config, get_hash from .._util import PROJECT_FILE, PROJECT_LOCK, load_project_config, get_hash
from .._util import get_checksum, project_cli, Arg, Opt, COMMAND from .._util import get_checksum, project_cli, Arg, Opt, COMMAND
@ -115,7 +116,9 @@ def print_run_help(project_dir: Path, subcommand: Optional[str] = None) -> None:
def run_commands( def run_commands(
commands: List[str] = tuple(), silent: bool = False, dry: bool = False, commands: Iterable[str] = SimpleFrozenList(),
silent: bool = False,
dry: bool = False,
) -> None: ) -> None:
"""Run a sequence of commands in a subprocess, in order. """Run a sequence of commands in a subprocess, in order.

View File

@ -472,6 +472,13 @@ class Errors:
E199 = ("Unable to merge 0-length span at doc[{start}:{end}].") E199 = ("Unable to merge 0-length span at doc[{start}:{end}].")
# TODO: fix numbering after merging develop into master # TODO: fix numbering after merging develop into master
E926 = ("It looks like you're trying to modify nlp.{attr} directly. This "
"doesn't work because it's an immutable computed property. If you "
"need to modify the pipeline, use the built-in methods like "
"nlp.add_pipe, nlp.remove_pipe, nlp.disable_pipe or nlp.enable_pipe "
"instead.")
E927 = ("Can't write to frozen list Maybe you're trying to modify a computed "
"property or default function argument?")
E928 = ("A 'KnowledgeBase' should be written to / read from a file, but the " E928 = ("A 'KnowledgeBase' should be written to / read from a file, but the "
"provided argument {loc} is an existing directory.") "provided argument {loc} is an existing directory.")
E929 = ("A 'KnowledgeBase' could not be read from {loc} - the path does " E929 = ("A 'KnowledgeBase' could not be read from {loc} - the path does "

View File

@ -20,7 +20,7 @@ from .vocab import Vocab, create_vocab
from .pipe_analysis import validate_attrs, analyze_pipes, print_pipe_analysis from .pipe_analysis import validate_attrs, analyze_pipes, print_pipe_analysis
from .gold import Example, validate_examples from .gold import Example, validate_examples
from .scorer import Scorer from .scorer import Scorer
from .util import create_default_optimizer, registry from .util import create_default_optimizer, registry, SimpleFrozenList
from .util import SimpleFrozenDict, combine_score_weights, CONFIG_SECTION_ORDER from .util import SimpleFrozenDict, combine_score_weights, CONFIG_SECTION_ORDER
from .lang.tokenizer_exceptions import URL_MATCH, BASE_EXCEPTIONS from .lang.tokenizer_exceptions import URL_MATCH, BASE_EXCEPTIONS
from .lang.punctuation import TOKENIZER_PREFIXES, TOKENIZER_SUFFIXES from .lang.punctuation import TOKENIZER_PREFIXES, TOKENIZER_SUFFIXES
@ -159,7 +159,7 @@ class Language:
self.vocab: Vocab = vocab self.vocab: Vocab = vocab
if self.lang is None: if self.lang is None:
self.lang = self.vocab.lang self.lang = self.vocab.lang
self.components = [] self._components = []
self._disabled = set() self._disabled = set()
self.max_length = max_length self.max_length = max_length
self.resolved = {} self.resolved = {}
@ -207,11 +207,11 @@ class Language:
"keys": self.vocab.vectors.n_keys, "keys": self.vocab.vectors.n_keys,
"name": self.vocab.vectors.name, "name": self.vocab.vectors.name,
} }
self._meta["labels"] = self.pipe_labels self._meta["labels"] = dict(self.pipe_labels)
# TODO: Adding this back to prevent breaking people's code etc., but # TODO: Adding this back to prevent breaking people's code etc., but
# we should consider removing it # we should consider removing it
self._meta["pipeline"] = self.pipe_names self._meta["pipeline"] = list(self.pipe_names)
self._meta["disabled"] = self.disabled self._meta["disabled"] = list(self.disabled)
return self._meta return self._meta
@meta.setter @meta.setter
@ -240,8 +240,8 @@ class Language:
pipeline[pipe_name] = {"factory": pipe_meta.factory, **pipe_config} pipeline[pipe_name] = {"factory": pipe_meta.factory, **pipe_config}
if pipe_meta.default_score_weights: if pipe_meta.default_score_weights:
score_weights.append(pipe_meta.default_score_weights) score_weights.append(pipe_meta.default_score_weights)
self._config["nlp"]["pipeline"] = self.component_names self._config["nlp"]["pipeline"] = list(self.component_names)
self._config["nlp"]["disabled"] = self.disabled self._config["nlp"]["disabled"] = list(self.disabled)
self._config["components"] = pipeline self._config["components"] = pipeline
self._config["training"]["score_weights"] = combine_score_weights(score_weights) self._config["training"]["score_weights"] = combine_score_weights(score_weights)
if not srsly.is_json_serializable(self._config): if not srsly.is_json_serializable(self._config):
@ -260,7 +260,8 @@ class Language:
""" """
# Make sure the disabled components are returned in the order they # Make sure the disabled components are returned in the order they
# appear in the pipeline (which isn't guaranteed by the set) # appear in the pipeline (which isn't guaranteed by the set)
return [name for name, _ in self.components if name in self._disabled] names = [name for name, _ in self._components if name in self._disabled]
return SimpleFrozenList(names, error=Errors.E926.format(attr="disabled"))
@property @property
def factory_names(self) -> List[str]: def factory_names(self) -> List[str]:
@ -268,7 +269,17 @@ class Language:
RETURNS (List[str]): The factory names. RETURNS (List[str]): The factory names.
""" """
return list(self.factories.keys()) names = list(self.factories.keys())
return SimpleFrozenList(names)
@property
def components(self) -> List[Tuple[str, Callable[[Doc], Doc]]]:
"""Get all (name, component) tuples in the pipeline, including the
currently disabled components.
"""
return SimpleFrozenList(
self._components, error=Errors.E926.format(attr="components")
)
@property @property
def component_names(self) -> List[str]: def component_names(self) -> List[str]:
@ -277,7 +288,8 @@ class Language:
RETURNS (List[str]): List of component name strings, in order. RETURNS (List[str]): List of component name strings, in order.
""" """
return [pipe_name for pipe_name, _ in self.components] names = [pipe_name for pipe_name, _ in self._components]
return SimpleFrozenList(names, error=Errors.E926.format(attr="component_names"))
@property @property
def pipeline(self) -> List[Tuple[str, Callable[[Doc], Doc]]]: def pipeline(self) -> List[Tuple[str, Callable[[Doc], Doc]]]:
@ -287,7 +299,8 @@ class Language:
RETURNS (List[Tuple[str, Callable[[Doc], Doc]]]): The pipeline. RETURNS (List[Tuple[str, Callable[[Doc], Doc]]]): The pipeline.
""" """
return [(name, p) for name, p in self.components if name not in self._disabled] pipes = [(n, p) for n, p in self._components if n not in self._disabled]
return SimpleFrozenList(pipes, error=Errors.E926.format(attr="pipeline"))
@property @property
def pipe_names(self) -> List[str]: def pipe_names(self) -> List[str]:
@ -295,7 +308,8 @@ class Language:
RETURNS (List[str]): List of component name strings, in order. RETURNS (List[str]): List of component name strings, in order.
""" """
return [pipe_name for pipe_name, _ in self.pipeline] names = [pipe_name for pipe_name, _ in self.pipeline]
return SimpleFrozenList(names, error=Errors.E926.format(attr="pipe_names"))
@property @property
def pipe_factories(self) -> Dict[str, str]: def pipe_factories(self) -> Dict[str, str]:
@ -304,9 +318,9 @@ class Language:
RETURNS (Dict[str, str]): Factory names, keyed by component names. RETURNS (Dict[str, str]): Factory names, keyed by component names.
""" """
factories = {} factories = {}
for pipe_name, pipe in self.components: for pipe_name, pipe in self._components:
factories[pipe_name] = self.get_pipe_meta(pipe_name).factory factories[pipe_name] = self.get_pipe_meta(pipe_name).factory
return factories return SimpleFrozenDict(factories)
@property @property
def pipe_labels(self) -> Dict[str, List[str]]: def pipe_labels(self) -> Dict[str, List[str]]:
@ -316,10 +330,10 @@ class Language:
RETURNS (Dict[str, List[str]]): Labels keyed by component name. RETURNS (Dict[str, List[str]]): Labels keyed by component name.
""" """
labels = {} labels = {}
for name, pipe in self.components: for name, pipe in self._components:
if hasattr(pipe, "labels"): if hasattr(pipe, "labels"):
labels[name] = list(pipe.labels) labels[name] = list(pipe.labels)
return labels return SimpleFrozenDict(labels)
@classmethod @classmethod
def has_factory(cls, name: str) -> bool: def has_factory(cls, name: str) -> bool:
@ -390,10 +404,10 @@ class Language:
name: str, name: str,
*, *,
default_config: Dict[str, Any] = SimpleFrozenDict(), default_config: Dict[str, Any] = SimpleFrozenDict(),
assigns: Iterable[str] = tuple(), assigns: Iterable[str] = SimpleFrozenList(),
requires: Iterable[str] = tuple(), requires: Iterable[str] = SimpleFrozenList(),
retokenizes: bool = False, retokenizes: bool = False,
scores: Iterable[str] = tuple(), scores: Iterable[str] = SimpleFrozenList(),
default_score_weights: Dict[str, float] = SimpleFrozenDict(), default_score_weights: Dict[str, float] = SimpleFrozenDict(),
func: Optional[Callable] = None, func: Optional[Callable] = None,
) -> Callable: ) -> Callable:
@ -471,8 +485,8 @@ class Language:
cls, cls,
name: Optional[str] = None, name: Optional[str] = None,
*, *,
assigns: Iterable[str] = tuple(), assigns: Iterable[str] = SimpleFrozenList(),
requires: Iterable[str] = tuple(), requires: Iterable[str] = SimpleFrozenList(),
retokenizes: bool = False, retokenizes: bool = False,
func: Optional[Callable[[Doc], Doc]] = None, func: Optional[Callable[[Doc], Doc]] = None,
) -> Callable: ) -> Callable:
@ -544,7 +558,7 @@ class Language:
DOCS: https://spacy.io/api/language#get_pipe DOCS: https://spacy.io/api/language#get_pipe
""" """
for pipe_name, component in self.components: for pipe_name, component in self._components:
if pipe_name == name: if pipe_name == name:
return component return component
raise KeyError(Errors.E001.format(name=name, opts=self.component_names)) raise KeyError(Errors.E001.format(name=name, opts=self.component_names))
@ -718,7 +732,7 @@ class Language:
) )
pipe_index = self._get_pipe_index(before, after, first, last) pipe_index = self._get_pipe_index(before, after, first, last)
self._pipe_meta[name] = self.get_factory_meta(factory_name) self._pipe_meta[name] = self.get_factory_meta(factory_name)
self.components.insert(pipe_index, (name, pipe_component)) self._components.insert(pipe_index, (name, pipe_component))
return pipe_component return pipe_component
def _get_pipe_index( def _get_pipe_index(
@ -743,7 +757,7 @@ class Language:
Errors.E006.format(args=all_args, opts=self.component_names) Errors.E006.format(args=all_args, opts=self.component_names)
) )
if last or not any(value is not None for value in [first, before, after]): if last or not any(value is not None for value in [first, before, after]):
return len(self.components) return len(self._components)
elif first: elif first:
return 0 return 0
elif isinstance(before, str): elif isinstance(before, str):
@ -761,14 +775,14 @@ class Language:
# We're only accepting indices referring to components that exist # We're only accepting indices referring to components that exist
# (can't just do isinstance here because bools are instance of int, too) # (can't just do isinstance here because bools are instance of int, too)
elif type(before) == int: elif type(before) == int:
if before >= len(self.components) or before < 0: if before >= len(self._components) or before < 0:
err = Errors.E959.format( err = Errors.E959.format(
dir="before", idx=before, opts=self.component_names dir="before", idx=before, opts=self.component_names
) )
raise ValueError(err) raise ValueError(err)
return before return before
elif type(after) == int: elif type(after) == int:
if after >= len(self.components) or after < 0: if after >= len(self._components) or after < 0:
err = Errors.E959.format( err = Errors.E959.format(
dir="after", idx=after, opts=self.component_names dir="after", idx=after, opts=self.component_names
) )
@ -815,7 +829,7 @@ class Language:
# to Language.pipeline to make sure the configs are handled correctly # to Language.pipeline to make sure the configs are handled correctly
pipe_index = self.pipe_names.index(name) pipe_index = self.pipe_names.index(name)
self.remove_pipe(name) self.remove_pipe(name)
if not len(self.components) or pipe_index == len(self.components): if not len(self._components) or pipe_index == len(self._components):
# we have no components to insert before/after, or we're replacing the last component # we have no components to insert before/after, or we're replacing the last component
self.add_pipe(factory_name, name=name, config=config, validate=validate) self.add_pipe(factory_name, name=name, config=config, validate=validate)
else: else:
@ -844,7 +858,7 @@ class Language:
Errors.E007.format(name=new_name, opts=self.component_names) Errors.E007.format(name=new_name, opts=self.component_names)
) )
i = self.component_names.index(old_name) i = self.component_names.index(old_name)
self.components[i] = (new_name, self.components[i][1]) self._components[i] = (new_name, self._components[i][1])
self._pipe_meta[new_name] = self._pipe_meta.pop(old_name) self._pipe_meta[new_name] = self._pipe_meta.pop(old_name)
self._pipe_configs[new_name] = self._pipe_configs.pop(old_name) self._pipe_configs[new_name] = self._pipe_configs.pop(old_name)
@ -858,7 +872,7 @@ class Language:
""" """
if name not in self.component_names: if name not in self.component_names:
raise ValueError(Errors.E001.format(name=name, opts=self.component_names)) raise ValueError(Errors.E001.format(name=name, opts=self.component_names))
removed = self.components.pop(self.component_names.index(name)) removed = self._components.pop(self.component_names.index(name))
# We're only removing the component itself from the metas/configs here # We're only removing the component itself from the metas/configs here
# because factory may be used for something else # because factory may be used for something else
self._pipe_meta.pop(name) self._pipe_meta.pop(name)
@ -894,7 +908,7 @@ class Language:
self, self,
text: str, text: str,
*, *,
disable: Iterable[str] = tuple(), disable: Iterable[str] = SimpleFrozenList(),
component_cfg: Optional[Dict[str, Dict[str, Any]]] = None, component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
) -> Doc: ) -> Doc:
"""Apply the pipeline to some text. The text can span multiple sentences, """Apply the pipeline to some text. The text can span multiple sentences,
@ -993,7 +1007,7 @@ class Language:
sgd: Optional[Optimizer] = None, sgd: Optional[Optimizer] = None,
losses: Optional[Dict[str, float]] = None, losses: Optional[Dict[str, float]] = None,
component_cfg: Optional[Dict[str, Dict[str, Any]]] = None, component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
exclude: Iterable[str] = tuple(), exclude: Iterable[str] = SimpleFrozenList(),
): ):
"""Update the models in the pipeline. """Update the models in the pipeline.
@ -1047,7 +1061,7 @@ class Language:
sgd: Optional[Optimizer] = None, sgd: Optional[Optimizer] = None,
losses: Optional[Dict[str, float]] = None, losses: Optional[Dict[str, float]] = None,
component_cfg: Optional[Dict[str, Dict[str, Any]]] = None, component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
exclude: Iterable[str] = tuple(), exclude: Iterable[str] = SimpleFrozenList(),
) -> Dict[str, float]: ) -> Dict[str, float]:
"""Make a "rehearsal" update to the models in the pipeline, to prevent """Make a "rehearsal" update to the models in the pipeline, to prevent
forgetting. Rehearsal updates run an initial copy of the model over some forgetting. Rehearsal updates run an initial copy of the model over some
@ -1276,7 +1290,7 @@ class Language:
*, *,
as_tuples: bool = False, as_tuples: bool = False,
batch_size: int = 1000, batch_size: int = 1000,
disable: Iterable[str] = tuple(), disable: Iterable[str] = SimpleFrozenList(),
cleanup: bool = False, cleanup: bool = False,
component_cfg: Optional[Dict[str, Dict[str, Any]]] = None, component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
n_process: int = 1, n_process: int = 1,
@ -1436,8 +1450,8 @@ class Language:
config: Union[Dict[str, Any], Config] = {}, config: Union[Dict[str, Any], Config] = {},
*, *,
vocab: Union[Vocab, bool] = True, vocab: Union[Vocab, bool] = True,
disable: Iterable[str] = tuple(), disable: Iterable[str] = SimpleFrozenList(),
exclude: Iterable[str] = tuple(), exclude: Iterable[str] = SimpleFrozenList(),
auto_fill: bool = True, auto_fill: bool = True,
validate: bool = True, validate: bool = True,
) -> "Language": ) -> "Language":
@ -1562,7 +1576,7 @@ class Language:
return nlp return nlp
def to_disk( def to_disk(
self, path: Union[str, Path], *, exclude: Iterable[str] = tuple() self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList()
) -> None: ) -> None:
"""Save the current state to a directory. If a model is loaded, this """Save the current state to a directory. If a model is loaded, this
will include the model. will include the model.
@ -1580,7 +1594,7 @@ class Language:
) )
serializers["meta.json"] = lambda p: srsly.write_json(p, self.meta) serializers["meta.json"] = lambda p: srsly.write_json(p, self.meta)
serializers["config.cfg"] = lambda p: self.config.to_disk(p) serializers["config.cfg"] = lambda p: self.config.to_disk(p)
for name, proc in self.components: for name, proc in self._components:
if name in exclude: if name in exclude:
continue continue
if not hasattr(proc, "to_disk"): if not hasattr(proc, "to_disk"):
@ -1590,7 +1604,7 @@ class Language:
util.to_disk(path, serializers, exclude) util.to_disk(path, serializers, exclude)
def from_disk( def from_disk(
self, path: Union[str, Path], *, exclude: Iterable[str] = tuple() self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList()
) -> "Language": ) -> "Language":
"""Loads state from a directory. Modifies the object in place and """Loads state from a directory. Modifies the object in place and
returns it. If the saved `Language` object contains a model, the returns it. If the saved `Language` object contains a model, the
@ -1624,7 +1638,7 @@ class Language:
deserializers["tokenizer"] = lambda p: self.tokenizer.from_disk( deserializers["tokenizer"] = lambda p: self.tokenizer.from_disk(
p, exclude=["vocab"] p, exclude=["vocab"]
) )
for name, proc in self.components: for name, proc in self._components:
if name in exclude: if name in exclude:
continue continue
if not hasattr(proc, "from_disk"): if not hasattr(proc, "from_disk"):
@ -1640,7 +1654,7 @@ class Language:
self._link_components() self._link_components()
return self return self
def to_bytes(self, *, exclude: Iterable[str] = tuple()) -> bytes: def to_bytes(self, *, exclude: Iterable[str] = SimpleFrozenList()) -> bytes:
"""Serialize the current state to a binary string. """Serialize the current state to a binary string.
exclude (list): Names of components or serialization fields to exclude. exclude (list): Names of components or serialization fields to exclude.
@ -1653,7 +1667,7 @@ class Language:
serializers["tokenizer"] = lambda: self.tokenizer.to_bytes(exclude=["vocab"]) serializers["tokenizer"] = lambda: self.tokenizer.to_bytes(exclude=["vocab"])
serializers["meta.json"] = lambda: srsly.json_dumps(self.meta) serializers["meta.json"] = lambda: srsly.json_dumps(self.meta)
serializers["config.cfg"] = lambda: self.config.to_bytes() serializers["config.cfg"] = lambda: self.config.to_bytes()
for name, proc in self.components: for name, proc in self._components:
if name in exclude: if name in exclude:
continue continue
if not hasattr(proc, "to_bytes"): if not hasattr(proc, "to_bytes"):
@ -1662,7 +1676,7 @@ class Language:
return util.to_bytes(serializers, exclude) return util.to_bytes(serializers, exclude)
def from_bytes( def from_bytes(
self, bytes_data: bytes, *, exclude: Iterable[str] = tuple() self, bytes_data: bytes, *, exclude: Iterable[str] = SimpleFrozenList()
) -> "Language": ) -> "Language":
"""Load state from a binary string. """Load state from a binary string.
@ -1687,7 +1701,7 @@ class Language:
deserializers["tokenizer"] = lambda b: self.tokenizer.from_bytes( deserializers["tokenizer"] = lambda b: self.tokenizer.from_bytes(
b, exclude=["vocab"] b, exclude=["vocab"]
) )
for name, proc in self.components: for name, proc in self._components:
if name in exclude: if name in exclude:
continue continue
if not hasattr(proc, "from_bytes"): if not hasattr(proc, "from_bytes"):

View File

@ -12,6 +12,7 @@ from ..symbols import IDS, TAG, POS, MORPH, LEMMA
from ..tokens import Doc, Span from ..tokens import Doc, Span
from ..tokens._retokenize import normalize_token_attrs, set_token_attrs from ..tokens._retokenize import normalize_token_attrs, set_token_attrs
from ..vocab import Vocab from ..vocab import Vocab
from ..util import SimpleFrozenList
from .. import util from .. import util
@ -220,7 +221,7 @@ class AttributeRuler(Pipe):
results.update(Scorer.score_token_attr(examples, "lemma", **kwargs)) results.update(Scorer.score_token_attr(examples, "lemma", **kwargs))
return results return results
def to_bytes(self, exclude: Iterable[str] = tuple()) -> bytes: def to_bytes(self, exclude: Iterable[str] = SimpleFrozenList()) -> bytes:
"""Serialize the AttributeRuler to a bytestring. """Serialize the AttributeRuler to a bytestring.
exclude (Iterable[str]): String names of serialization fields to exclude. exclude (Iterable[str]): String names of serialization fields to exclude.
@ -236,7 +237,9 @@ class AttributeRuler(Pipe):
serialize["indices"] = lambda: srsly.msgpack_dumps(self.indices) serialize["indices"] = lambda: srsly.msgpack_dumps(self.indices)
return util.to_bytes(serialize, exclude) return util.to_bytes(serialize, exclude)
def from_bytes(self, bytes_data: bytes, exclude: Iterable[str] = tuple()): def from_bytes(
self, bytes_data: bytes, exclude: Iterable[str] = SimpleFrozenList()
):
"""Load the AttributeRuler from a bytestring. """Load the AttributeRuler from a bytestring.
bytes_data (bytes): The data to load. bytes_data (bytes): The data to load.
@ -272,7 +275,9 @@ class AttributeRuler(Pipe):
return self return self
def to_disk(self, path: Union[Path, str], exclude: Iterable[str] = tuple()) -> None: def to_disk(
self, path: Union[Path, str], exclude: Iterable[str] = SimpleFrozenList()
) -> None:
"""Serialize the AttributeRuler to disk. """Serialize the AttributeRuler to disk.
path (Union[Path, str]): A path to a directory. path (Union[Path, str]): A path to a directory.
@ -289,7 +294,7 @@ class AttributeRuler(Pipe):
util.to_disk(path, serialize, exclude) util.to_disk(path, serialize, exclude)
def from_disk( def from_disk(
self, path: Union[Path, str], exclude: Iterable[str] = tuple() self, path: Union[Path, str], exclude: Iterable[str] = SimpleFrozenList()
) -> None: ) -> None:
"""Load the AttributeRuler from disk. """Load the AttributeRuler from disk.

View File

@ -13,6 +13,7 @@ from ..language import Language
from ..vocab import Vocab from ..vocab import Vocab
from ..gold import Example, validate_examples from ..gold import Example, validate_examples
from ..errors import Errors, Warnings from ..errors import Errors, Warnings
from ..util import SimpleFrozenList
from .. import util from .. import util
@ -404,7 +405,7 @@ class EntityLinker(Pipe):
token.ent_kb_id_ = kb_id token.ent_kb_id_ = kb_id
def to_disk( def to_disk(
self, path: Union[str, Path], *, exclude: Iterable[str] = tuple() self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList(),
) -> None: ) -> None:
"""Serialize the pipe to disk. """Serialize the pipe to disk.
@ -421,7 +422,7 @@ class EntityLinker(Pipe):
util.to_disk(path, serialize, exclude) util.to_disk(path, serialize, exclude)
def from_disk( def from_disk(
self, path: Union[str, Path], *, exclude: Iterable[str] = tuple() self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList(),
) -> "EntityLinker": ) -> "EntityLinker":
"""Load the pipe from disk. Modifies the object in place and returns it. """Load the pipe from disk. Modifies the object in place and returns it.

View File

@ -5,7 +5,7 @@ import srsly
from ..language import Language from ..language import Language
from ..errors import Errors from ..errors import Errors
from ..util import ensure_path, to_disk, from_disk from ..util import ensure_path, to_disk, from_disk, SimpleFrozenList
from ..tokens import Doc, Span from ..tokens import Doc, Span
from ..matcher import Matcher, PhraseMatcher from ..matcher import Matcher, PhraseMatcher
from ..scorer import Scorer from ..scorer import Scorer
@ -317,7 +317,7 @@ class EntityRuler:
return Scorer.score_spans(examples, "ents", **kwargs) return Scorer.score_spans(examples, "ents", **kwargs)
def from_bytes( def from_bytes(
self, patterns_bytes: bytes, *, exclude: Iterable[str] = tuple() self, patterns_bytes: bytes, *, exclude: Iterable[str] = SimpleFrozenList()
) -> "EntityRuler": ) -> "EntityRuler":
"""Load the entity ruler from a bytestring. """Load the entity ruler from a bytestring.
@ -341,7 +341,7 @@ class EntityRuler:
self.add_patterns(cfg) self.add_patterns(cfg)
return self return self
def to_bytes(self, *, exclude: Iterable[str] = tuple()) -> bytes: def to_bytes(self, *, exclude: Iterable[str] = SimpleFrozenList()) -> bytes:
"""Serialize the entity ruler patterns to a bytestring. """Serialize the entity ruler patterns to a bytestring.
RETURNS (bytes): The serialized patterns. RETURNS (bytes): The serialized patterns.
@ -357,7 +357,7 @@ class EntityRuler:
return srsly.msgpack_dumps(serial) return srsly.msgpack_dumps(serial)
def from_disk( def from_disk(
self, path: Union[str, Path], *, exclude: Iterable[str] = tuple() self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList()
) -> "EntityRuler": ) -> "EntityRuler":
"""Load the entity ruler from a file. Expects a file containing """Load the entity ruler from a file. Expects a file containing
newline-delimited JSON (JSONL) with one entry per line. newline-delimited JSON (JSONL) with one entry per line.
@ -394,7 +394,7 @@ class EntityRuler:
return self return self
def to_disk( def to_disk(
self, path: Union[str, Path], *, exclude: Iterable[str] = tuple() self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList()
) -> None: ) -> None:
"""Save the entity ruler patterns to a directory. The patterns will be """Save the entity ruler patterns to a directory. The patterns will be
saved as newline-delimited JSON (JSONL). saved as newline-delimited JSON (JSONL).

View File

@ -1,10 +1,10 @@
from typing import Optional, Iterable, Dict, Any, Callable, Tuple, TYPE_CHECKING from typing import Optional, Iterable, Dict, Any, Callable, TYPE_CHECKING
import numpy as np import numpy as np
from .gold import Example from .gold import Example
from .tokens import Token, Doc, Span from .tokens import Token, Doc, Span
from .errors import Errors from .errors import Errors
from .util import get_lang_class from .util import get_lang_class, SimpleFrozenList
from .morphology import Morphology from .morphology import Morphology
if TYPE_CHECKING: if TYPE_CHECKING:
@ -317,7 +317,7 @@ class Scorer:
attr: str, attr: str,
*, *,
getter: Callable[[Doc, str], Any] = getattr, getter: Callable[[Doc, str], Any] = getattr,
labels: Iterable[str] = tuple(), labels: Iterable[str] = SimpleFrozenList(),
multi_label: bool = True, multi_label: bool = True,
positive_label: Optional[str] = None, positive_label: Optional[str] = None,
threshold: Optional[float] = None, threshold: Optional[float] = None,
@ -447,7 +447,7 @@ class Scorer:
getter: Callable[[Token, str], Any] = getattr, getter: Callable[[Token, str], Any] = getattr,
head_attr: str = "head", head_attr: str = "head",
head_getter: Callable[[Token, str], Token] = getattr, head_getter: Callable[[Token, str], Token] = getattr,
ignore_labels: Tuple[str] = tuple(), ignore_labels: Iterable[str] = SimpleFrozenList(),
**cfg, **cfg,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Returns the UAS, LAS, and LAS per type scores for dependency """Returns the UAS, LAS, and LAS per type scores for dependency

View File

@ -1,5 +1,6 @@
import pytest import pytest
from spacy.language import Language from spacy.language import Language
from spacy.util import SimpleFrozenList
@pytest.fixture @pytest.fixture
@ -317,3 +318,31 @@ def test_disable_enable_pipes():
assert nlp.config["nlp"]["disabled"] == [name] assert nlp.config["nlp"]["disabled"] == [name]
nlp("?") nlp("?")
assert results[f"{name}1"] == "!" assert results[f"{name}1"] == "!"
def test_pipe_methods_frozen():
"""Test that spaCy raises custom error messages if "frozen" properties are
accessed. We still want to use a list here to not break backwards
compatibility, but users should see an error if they're trying to append
to nlp.pipeline etc."""
nlp = Language()
ner = nlp.add_pipe("ner")
assert nlp.pipe_names == ["ner"]
for prop in [
nlp.pipeline,
nlp.pipe_names,
nlp.components,
nlp.component_names,
nlp.disabled,
nlp.factory_names,
]:
assert isinstance(prop, list)
assert isinstance(prop, SimpleFrozenList)
with pytest.raises(NotImplementedError):
nlp.pipeline.append(("ner2", ner))
with pytest.raises(NotImplementedError):
nlp.pipe_names.pop()
with pytest.raises(NotImplementedError):
nlp.components.sort()
with pytest.raises(NotImplementedError):
nlp.component_names.clear()

View File

@ -3,10 +3,9 @@ import pytest
from .util import get_random_doc from .util import get_random_doc
from spacy import util from spacy import util
from spacy.util import dot_to_object from spacy.util import dot_to_object, SimpleFrozenList
from thinc.api import Config, Optimizer from thinc.api import Config, Optimizer
from spacy.gold.batchers import minibatch_by_words from spacy.gold.batchers import minibatch_by_words
from ..lang.en import English from ..lang.en import English
from ..lang.nl import Dutch from ..lang.nl import Dutch
from ..language import DEFAULT_CONFIG_PATH from ..language import DEFAULT_CONFIG_PATH
@ -106,3 +105,20 @@ def test_util_dot_section():
assert not dot_to_object(en_config, "nlp.load_vocab_data") assert not dot_to_object(en_config, "nlp.load_vocab_data")
assert dot_to_object(nl_config, "nlp.load_vocab_data") assert dot_to_object(nl_config, "nlp.load_vocab_data")
assert isinstance(dot_to_object(nl_config, "training.optimizer"), Optimizer) assert isinstance(dot_to_object(nl_config, "training.optimizer"), Optimizer)
def test_simple_frozen_list():
t = SimpleFrozenList(["foo", "bar"])
assert t == ["foo", "bar"]
assert t.index("bar") == 1 # okay method
with pytest.raises(NotImplementedError):
t.append("baz")
with pytest.raises(NotImplementedError):
t.sort()
with pytest.raises(NotImplementedError):
t.extend(["baz"])
with pytest.raises(NotImplementedError):
t.pop()
t = SimpleFrozenList(["foo", "bar"], error="Error!")
with pytest.raises(NotImplementedError):
t.append("baz")

View File

@ -10,7 +10,7 @@ from ..vocab import Vocab
from ..compat import copy_reg from ..compat import copy_reg
from ..attrs import SPACY, ORTH, intify_attr from ..attrs import SPACY, ORTH, intify_attr
from ..errors import Errors from ..errors import Errors
from ..util import ensure_path from ..util import ensure_path, SimpleFrozenList
# fmt: off # fmt: off
ALL_ATTRS = ("ORTH", "TAG", "HEAD", "DEP", "ENT_IOB", "ENT_TYPE", "ENT_KB_ID", "LEMMA", "MORPH", "POS") ALL_ATTRS = ("ORTH", "TAG", "HEAD", "DEP", "ENT_IOB", "ENT_TYPE", "ENT_KB_ID", "LEMMA", "MORPH", "POS")
@ -52,7 +52,7 @@ class DocBin:
self, self,
attrs: Iterable[str] = ALL_ATTRS, attrs: Iterable[str] = ALL_ATTRS,
store_user_data: bool = False, store_user_data: bool = False,
docs: Iterable[Doc] = tuple(), docs: Iterable[Doc] = SimpleFrozenList(),
) -> None: ) -> None:
"""Create a DocBin object to hold serialized annotations. """Create a DocBin object to hold serialized annotations.

View File

@ -120,6 +120,47 @@ class SimpleFrozenDict(dict):
raise NotImplementedError(self.error) raise NotImplementedError(self.error)
class SimpleFrozenList(list):
"""Wrapper class around a list that lets us raise custom errors if certain
attributes/methods are accessed. Mostly used for properties like
Language.pipeline that return an immutable list (and that we don't want to
convert to a tuple to not break too much backwards compatibility). If a user
accidentally calls nlp.pipeline.append(), we can raise a more helpful error.
"""
def __init__(self, *args, error: str = Errors.E927) -> None:
"""Initialize the frozen list.
error (str): The error message when user tries to mutate the list.
"""
self.error = error
super().__init__(*args)
def append(self, *args, **kwargs):
raise NotImplementedError(self.error)
def clear(self, *args, **kwargs):
raise NotImplementedError(self.error)
def extend(self, *args, **kwargs):
raise NotImplementedError(self.error)
def insert(self, *args, **kwargs):
raise NotImplementedError(self.error)
def pop(self, *args, **kwargs):
raise NotImplementedError(self.error)
def remove(self, *args, **kwargs):
raise NotImplementedError(self.error)
def reverse(self, *args, **kwargs):
raise NotImplementedError(self.error)
def sort(self, *args, **kwargs):
raise NotImplementedError(self.error)
def lang_class_is_loaded(lang: str) -> bool: def lang_class_is_loaded(lang: str) -> bool:
"""Check whether a Language class is already loaded. Language classes are """Check whether a Language class is already loaded. Language classes are
loaded lazily, to avoid expensive setup code associated with the language loaded lazily, to avoid expensive setup code associated with the language
@ -215,8 +256,8 @@ def load_model(
name: Union[str, Path], name: Union[str, Path],
*, *,
vocab: Union["Vocab", bool] = True, vocab: Union["Vocab", bool] = True,
disable: Iterable[str] = tuple(), disable: Iterable[str] = SimpleFrozenList(),
exclude: Iterable[str] = tuple(), exclude: Iterable[str] = SimpleFrozenList(),
config: Union[Dict[str, Any], Config] = SimpleFrozenDict(), config: Union[Dict[str, Any], Config] = SimpleFrozenDict(),
) -> "Language": ) -> "Language":
"""Load a model from a package or data path. """Load a model from a package or data path.
@ -248,8 +289,8 @@ def load_model_from_package(
name: str, name: str,
*, *,
vocab: Union["Vocab", bool] = True, vocab: Union["Vocab", bool] = True,
disable: Iterable[str] = tuple(), disable: Iterable[str] = SimpleFrozenList(),
exclude: Iterable[str] = tuple(), exclude: Iterable[str] = SimpleFrozenList(),
config: Union[Dict[str, Any], Config] = SimpleFrozenDict(), config: Union[Dict[str, Any], Config] = SimpleFrozenDict(),
) -> "Language": ) -> "Language":
"""Load a model from an installed package. """Load a model from an installed package.
@ -275,8 +316,8 @@ def load_model_from_path(
*, *,
meta: Optional[Dict[str, Any]] = None, meta: Optional[Dict[str, Any]] = None,
vocab: Union["Vocab", bool] = True, vocab: Union["Vocab", bool] = True,
disable: Iterable[str] = tuple(), disable: Iterable[str] = SimpleFrozenList(),
exclude: Iterable[str] = tuple(), exclude: Iterable[str] = SimpleFrozenList(),
config: Union[Dict[str, Any], Config] = SimpleFrozenDict(), config: Union[Dict[str, Any], Config] = SimpleFrozenDict(),
) -> "Language": ) -> "Language":
"""Load a model from a data directory path. Creates Language class with """Load a model from a data directory path. Creates Language class with
@ -311,8 +352,8 @@ def load_model_from_config(
config: Union[Dict[str, Any], Config], config: Union[Dict[str, Any], Config],
*, *,
vocab: Union["Vocab", bool] = True, vocab: Union["Vocab", bool] = True,
disable: Iterable[str] = tuple(), disable: Iterable[str] = SimpleFrozenList(),
exclude: Iterable[str] = tuple(), exclude: Iterable[str] = SimpleFrozenList(),
auto_fill: bool = False, auto_fill: bool = False,
validate: bool = True, validate: bool = True,
) -> Tuple["Language", Config]: ) -> Tuple["Language", Config]:
@ -355,8 +396,8 @@ def load_model_from_init_py(
init_file: Union[Path, str], init_file: Union[Path, str],
*, *,
vocab: Union["Vocab", bool] = True, vocab: Union["Vocab", bool] = True,
disable: Iterable[str] = tuple(), disable: Iterable[str] = SimpleFrozenList(),
exclude: Iterable[str] = tuple(), exclude: Iterable[str] = SimpleFrozenList(),
config: Union[Dict[str, Any], Config] = SimpleFrozenDict(), config: Union[Dict[str, Any], Config] = SimpleFrozenDict(),
) -> "Language": ) -> "Language":
"""Helper function to use in the `load()` method of a model package's """Helper function to use in the `load()` method of a model package's