Rename user-facing API

This commit is contained in:
Ines Montani 2020-08-28 21:04:02 +02:00
parent 6a999c9303
commit 0687d7148e
3 changed files with 82 additions and 71 deletions

View File

@ -159,8 +159,8 @@ class Language:
self.vocab: Vocab = vocab
if self.lang is None:
self.lang = self.vocab.lang
self._pipeline = []
self._disabled = set()
self.components = []
self.disabled = set()
self.max_length = max_length
self.resolved = {}
# Create the default tokenizer from the default config
@ -211,7 +211,7 @@ class Language:
# TODO: Adding this back to prevent breaking people's code etc., but
# we should consider removing it
self._meta["pipeline"] = self.pipe_names
self._meta["disabled"] = list(self._disabled)
self._meta["disabled"] = list(self.disabled)
return self._meta
@meta.setter
@ -234,14 +234,14 @@ class Language:
# we can populate the config again later
pipeline = {}
score_weights = []
for pipe_name in self._pipe_names:
for pipe_name in self.component_names:
pipe_meta = self.get_pipe_meta(pipe_name)
pipe_config = self.get_pipe_config(pipe_name)
pipeline[pipe_name] = {"factory": pipe_meta.factory, **pipe_config}
if pipe_meta.default_score_weights:
score_weights.append(pipe_meta.default_score_weights)
self._config["nlp"]["pipeline"] = self._pipe_names
self._config["nlp"]["disabled"] = list(self._disabled)
self._config["nlp"]["pipeline"] = self.component_names
self._config["nlp"]["disabled"] = list(self.disabled)
self._config["components"] = pipeline
self._config["training"]["score_weights"] = combine_score_weights(score_weights)
if not srsly.is_json_serializable(self._config):
@ -261,15 +261,13 @@ class Language:
return list(self.factories.keys())
@property
def _pipe_names(self) -> List[str]:
def component_names(self) -> List[str]:
"""Get the names of the available pipeline components. Includes all
active and inactive pipeline components.
RETURNS (List[str]): List of component name strings, in order.
"""
# TODO: Should we make this available via a user-facing property? (The
# underscore distinction works well internally)
return [pipe_name for pipe_name, _ in self._pipeline]
return [pipe_name for pipe_name, _ in self.components]
@property
def pipeline(self) -> List[Tuple[str, Callable[[Doc], Doc]]]:
@ -279,7 +277,7 @@ class Language:
RETURNS (List[Tuple[str, Callable[[Doc], Doc]]]): The pipeline.
"""
return [(name, p) for name, p in self._pipeline if name not in self._disabled]
return [(name, p) for name, p in self.components if name not in self.disabled]
@property
def pipe_names(self) -> List[str]:
@ -296,7 +294,7 @@ class Language:
RETURNS (Dict[str, str]): Factory names, keyed by component names.
"""
factories = {}
for pipe_name, pipe in self._pipeline:
for pipe_name, pipe in self.components:
factories[pipe_name] = self.get_pipe_meta(pipe_name).factory
return factories
@ -308,7 +306,7 @@ class Language:
RETURNS (Dict[str, List[str]]): Labels keyed by component name.
"""
labels = {}
for name, pipe in self._pipeline:
for name, pipe in self.components:
if hasattr(pipe, "labels"):
labels[name] = list(pipe.labels)
return labels
@ -536,10 +534,10 @@ class Language:
DOCS: https://spacy.io/api/language#get_pipe
"""
for pipe_name, component in self._pipeline:
for pipe_name, component in self.components:
if pipe_name == name:
return component
raise KeyError(Errors.E001.format(name=name, opts=self._pipe_names))
raise KeyError(Errors.E001.format(name=name, opts=self.component_names))
def create_pipe(
self,
@ -684,8 +682,8 @@ class Language:
err = Errors.E966.format(component=bad_val, name=name)
raise ValueError(err)
name = name if name is not None else factory_name
if name in self._pipe_names:
raise ValueError(Errors.E007.format(name=name, opts=self._pipe_names))
if name in self.component_names:
raise ValueError(Errors.E007.format(name=name, opts=self.component_names))
if source is not None:
# We're loading the component from a model. After loading the
# component, we know its real factory name
@ -710,7 +708,7 @@ class Language:
)
pipe_index = self._get_pipe_index(before, after, first, last)
self._pipe_meta[name] = self.get_factory_meta(factory_name)
self._pipeline.insert(pipe_index, (name, pipe_component))
self.components.insert(pipe_index, (name, pipe_component))
return pipe_component
def _get_pipe_index(
@ -731,34 +729,42 @@ class Language:
"""
all_args = {"before": before, "after": after, "first": first, "last": last}
if sum(arg is not None for arg in [before, after, first, last]) >= 2:
raise ValueError(Errors.E006.format(args=all_args, opts=self._pipe_names))
raise ValueError(
Errors.E006.format(args=all_args, opts=self.component_names)
)
if last or not any(value is not None for value in [first, before, after]):
return len(self._pipeline)
return len(self.components)
elif first:
return 0
elif isinstance(before, str):
if before not in self._pipe_names:
raise ValueError(Errors.E001.format(name=before, opts=self._pipe_names))
return self._pipe_names.index(before)
if before not in self.component_names:
raise ValueError(
Errors.E001.format(name=before, opts=self.component_names)
)
return self.component_names.index(before)
elif isinstance(after, str):
if after not in self._pipe_names:
raise ValueError(Errors.E001.format(name=after, opts=self._pipe_names))
return self._pipe_names.index(after) + 1
if after not in self.component_names:
raise ValueError(
Errors.E001.format(name=after, opts=self.component_names)
)
return self.component_names.index(after) + 1
# We're only accepting indices referring to components that exist
# (can't just do isinstance here because bools are instance of int, too)
elif type(before) == int:
if before >= len(self._pipeline) or before < 0:
if before >= len(self.components) or before < 0:
err = Errors.E959.format(
dir="before", idx=before, opts=self._pipe_names
dir="before", idx=before, opts=self.component_names
)
raise ValueError(err)
return before
elif type(after) == int:
if after >= len(self._pipeline) or after < 0:
err = Errors.E959.format(dir="after", idx=after, opts=self._pipe_names)
if after >= len(self.components) or after < 0:
err = Errors.E959.format(
dir="after", idx=after, opts=self.component_names
)
raise ValueError(err)
return after + 1
raise ValueError(Errors.E006.format(args=all_args, opts=self._pipe_names))
raise ValueError(Errors.E006.format(args=all_args, opts=self.component_names))
def has_pipe(self, name: str) -> bool:
"""Check if a component name is present in the pipeline. Equivalent to
@ -799,7 +805,7 @@ class Language:
# to Language.pipeline to make sure the configs are handled correctly
pipe_index = self.pipe_names.index(name)
self.remove_pipe(name)
if not len(self._pipeline) or pipe_index == len(self._pipeline):
if not len(self.components) or pipe_index == len(self.components):
# we have no components to insert before/after, or we're replacing the last component
self.add_pipe(factory_name, name=name, config=config, validate=validate)
else:
@ -819,12 +825,16 @@ class Language:
DOCS: https://spacy.io/api/language#rename_pipe
"""
if old_name not in self._pipe_names:
raise ValueError(Errors.E001.format(name=old_name, opts=self._pipe_names))
if new_name in self._pipe_names:
raise ValueError(Errors.E007.format(name=new_name, opts=self._pipe_names))
i = self._pipe_names.index(old_name)
self._pipeline[i] = (new_name, self._pipeline[i][1])
if old_name not in self.component_names:
raise ValueError(
Errors.E001.format(name=old_name, opts=self.component_names)
)
if new_name in self.component_names:
raise ValueError(
Errors.E007.format(name=new_name, opts=self.component_names)
)
i = self.component_names.index(old_name)
self.components[i] = (new_name, self.components[i][1])
self._pipe_meta[new_name] = self._pipe_meta.pop(old_name)
self._pipe_configs[new_name] = self._pipe_configs.pop(old_name)
@ -836,16 +846,16 @@ class Language:
DOCS: https://spacy.io/api/language#remove_pipe
"""
if name not in self._pipe_names:
raise ValueError(Errors.E001.format(name=name, opts=self._pipe_names))
removed = self._pipeline.pop(self._pipe_names.index(name))
if name not in self.component_names:
raise ValueError(Errors.E001.format(name=name, opts=self.component_names))
removed = self.components.pop(self.component_names.index(name))
# We're only removing the component itself from the metas/configs here
# because factory may be used for something else
self._pipe_meta.pop(name)
self._pipe_configs.pop(name)
# Make sure the name is also removed from the set of disabled components
if name in self._disabled:
self._disabled.remove(name)
if name in self.disabled:
self.disabled.remove(name)
return removed
def disable_pipe(self, name: str) -> None:
@ -855,9 +865,9 @@ class Language:
name (str): The name of the component to disable.
"""
if name not in self._pipe_names:
raise ValueError(Errors.E001.format(name=name, opts=self._pipe_names))
self._disabled.add(name)
if name not in self.component_names:
raise ValueError(Errors.E001.format(name=name, opts=self.component_names))
self.disabled.add(name)
def enable_pipe(self, name: str) -> None:
"""Enable a previously disabled pipeline component so it's run as part
@ -865,10 +875,10 @@ class Language:
name (str): The name of the component to enable.
"""
if name not in self._pipe_names:
raise ValueError(Errors.E001.format(name=name, opts=self._pipe_names))
if name in self._disabled:
self._disabled.remove(name)
if name not in self.component_names:
raise ValueError(Errors.E001.format(name=name, opts=self.component_names))
if name in self.disabled:
self.disabled.remove(name)
def __call__(
self,
@ -1530,7 +1540,7 @@ class Language:
source_name = pipe_cfg.get("component", pipe_name)
nlp.add_pipe(source_name, source=source_nlps[model], name=pipe_name)
disabled_pipes = [*config["nlp"]["disabled"], *disable]
nlp._disabled = set(p for p in disabled_pipes if p not in exclude)
nlp.disabled = set(p for p in disabled_pipes if p not in exclude)
nlp.config = filled if auto_fill else config
nlp.resolved = resolved
if after_pipeline_creation is not None:
@ -1560,6 +1570,7 @@ class Language:
)
serializers["meta.json"] = lambda p: srsly.write_json(p, self.meta)
serializers["config.cfg"] = lambda p: self.config.to_disk(p)
for name, proc in self.components:
if name in exclude:
continue
if not hasattr(proc, "to_disk"):
@ -1603,7 +1614,7 @@ class Language:
deserializers["tokenizer"] = lambda p: self.tokenizer.from_disk(
p, exclude=["vocab"]
)
for name, proc in self._pipeline:
for name, proc in self.components:
if name in exclude:
continue
if not hasattr(proc, "from_disk"):
@ -1632,7 +1643,7 @@ class Language:
serializers["tokenizer"] = lambda: self.tokenizer.to_bytes(exclude=["vocab"])
serializers["meta.json"] = lambda: srsly.json_dumps(self.meta)
serializers["config.cfg"] = lambda: self.config.to_bytes()
for name, proc in self._pipeline:
for name, proc in self.components:
if name in exclude:
continue
if not hasattr(proc, "to_bytes"):
@ -1666,7 +1677,7 @@ class Language:
deserializers["tokenizer"] = lambda b: self.tokenizer.from_bytes(
b, exclude=["vocab"]
)
for name, proc in self._pipeline:
for name, proc in self.components:
if name in exclude:
continue
if not hasattr(proc, "from_bytes"):
@ -1716,7 +1727,7 @@ class DisabledPipes(list):
def restore(self) -> None:
"""Restore the pipeline to its state when DisabledPipes was created."""
for name in self.names:
if name not in self.nlp._pipe_names:
if name not in self.nlp.component_names:
raise ValueError(Errors.E008.format(name=name))
self.nlp.enable_pipe(name)
self[:] = []

View File

@ -281,15 +281,15 @@ def test_disable_enable_pipes():
assert nlp.pipeline == [(f"{name}1", c1), (f"{name}2", c2)]
assert nlp.pipe_names == [f"{name}1", f"{name}2"]
nlp.disable_pipe(f"{name}1")
assert nlp._disabled == set([f"{name}1"])
assert nlp._pipe_names == [f"{name}1", f"{name}2"]
assert nlp.disabled == set([f"{name}1"])
assert nlp.component_names == [f"{name}1", f"{name}2"]
assert nlp.pipe_names == [f"{name}2"]
assert nlp.config["nlp"]["disabled"] == [f"{name}1"]
nlp("hello")
assert results[f"{name}1"] == "" # didn't run
assert results[f"{name}2"] == "hello" # ran
nlp.enable_pipe(f"{name}1")
assert nlp._disabled == set()
assert nlp.disabled == set()
assert nlp.pipe_names == [f"{name}1", f"{name}2"]
assert nlp.config["nlp"]["disabled"] == []
nlp("world")
@ -297,22 +297,22 @@ def test_disable_enable_pipes():
assert results[f"{name}2"] == "world"
nlp.disable_pipe(f"{name}2")
nlp.remove_pipe(f"{name}2")
assert nlp._pipeline == [(f"{name}1", c1)]
assert nlp.components == [(f"{name}1", c1)]
assert nlp.pipeline == [(f"{name}1", c1)]
assert nlp._pipe_names == [f"{name}1"]
assert nlp.component_names == [f"{name}1"]
assert nlp.pipe_names == [f"{name}1"]
assert nlp._disabled == set()
assert nlp.disabled == set()
assert nlp.config["nlp"]["disabled"] == []
nlp.rename_pipe(f"{name}1", name)
assert nlp._pipeline == [(name, c1)]
assert nlp._pipe_names == [name]
assert nlp.components == [(name, c1)]
assert nlp.component_names == [name]
nlp("!")
assert results[f"{name}1"] == "!"
assert results[f"{name}2"] == "world"
with pytest.raises(ValueError):
nlp.disable_pipe(f"{name}2")
nlp.disable_pipe(name)
assert nlp._pipe_names == [name]
assert nlp.component_names == [name]
assert nlp.pipe_names == []
assert nlp.config["nlp"]["disabled"] == [name]
nlp("?")

View File

@ -186,23 +186,23 @@ def test_serialize_pipeline_disable_enable():
config = nlp.config.copy()
nlp2 = English.from_config(config)
assert nlp2.pipe_names == ["ner"]
assert nlp2._pipe_names == ["ner", "tagger"]
assert nlp2._disabled == set(["tagger"])
assert nlp2.component_names == ["ner", "tagger"]
assert nlp2.disabled == set(["tagger"])
assert nlp2.config["nlp"]["disabled"] == ["tagger"]
with make_tempdir() as d:
nlp2.to_disk(d)
nlp3 = spacy.load(d)
assert nlp3.pipe_names == ["ner"]
assert nlp3._pipe_names == ["ner", "tagger"]
assert nlp3.component_names == ["ner", "tagger"]
with make_tempdir() as d:
nlp3.to_disk(d)
nlp4 = spacy.load(d, disable=["ner"])
assert nlp4.pipe_names == []
assert nlp4._pipe_names == ["ner", "tagger"]
assert nlp4._disabled == set(["ner", "tagger"])
assert nlp4.component_names == ["ner", "tagger"]
assert nlp4.disabled == set(["ner", "tagger"])
with make_tempdir() as d:
nlp.to_disk(d)
nlp5 = spacy.load(d, exclude=["tagger"])
assert nlp5.pipe_names == ["ner"]
assert nlp5._pipe_names == ["ner"]
assert nlp5._disabled == set()
assert nlp5.component_names == ["ner"]
assert nlp5.disabled == set()