Use frozen list with custom errors

We don't want to break backwards compatibility too much but we also want to provide the best possible UX
This commit is contained in:
Ines Montani 2020-08-29 15:20:11 +02:00
parent 6520d1a1df
commit 34146750d4
13 changed files with 195 additions and 78 deletions

View File

@ -27,8 +27,8 @@ if sys.maxunicode == 65535:
def load(
name: Union[str, Path],
disable: Iterable[str] = tuple(),
exclude: Iterable[str] = tuple(),
disable: Iterable[str] = util.SimpleFrozenList(),
exclude: Iterable[str] = util.SimpleFrozenList(),
config: Union[Dict[str, Any], Config] = util.SimpleFrozenDict(),
) -> Language:
"""Load a spaCy model from an installed package or a local path.

View File

@ -1,6 +1,6 @@
"""This module contains helpers and subcommands for integrating spaCy projects
with Data Version Controk (DVC). https://dvc.org"""
from typing import Dict, Any, List, Optional
from typing import Dict, Any, List, Optional, Iterable
import subprocess
from pathlib import Path
from wasabi import msg
@ -8,6 +8,7 @@ from wasabi import msg
from .._util import PROJECT_FILE, load_project_config, get_hash, project_cli
from .._util import Arg, Opt, NAME, COMMAND
from ...util import working_dir, split_command, join_command, run_command
from ...util import SimpleFrozenList
DVC_CONFIG = "dvc.yaml"
@ -130,7 +131,7 @@ def update_dvc_config(
def run_dvc_commands(
commands: List[str] = tuple(), flags: Dict[str, bool] = {},
commands: Iterable[str] = SimpleFrozenList(), flags: Dict[str, bool] = {},
) -> None:
"""Run a sequence of DVC commands in a subprocess, in order.

View File

@ -1,10 +1,11 @@
from typing import Optional, List, Dict, Sequence, Any
from typing import Optional, List, Dict, Sequence, Any, Iterable
from pathlib import Path
from wasabi import msg
import sys
import srsly
from ...util import working_dir, run_command, split_command, is_cwd, join_command
from ...util import SimpleFrozenList
from .._util import PROJECT_FILE, PROJECT_LOCK, load_project_config, get_hash
from .._util import get_checksum, project_cli, Arg, Opt, COMMAND
@ -115,7 +116,9 @@ def print_run_help(project_dir: Path, subcommand: Optional[str] = None) -> None:
def run_commands(
commands: List[str] = tuple(), silent: bool = False, dry: bool = False,
commands: Iterable[str] = SimpleFrozenList(),
silent: bool = False,
dry: bool = False,
) -> None:
"""Run a sequence of commands in a subprocess, in order.

View File

@ -472,6 +472,13 @@ class Errors:
E199 = ("Unable to merge 0-length span at doc[{start}:{end}].")
# TODO: fix numbering after merging develop into master
E926 = ("It looks like you're trying to modify nlp.{attr} directly. This "
"doesn't work because it's an immutable computed property. If you "
"need to modify the pipeline, use the built-in methods like "
"nlp.add_pipe, nlp.remove_pipe, nlp.disable_pipe or nlp.enable_pipe "
"instead.")
E927 = ("Can't write to frozen list Maybe you're trying to modify a computed "
"property or default function argument?")
E928 = ("A 'KnowledgeBase' should be written to / read from a file, but the "
"provided argument {loc} is an existing directory.")
E929 = ("A 'KnowledgeBase' could not be read from {loc} - the path does "

View File

@ -20,7 +20,7 @@ from .vocab import Vocab, create_vocab
from .pipe_analysis import validate_attrs, analyze_pipes, print_pipe_analysis
from .gold import Example, validate_examples
from .scorer import Scorer
from .util import create_default_optimizer, registry
from .util import create_default_optimizer, registry, SimpleFrozenList
from .util import SimpleFrozenDict, combine_score_weights, CONFIG_SECTION_ORDER
from .lang.tokenizer_exceptions import URL_MATCH, BASE_EXCEPTIONS
from .lang.punctuation import TOKENIZER_PREFIXES, TOKENIZER_SUFFIXES
@ -159,7 +159,7 @@ class Language:
self.vocab: Vocab = vocab
if self.lang is None:
self.lang = self.vocab.lang
self.components = []
self._components = []
self._disabled = set()
self.max_length = max_length
self.resolved = {}
@ -207,11 +207,11 @@ class Language:
"keys": self.vocab.vectors.n_keys,
"name": self.vocab.vectors.name,
}
self._meta["labels"] = self.pipe_labels
self._meta["labels"] = dict(self.pipe_labels)
# TODO: Adding this back to prevent breaking people's code etc., but
# we should consider removing it
self._meta["pipeline"] = self.pipe_names
self._meta["disabled"] = self.disabled
self._meta["pipeline"] = list(self.pipe_names)
self._meta["disabled"] = list(self.disabled)
return self._meta
@meta.setter
@ -240,8 +240,8 @@ class Language:
pipeline[pipe_name] = {"factory": pipe_meta.factory, **pipe_config}
if pipe_meta.default_score_weights:
score_weights.append(pipe_meta.default_score_weights)
self._config["nlp"]["pipeline"] = self.component_names
self._config["nlp"]["disabled"] = self.disabled
self._config["nlp"]["pipeline"] = list(self.component_names)
self._config["nlp"]["disabled"] = list(self.disabled)
self._config["components"] = pipeline
self._config["training"]["score_weights"] = combine_score_weights(score_weights)
if not srsly.is_json_serializable(self._config):
@ -260,7 +260,8 @@ class Language:
"""
# Make sure the disabled components are returned in the order they
# appear in the pipeline (which isn't guaranteed by the set)
return [name for name, _ in self.components if name in self._disabled]
names = [name for name, _ in self._components if name in self._disabled]
return SimpleFrozenList(names, error=Errors.E926.format(attr="disabled"))
@property
def factory_names(self) -> List[str]:
@ -268,7 +269,17 @@ class Language:
RETURNS (List[str]): The factory names.
"""
return list(self.factories.keys())
names = list(self.factories.keys())
return SimpleFrozenList(names)
@property
def components(self) -> List[Tuple[str, Callable[[Doc], Doc]]]:
"""Get all (name, component) tuples in the pipeline, including the
currently disabled components.
"""
return SimpleFrozenList(
self._components, error=Errors.E926.format(attr="components")
)
@property
def component_names(self) -> List[str]:
@ -277,7 +288,8 @@ class Language:
RETURNS (List[str]): List of component name strings, in order.
"""
return [pipe_name for pipe_name, _ in self.components]
names = [pipe_name for pipe_name, _ in self._components]
return SimpleFrozenList(names, error=Errors.E926.format(attr="component_names"))
@property
def pipeline(self) -> List[Tuple[str, Callable[[Doc], Doc]]]:
@ -287,7 +299,8 @@ class Language:
RETURNS (List[Tuple[str, Callable[[Doc], Doc]]]): The pipeline.
"""
return [(name, p) for name, p in self.components if name not in self._disabled]
pipes = [(n, p) for n, p in self._components if n not in self._disabled]
return SimpleFrozenList(pipes, error=Errors.E926.format(attr="pipeline"))
@property
def pipe_names(self) -> List[str]:
@ -295,7 +308,8 @@ class Language:
RETURNS (List[str]): List of component name strings, in order.
"""
return [pipe_name for pipe_name, _ in self.pipeline]
names = [pipe_name for pipe_name, _ in self.pipeline]
return SimpleFrozenList(names, error=Errors.E926.format(attr="pipe_names"))
@property
def pipe_factories(self) -> Dict[str, str]:
@ -304,9 +318,9 @@ class Language:
RETURNS (Dict[str, str]): Factory names, keyed by component names.
"""
factories = {}
for pipe_name, pipe in self.components:
for pipe_name, pipe in self._components:
factories[pipe_name] = self.get_pipe_meta(pipe_name).factory
return factories
return SimpleFrozenDict(factories)
@property
def pipe_labels(self) -> Dict[str, List[str]]:
@ -316,10 +330,10 @@ class Language:
RETURNS (Dict[str, List[str]]): Labels keyed by component name.
"""
labels = {}
for name, pipe in self.components:
for name, pipe in self._components:
if hasattr(pipe, "labels"):
labels[name] = list(pipe.labels)
return labels
return SimpleFrozenDict(labels)
@classmethod
def has_factory(cls, name: str) -> bool:
@ -390,10 +404,10 @@ class Language:
name: str,
*,
default_config: Dict[str, Any] = SimpleFrozenDict(),
assigns: Iterable[str] = tuple(),
requires: Iterable[str] = tuple(),
assigns: Iterable[str] = SimpleFrozenList(),
requires: Iterable[str] = SimpleFrozenList(),
retokenizes: bool = False,
scores: Iterable[str] = tuple(),
scores: Iterable[str] = SimpleFrozenList(),
default_score_weights: Dict[str, float] = SimpleFrozenDict(),
func: Optional[Callable] = None,
) -> Callable:
@ -471,8 +485,8 @@ class Language:
cls,
name: Optional[str] = None,
*,
assigns: Iterable[str] = tuple(),
requires: Iterable[str] = tuple(),
assigns: Iterable[str] = SimpleFrozenList(),
requires: Iterable[str] = SimpleFrozenList(),
retokenizes: bool = False,
func: Optional[Callable[[Doc], Doc]] = None,
) -> Callable:
@ -544,7 +558,7 @@ class Language:
DOCS: https://spacy.io/api/language#get_pipe
"""
for pipe_name, component in self.components:
for pipe_name, component in self._components:
if pipe_name == name:
return component
raise KeyError(Errors.E001.format(name=name, opts=self.component_names))
@ -718,7 +732,7 @@ class Language:
)
pipe_index = self._get_pipe_index(before, after, first, last)
self._pipe_meta[name] = self.get_factory_meta(factory_name)
self.components.insert(pipe_index, (name, pipe_component))
self._components.insert(pipe_index, (name, pipe_component))
return pipe_component
def _get_pipe_index(
@ -743,7 +757,7 @@ class Language:
Errors.E006.format(args=all_args, opts=self.component_names)
)
if last or not any(value is not None for value in [first, before, after]):
return len(self.components)
return len(self._components)
elif first:
return 0
elif isinstance(before, str):
@ -761,14 +775,14 @@ class Language:
# We're only accepting indices referring to components that exist
# (can't just do isinstance here because bools are instance of int, too)
elif type(before) == int:
if before >= len(self.components) or before < 0:
if before >= len(self._components) or before < 0:
err = Errors.E959.format(
dir="before", idx=before, opts=self.component_names
)
raise ValueError(err)
return before
elif type(after) == int:
if after >= len(self.components) or after < 0:
if after >= len(self._components) or after < 0:
err = Errors.E959.format(
dir="after", idx=after, opts=self.component_names
)
@ -815,7 +829,7 @@ class Language:
# to Language.pipeline to make sure the configs are handled correctly
pipe_index = self.pipe_names.index(name)
self.remove_pipe(name)
if not len(self.components) or pipe_index == len(self.components):
if not len(self._components) or pipe_index == len(self._components):
# we have no components to insert before/after, or we're replacing the last component
self.add_pipe(factory_name, name=name, config=config, validate=validate)
else:
@ -844,7 +858,7 @@ class Language:
Errors.E007.format(name=new_name, opts=self.component_names)
)
i = self.component_names.index(old_name)
self.components[i] = (new_name, self.components[i][1])
self._components[i] = (new_name, self._components[i][1])
self._pipe_meta[new_name] = self._pipe_meta.pop(old_name)
self._pipe_configs[new_name] = self._pipe_configs.pop(old_name)
@ -858,7 +872,7 @@ class Language:
"""
if name not in self.component_names:
raise ValueError(Errors.E001.format(name=name, opts=self.component_names))
removed = self.components.pop(self.component_names.index(name))
removed = self._components.pop(self.component_names.index(name))
# We're only removing the component itself from the metas/configs here
# because factory may be used for something else
self._pipe_meta.pop(name)
@ -894,7 +908,7 @@ class Language:
self,
text: str,
*,
disable: Iterable[str] = tuple(),
disable: Iterable[str] = SimpleFrozenList(),
component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
) -> Doc:
"""Apply the pipeline to some text. The text can span multiple sentences,
@ -993,7 +1007,7 @@ class Language:
sgd: Optional[Optimizer] = None,
losses: Optional[Dict[str, float]] = None,
component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
exclude: Iterable[str] = tuple(),
exclude: Iterable[str] = SimpleFrozenList(),
):
"""Update the models in the pipeline.
@ -1047,7 +1061,7 @@ class Language:
sgd: Optional[Optimizer] = None,
losses: Optional[Dict[str, float]] = None,
component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
exclude: Iterable[str] = tuple(),
exclude: Iterable[str] = SimpleFrozenList(),
) -> Dict[str, float]:
"""Make a "rehearsal" update to the models in the pipeline, to prevent
forgetting. Rehearsal updates run an initial copy of the model over some
@ -1276,7 +1290,7 @@ class Language:
*,
as_tuples: bool = False,
batch_size: int = 1000,
disable: Iterable[str] = tuple(),
disable: Iterable[str] = SimpleFrozenList(),
cleanup: bool = False,
component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
n_process: int = 1,
@ -1436,8 +1450,8 @@ class Language:
config: Union[Dict[str, Any], Config] = {},
*,
vocab: Union[Vocab, bool] = True,
disable: Iterable[str] = tuple(),
exclude: Iterable[str] = tuple(),
disable: Iterable[str] = SimpleFrozenList(),
exclude: Iterable[str] = SimpleFrozenList(),
auto_fill: bool = True,
validate: bool = True,
) -> "Language":
@ -1562,7 +1576,7 @@ class Language:
return nlp
def to_disk(
self, path: Union[str, Path], *, exclude: Iterable[str] = tuple()
self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList()
) -> None:
"""Save the current state to a directory. If a model is loaded, this
will include the model.
@ -1580,7 +1594,7 @@ class Language:
)
serializers["meta.json"] = lambda p: srsly.write_json(p, self.meta)
serializers["config.cfg"] = lambda p: self.config.to_disk(p)
for name, proc in self.components:
for name, proc in self._components:
if name in exclude:
continue
if not hasattr(proc, "to_disk"):
@ -1590,7 +1604,7 @@ class Language:
util.to_disk(path, serializers, exclude)
def from_disk(
self, path: Union[str, Path], *, exclude: Iterable[str] = tuple()
self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList()
) -> "Language":
"""Loads state from a directory. Modifies the object in place and
returns it. If the saved `Language` object contains a model, the
@ -1624,7 +1638,7 @@ class Language:
deserializers["tokenizer"] = lambda p: self.tokenizer.from_disk(
p, exclude=["vocab"]
)
for name, proc in self.components:
for name, proc in self._components:
if name in exclude:
continue
if not hasattr(proc, "from_disk"):
@ -1640,7 +1654,7 @@ class Language:
self._link_components()
return self
def to_bytes(self, *, exclude: Iterable[str] = tuple()) -> bytes:
def to_bytes(self, *, exclude: Iterable[str] = SimpleFrozenList()) -> bytes:
"""Serialize the current state to a binary string.
exclude (list): Names of components or serialization fields to exclude.
@ -1653,7 +1667,7 @@ class Language:
serializers["tokenizer"] = lambda: self.tokenizer.to_bytes(exclude=["vocab"])
serializers["meta.json"] = lambda: srsly.json_dumps(self.meta)
serializers["config.cfg"] = lambda: self.config.to_bytes()
for name, proc in self.components:
for name, proc in self._components:
if name in exclude:
continue
if not hasattr(proc, "to_bytes"):
@ -1662,7 +1676,7 @@ class Language:
return util.to_bytes(serializers, exclude)
def from_bytes(
self, bytes_data: bytes, *, exclude: Iterable[str] = tuple()
self, bytes_data: bytes, *, exclude: Iterable[str] = SimpleFrozenList()
) -> "Language":
"""Load state from a binary string.
@ -1687,7 +1701,7 @@ class Language:
deserializers["tokenizer"] = lambda b: self.tokenizer.from_bytes(
b, exclude=["vocab"]
)
for name, proc in self.components:
for name, proc in self._components:
if name in exclude:
continue
if not hasattr(proc, "from_bytes"):

View File

@ -12,6 +12,7 @@ from ..symbols import IDS, TAG, POS, MORPH, LEMMA
from ..tokens import Doc, Span
from ..tokens._retokenize import normalize_token_attrs, set_token_attrs
from ..vocab import Vocab
from ..util import SimpleFrozenList
from .. import util
@ -220,7 +221,7 @@ class AttributeRuler(Pipe):
results.update(Scorer.score_token_attr(examples, "lemma", **kwargs))
return results
def to_bytes(self, exclude: Iterable[str] = tuple()) -> bytes:
def to_bytes(self, exclude: Iterable[str] = SimpleFrozenList()) -> bytes:
"""Serialize the AttributeRuler to a bytestring.
exclude (Iterable[str]): String names of serialization fields to exclude.
@ -236,7 +237,9 @@ class AttributeRuler(Pipe):
serialize["indices"] = lambda: srsly.msgpack_dumps(self.indices)
return util.to_bytes(serialize, exclude)
def from_bytes(self, bytes_data: bytes, exclude: Iterable[str] = tuple()):
def from_bytes(
self, bytes_data: bytes, exclude: Iterable[str] = SimpleFrozenList()
):
"""Load the AttributeRuler from a bytestring.
bytes_data (bytes): The data to load.
@ -272,7 +275,9 @@ class AttributeRuler(Pipe):
return self
def to_disk(self, path: Union[Path, str], exclude: Iterable[str] = tuple()) -> None:
def to_disk(
self, path: Union[Path, str], exclude: Iterable[str] = SimpleFrozenList()
) -> None:
"""Serialize the AttributeRuler to disk.
path (Union[Path, str]): A path to a directory.
@ -289,7 +294,7 @@ class AttributeRuler(Pipe):
util.to_disk(path, serialize, exclude)
def from_disk(
self, path: Union[Path, str], exclude: Iterable[str] = tuple()
self, path: Union[Path, str], exclude: Iterable[str] = SimpleFrozenList()
) -> None:
"""Load the AttributeRuler from disk.

View File

@ -13,6 +13,7 @@ from ..language import Language
from ..vocab import Vocab
from ..gold import Example, validate_examples
from ..errors import Errors, Warnings
from ..util import SimpleFrozenList
from .. import util
@ -404,7 +405,7 @@ class EntityLinker(Pipe):
token.ent_kb_id_ = kb_id
def to_disk(
self, path: Union[str, Path], *, exclude: Iterable[str] = tuple()
self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList(),
) -> None:
"""Serialize the pipe to disk.
@ -421,7 +422,7 @@ class EntityLinker(Pipe):
util.to_disk(path, serialize, exclude)
def from_disk(
self, path: Union[str, Path], *, exclude: Iterable[str] = tuple()
self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList(),
) -> "EntityLinker":
"""Load the pipe from disk. Modifies the object in place and returns it.

View File

@ -5,7 +5,7 @@ import srsly
from ..language import Language
from ..errors import Errors
from ..util import ensure_path, to_disk, from_disk
from ..util import ensure_path, to_disk, from_disk, SimpleFrozenList
from ..tokens import Doc, Span
from ..matcher import Matcher, PhraseMatcher
from ..scorer import Scorer
@ -317,7 +317,7 @@ class EntityRuler:
return Scorer.score_spans(examples, "ents", **kwargs)
def from_bytes(
self, patterns_bytes: bytes, *, exclude: Iterable[str] = tuple()
self, patterns_bytes: bytes, *, exclude: Iterable[str] = SimpleFrozenList()
) -> "EntityRuler":
"""Load the entity ruler from a bytestring.
@ -341,7 +341,7 @@ class EntityRuler:
self.add_patterns(cfg)
return self
def to_bytes(self, *, exclude: Iterable[str] = tuple()) -> bytes:
def to_bytes(self, *, exclude: Iterable[str] = SimpleFrozenList()) -> bytes:
"""Serialize the entity ruler patterns to a bytestring.
RETURNS (bytes): The serialized patterns.
@ -357,7 +357,7 @@ class EntityRuler:
return srsly.msgpack_dumps(serial)
def from_disk(
self, path: Union[str, Path], *, exclude: Iterable[str] = tuple()
self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList()
) -> "EntityRuler":
"""Load the entity ruler from a file. Expects a file containing
newline-delimited JSON (JSONL) with one entry per line.
@ -394,7 +394,7 @@ class EntityRuler:
return self
def to_disk(
self, path: Union[str, Path], *, exclude: Iterable[str] = tuple()
self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList()
) -> None:
"""Save the entity ruler patterns to a directory. The patterns will be
saved as newline-delimited JSON (JSONL).

View File

@ -1,10 +1,10 @@
from typing import Optional, Iterable, Dict, Any, Callable, Tuple, TYPE_CHECKING
from typing import Optional, Iterable, Dict, Any, Callable, TYPE_CHECKING
import numpy as np
from .gold import Example
from .tokens import Token, Doc, Span
from .errors import Errors
from .util import get_lang_class
from .util import get_lang_class, SimpleFrozenList
from .morphology import Morphology
if TYPE_CHECKING:
@ -317,7 +317,7 @@ class Scorer:
attr: str,
*,
getter: Callable[[Doc, str], Any] = getattr,
labels: Iterable[str] = tuple(),
labels: Iterable[str] = SimpleFrozenList(),
multi_label: bool = True,
positive_label: Optional[str] = None,
threshold: Optional[float] = None,
@ -447,7 +447,7 @@ class Scorer:
getter: Callable[[Token, str], Any] = getattr,
head_attr: str = "head",
head_getter: Callable[[Token, str], Token] = getattr,
ignore_labels: Tuple[str] = tuple(),
ignore_labels: Iterable[str] = SimpleFrozenList(),
**cfg,
) -> Dict[str, Any]:
"""Returns the UAS, LAS, and LAS per type scores for dependency

View File

@ -1,5 +1,6 @@
import pytest
from spacy.language import Language
from spacy.util import SimpleFrozenList
@pytest.fixture
@ -317,3 +318,31 @@ def test_disable_enable_pipes():
assert nlp.config["nlp"]["disabled"] == [name]
nlp("?")
assert results[f"{name}1"] == "!"
def test_pipe_methods_frozen():
"""Test that spaCy raises custom error messages if "frozen" properties are
accessed. We still want to use a list here to not break backwards
compatibility, but users should see an error if they're trying to append
to nlp.pipeline etc."""
nlp = Language()
ner = nlp.add_pipe("ner")
assert nlp.pipe_names == ["ner"]
for prop in [
nlp.pipeline,
nlp.pipe_names,
nlp.components,
nlp.component_names,
nlp.disabled,
nlp.factory_names,
]:
assert isinstance(prop, list)
assert isinstance(prop, SimpleFrozenList)
with pytest.raises(NotImplementedError):
nlp.pipeline.append(("ner2", ner))
with pytest.raises(NotImplementedError):
nlp.pipe_names.pop()
with pytest.raises(NotImplementedError):
nlp.components.sort()
with pytest.raises(NotImplementedError):
nlp.component_names.clear()

View File

@ -3,10 +3,9 @@ import pytest
from .util import get_random_doc
from spacy import util
from spacy.util import dot_to_object
from spacy.util import dot_to_object, SimpleFrozenList
from thinc.api import Config, Optimizer
from spacy.gold.batchers import minibatch_by_words
from ..lang.en import English
from ..lang.nl import Dutch
from ..language import DEFAULT_CONFIG_PATH
@ -106,3 +105,20 @@ def test_util_dot_section():
assert not dot_to_object(en_config, "nlp.load_vocab_data")
assert dot_to_object(nl_config, "nlp.load_vocab_data")
assert isinstance(dot_to_object(nl_config, "training.optimizer"), Optimizer)
def test_simple_frozen_list():
t = SimpleFrozenList(["foo", "bar"])
assert t == ["foo", "bar"]
assert t.index("bar") == 1 # okay method
with pytest.raises(NotImplementedError):
t.append("baz")
with pytest.raises(NotImplementedError):
t.sort()
with pytest.raises(NotImplementedError):
t.extend(["baz"])
with pytest.raises(NotImplementedError):
t.pop()
t = SimpleFrozenList(["foo", "bar"], error="Error!")
with pytest.raises(NotImplementedError):
t.append("baz")

View File

@ -10,7 +10,7 @@ from ..vocab import Vocab
from ..compat import copy_reg
from ..attrs import SPACY, ORTH, intify_attr
from ..errors import Errors
from ..util import ensure_path
from ..util import ensure_path, SimpleFrozenList
# fmt: off
ALL_ATTRS = ("ORTH", "TAG", "HEAD", "DEP", "ENT_IOB", "ENT_TYPE", "ENT_KB_ID", "LEMMA", "MORPH", "POS")
@ -52,7 +52,7 @@ class DocBin:
self,
attrs: Iterable[str] = ALL_ATTRS,
store_user_data: bool = False,
docs: Iterable[Doc] = tuple(),
docs: Iterable[Doc] = SimpleFrozenList(),
) -> None:
"""Create a DocBin object to hold serialized annotations.

View File

@ -120,6 +120,47 @@ class SimpleFrozenDict(dict):
raise NotImplementedError(self.error)
class SimpleFrozenList(list):
"""Wrapper class around a list that lets us raise custom errors if certain
attributes/methods are accessed. Mostly used for properties like
Language.pipeline that return an immutable list (and that we don't want to
convert to a tuple to not break too much backwards compatibility). If a user
accidentally calls nlp.pipeline.append(), we can raise a more helpful error.
"""
def __init__(self, *args, error: str = Errors.E927) -> None:
"""Initialize the frozen list.
error (str): The error message when user tries to mutate the list.
"""
self.error = error
super().__init__(*args)
def append(self, *args, **kwargs):
raise NotImplementedError(self.error)
def clear(self, *args, **kwargs):
raise NotImplementedError(self.error)
def extend(self, *args, **kwargs):
raise NotImplementedError(self.error)
def insert(self, *args, **kwargs):
raise NotImplementedError(self.error)
def pop(self, *args, **kwargs):
raise NotImplementedError(self.error)
def remove(self, *args, **kwargs):
raise NotImplementedError(self.error)
def reverse(self, *args, **kwargs):
raise NotImplementedError(self.error)
def sort(self, *args, **kwargs):
raise NotImplementedError(self.error)
def lang_class_is_loaded(lang: str) -> bool:
"""Check whether a Language class is already loaded. Language classes are
loaded lazily, to avoid expensive setup code associated with the language
@ -215,8 +256,8 @@ def load_model(
name: Union[str, Path],
*,
vocab: Union["Vocab", bool] = True,
disable: Iterable[str] = tuple(),
exclude: Iterable[str] = tuple(),
disable: Iterable[str] = SimpleFrozenList(),
exclude: Iterable[str] = SimpleFrozenList(),
config: Union[Dict[str, Any], Config] = SimpleFrozenDict(),
) -> "Language":
"""Load a model from a package or data path.
@ -248,8 +289,8 @@ def load_model_from_package(
name: str,
*,
vocab: Union["Vocab", bool] = True,
disable: Iterable[str] = tuple(),
exclude: Iterable[str] = tuple(),
disable: Iterable[str] = SimpleFrozenList(),
exclude: Iterable[str] = SimpleFrozenList(),
config: Union[Dict[str, Any], Config] = SimpleFrozenDict(),
) -> "Language":
"""Load a model from an installed package.
@ -275,8 +316,8 @@ def load_model_from_path(
*,
meta: Optional[Dict[str, Any]] = None,
vocab: Union["Vocab", bool] = True,
disable: Iterable[str] = tuple(),
exclude: Iterable[str] = tuple(),
disable: Iterable[str] = SimpleFrozenList(),
exclude: Iterable[str] = SimpleFrozenList(),
config: Union[Dict[str, Any], Config] = SimpleFrozenDict(),
) -> "Language":
"""Load a model from a data directory path. Creates Language class with
@ -311,8 +352,8 @@ def load_model_from_config(
config: Union[Dict[str, Any], Config],
*,
vocab: Union["Vocab", bool] = True,
disable: Iterable[str] = tuple(),
exclude: Iterable[str] = tuple(),
disable: Iterable[str] = SimpleFrozenList(),
exclude: Iterable[str] = SimpleFrozenList(),
auto_fill: bool = False,
validate: bool = True,
) -> Tuple["Language", Config]:
@ -355,8 +396,8 @@ def load_model_from_init_py(
init_file: Union[Path, str],
*,
vocab: Union["Vocab", bool] = True,
disable: Iterable[str] = tuple(),
exclude: Iterable[str] = tuple(),
disable: Iterable[str] = SimpleFrozenList(),
exclude: Iterable[str] = SimpleFrozenList(),
config: Union[Dict[str, Any], Config] = SimpleFrozenDict(),
) -> "Language":
"""Helper function to use in the `load()` method of a model package's