Rename user-facing API

This commit is contained in:
Ines Montani 2020-08-28 21:04:02 +02:00
parent 6a999c9303
commit 0687d7148e
3 changed files with 82 additions and 71 deletions

View File

@ -159,8 +159,8 @@ class Language:
self.vocab: Vocab = vocab self.vocab: Vocab = vocab
if self.lang is None: if self.lang is None:
self.lang = self.vocab.lang self.lang = self.vocab.lang
self._pipeline = [] self.components = []
self._disabled = set() self.disabled = set()
self.max_length = max_length self.max_length = max_length
self.resolved = {} self.resolved = {}
# Create the default tokenizer from the default config # Create the default tokenizer from the default config
@ -211,7 +211,7 @@ class Language:
# TODO: Adding this back to prevent breaking people's code etc., but # TODO: Adding this back to prevent breaking people's code etc., but
# we should consider removing it # we should consider removing it
self._meta["pipeline"] = self.pipe_names self._meta["pipeline"] = self.pipe_names
self._meta["disabled"] = list(self._disabled) self._meta["disabled"] = list(self.disabled)
return self._meta return self._meta
@meta.setter @meta.setter
@ -234,14 +234,14 @@ class Language:
# we can populate the config again later # we can populate the config again later
pipeline = {} pipeline = {}
score_weights = [] score_weights = []
for pipe_name in self._pipe_names: for pipe_name in self.component_names:
pipe_meta = self.get_pipe_meta(pipe_name) pipe_meta = self.get_pipe_meta(pipe_name)
pipe_config = self.get_pipe_config(pipe_name) pipe_config = self.get_pipe_config(pipe_name)
pipeline[pipe_name] = {"factory": pipe_meta.factory, **pipe_config} pipeline[pipe_name] = {"factory": pipe_meta.factory, **pipe_config}
if pipe_meta.default_score_weights: if pipe_meta.default_score_weights:
score_weights.append(pipe_meta.default_score_weights) score_weights.append(pipe_meta.default_score_weights)
self._config["nlp"]["pipeline"] = self._pipe_names self._config["nlp"]["pipeline"] = self.component_names
self._config["nlp"]["disabled"] = list(self._disabled) self._config["nlp"]["disabled"] = list(self.disabled)
self._config["components"] = pipeline self._config["components"] = pipeline
self._config["training"]["score_weights"] = combine_score_weights(score_weights) self._config["training"]["score_weights"] = combine_score_weights(score_weights)
if not srsly.is_json_serializable(self._config): if not srsly.is_json_serializable(self._config):
@ -261,15 +261,13 @@ class Language:
return list(self.factories.keys()) return list(self.factories.keys())
@property @property
def _pipe_names(self) -> List[str]: def component_names(self) -> List[str]:
"""Get the names of the available pipeline components. Includes all """Get the names of the available pipeline components. Includes all
active and inactive pipeline components. active and inactive pipeline components.
RETURNS (List[str]): List of component name strings, in order. RETURNS (List[str]): List of component name strings, in order.
""" """
# TODO: Should we make this available via a user-facing property? (The return [pipe_name for pipe_name, _ in self.components]
# underscore distinction works well internally)
return [pipe_name for pipe_name, _ in self._pipeline]
@property @property
def pipeline(self) -> List[Tuple[str, Callable[[Doc], Doc]]]: def pipeline(self) -> List[Tuple[str, Callable[[Doc], Doc]]]:
@ -279,7 +277,7 @@ class Language:
RETURNS (List[Tuple[str, Callable[[Doc], Doc]]]): The pipeline. RETURNS (List[Tuple[str, Callable[[Doc], Doc]]]): The pipeline.
""" """
return [(name, p) for name, p in self._pipeline if name not in self._disabled] return [(name, p) for name, p in self.components if name not in self.disabled]
@property @property
def pipe_names(self) -> List[str]: def pipe_names(self) -> List[str]:
@ -296,7 +294,7 @@ class Language:
RETURNS (Dict[str, str]): Factory names, keyed by component names. RETURNS (Dict[str, str]): Factory names, keyed by component names.
""" """
factories = {} factories = {}
for pipe_name, pipe in self._pipeline: for pipe_name, pipe in self.components:
factories[pipe_name] = self.get_pipe_meta(pipe_name).factory factories[pipe_name] = self.get_pipe_meta(pipe_name).factory
return factories return factories
@ -308,7 +306,7 @@ class Language:
RETURNS (Dict[str, List[str]]): Labels keyed by component name. RETURNS (Dict[str, List[str]]): Labels keyed by component name.
""" """
labels = {} labels = {}
for name, pipe in self._pipeline: for name, pipe in self.components:
if hasattr(pipe, "labels"): if hasattr(pipe, "labels"):
labels[name] = list(pipe.labels) labels[name] = list(pipe.labels)
return labels return labels
@ -536,10 +534,10 @@ class Language:
DOCS: https://spacy.io/api/language#get_pipe DOCS: https://spacy.io/api/language#get_pipe
""" """
for pipe_name, component in self._pipeline: for pipe_name, component in self.components:
if pipe_name == name: if pipe_name == name:
return component return component
raise KeyError(Errors.E001.format(name=name, opts=self._pipe_names)) raise KeyError(Errors.E001.format(name=name, opts=self.component_names))
def create_pipe( def create_pipe(
self, self,
@ -684,8 +682,8 @@ class Language:
err = Errors.E966.format(component=bad_val, name=name) err = Errors.E966.format(component=bad_val, name=name)
raise ValueError(err) raise ValueError(err)
name = name if name is not None else factory_name name = name if name is not None else factory_name
if name in self._pipe_names: if name in self.component_names:
raise ValueError(Errors.E007.format(name=name, opts=self._pipe_names)) raise ValueError(Errors.E007.format(name=name, opts=self.component_names))
if source is not None: if source is not None:
# We're loading the component from a model. After loading the # We're loading the component from a model. After loading the
# component, we know its real factory name # component, we know its real factory name
@ -710,7 +708,7 @@ class Language:
) )
pipe_index = self._get_pipe_index(before, after, first, last) pipe_index = self._get_pipe_index(before, after, first, last)
self._pipe_meta[name] = self.get_factory_meta(factory_name) self._pipe_meta[name] = self.get_factory_meta(factory_name)
self._pipeline.insert(pipe_index, (name, pipe_component)) self.components.insert(pipe_index, (name, pipe_component))
return pipe_component return pipe_component
def _get_pipe_index( def _get_pipe_index(
@ -731,34 +729,42 @@ class Language:
""" """
all_args = {"before": before, "after": after, "first": first, "last": last} all_args = {"before": before, "after": after, "first": first, "last": last}
if sum(arg is not None for arg in [before, after, first, last]) >= 2: if sum(arg is not None for arg in [before, after, first, last]) >= 2:
raise ValueError(Errors.E006.format(args=all_args, opts=self._pipe_names)) raise ValueError(
Errors.E006.format(args=all_args, opts=self.component_names)
)
if last or not any(value is not None for value in [first, before, after]): if last or not any(value is not None for value in [first, before, after]):
return len(self._pipeline) return len(self.components)
elif first: elif first:
return 0 return 0
elif isinstance(before, str): elif isinstance(before, str):
if before not in self._pipe_names: if before not in self.component_names:
raise ValueError(Errors.E001.format(name=before, opts=self._pipe_names)) raise ValueError(
return self._pipe_names.index(before) Errors.E001.format(name=before, opts=self.component_names)
)
return self.component_names.index(before)
elif isinstance(after, str): elif isinstance(after, str):
if after not in self._pipe_names: if after not in self.component_names:
raise ValueError(Errors.E001.format(name=after, opts=self._pipe_names)) raise ValueError(
return self._pipe_names.index(after) + 1 Errors.E001.format(name=after, opts=self.component_names)
)
return self.component_names.index(after) + 1
# We're only accepting indices referring to components that exist # We're only accepting indices referring to components that exist
# (can't just do isinstance here because bools are instance of int, too) # (can't just do isinstance here because bools are instance of int, too)
elif type(before) == int: elif type(before) == int:
if before >= len(self._pipeline) or before < 0: if before >= len(self.components) or before < 0:
err = Errors.E959.format( err = Errors.E959.format(
dir="before", idx=before, opts=self._pipe_names dir="before", idx=before, opts=self.component_names
) )
raise ValueError(err) raise ValueError(err)
return before return before
elif type(after) == int: elif type(after) == int:
if after >= len(self._pipeline) or after < 0: if after >= len(self.components) or after < 0:
err = Errors.E959.format(dir="after", idx=after, opts=self._pipe_names) err = Errors.E959.format(
dir="after", idx=after, opts=self.component_names
)
raise ValueError(err) raise ValueError(err)
return after + 1 return after + 1
raise ValueError(Errors.E006.format(args=all_args, opts=self._pipe_names)) raise ValueError(Errors.E006.format(args=all_args, opts=self.component_names))
def has_pipe(self, name: str) -> bool: def has_pipe(self, name: str) -> bool:
"""Check if a component name is present in the pipeline. Equivalent to """Check if a component name is present in the pipeline. Equivalent to
@ -799,7 +805,7 @@ class Language:
# to Language.pipeline to make sure the configs are handled correctly # to Language.pipeline to make sure the configs are handled correctly
pipe_index = self.pipe_names.index(name) pipe_index = self.pipe_names.index(name)
self.remove_pipe(name) self.remove_pipe(name)
if not len(self._pipeline) or pipe_index == len(self._pipeline): if not len(self.components) or pipe_index == len(self.components):
# we have no components to insert before/after, or we're replacing the last component # we have no components to insert before/after, or we're replacing the last component
self.add_pipe(factory_name, name=name, config=config, validate=validate) self.add_pipe(factory_name, name=name, config=config, validate=validate)
else: else:
@ -819,12 +825,16 @@ class Language:
DOCS: https://spacy.io/api/language#rename_pipe DOCS: https://spacy.io/api/language#rename_pipe
""" """
if old_name not in self._pipe_names: if old_name not in self.component_names:
raise ValueError(Errors.E001.format(name=old_name, opts=self._pipe_names)) raise ValueError(
if new_name in self._pipe_names: Errors.E001.format(name=old_name, opts=self.component_names)
raise ValueError(Errors.E007.format(name=new_name, opts=self._pipe_names)) )
i = self._pipe_names.index(old_name) if new_name in self.component_names:
self._pipeline[i] = (new_name, self._pipeline[i][1]) raise ValueError(
Errors.E007.format(name=new_name, opts=self.component_names)
)
i = self.component_names.index(old_name)
self.components[i] = (new_name, self.components[i][1])
self._pipe_meta[new_name] = self._pipe_meta.pop(old_name) self._pipe_meta[new_name] = self._pipe_meta.pop(old_name)
self._pipe_configs[new_name] = self._pipe_configs.pop(old_name) self._pipe_configs[new_name] = self._pipe_configs.pop(old_name)
@ -836,16 +846,16 @@ class Language:
DOCS: https://spacy.io/api/language#remove_pipe DOCS: https://spacy.io/api/language#remove_pipe
""" """
if name not in self._pipe_names: if name not in self.component_names:
raise ValueError(Errors.E001.format(name=name, opts=self._pipe_names)) raise ValueError(Errors.E001.format(name=name, opts=self.component_names))
removed = self._pipeline.pop(self._pipe_names.index(name)) removed = self.components.pop(self.component_names.index(name))
# We're only removing the component itself from the metas/configs here # We're only removing the component itself from the metas/configs here
# because factory may be used for something else # because factory may be used for something else
self._pipe_meta.pop(name) self._pipe_meta.pop(name)
self._pipe_configs.pop(name) self._pipe_configs.pop(name)
# Make sure the name is also removed from the set of disabled components # Make sure the name is also removed from the set of disabled components
if name in self._disabled: if name in self.disabled:
self._disabled.remove(name) self.disabled.remove(name)
return removed return removed
def disable_pipe(self, name: str) -> None: def disable_pipe(self, name: str) -> None:
@ -855,9 +865,9 @@ class Language:
name (str): The name of the component to disable. name (str): The name of the component to disable.
""" """
if name not in self._pipe_names: if name not in self.component_names:
raise ValueError(Errors.E001.format(name=name, opts=self._pipe_names)) raise ValueError(Errors.E001.format(name=name, opts=self.component_names))
self._disabled.add(name) self.disabled.add(name)
def enable_pipe(self, name: str) -> None: def enable_pipe(self, name: str) -> None:
"""Enable a previously disabled pipeline component so it's run as part """Enable a previously disabled pipeline component so it's run as part
@ -865,10 +875,10 @@ class Language:
name (str): The name of the component to enable. name (str): The name of the component to enable.
""" """
if name not in self._pipe_names: if name not in self.component_names:
raise ValueError(Errors.E001.format(name=name, opts=self._pipe_names)) raise ValueError(Errors.E001.format(name=name, opts=self.component_names))
if name in self._disabled: if name in self.disabled:
self._disabled.remove(name) self.disabled.remove(name)
def __call__( def __call__(
self, self,
@ -1530,7 +1540,7 @@ class Language:
source_name = pipe_cfg.get("component", pipe_name) source_name = pipe_cfg.get("component", pipe_name)
nlp.add_pipe(source_name, source=source_nlps[model], name=pipe_name) nlp.add_pipe(source_name, source=source_nlps[model], name=pipe_name)
disabled_pipes = [*config["nlp"]["disabled"], *disable] disabled_pipes = [*config["nlp"]["disabled"], *disable]
nlp._disabled = set(p for p in disabled_pipes if p not in exclude) nlp.disabled = set(p for p in disabled_pipes if p not in exclude)
nlp.config = filled if auto_fill else config nlp.config = filled if auto_fill else config
nlp.resolved = resolved nlp.resolved = resolved
if after_pipeline_creation is not None: if after_pipeline_creation is not None:
@ -1560,6 +1570,7 @@ class Language:
) )
serializers["meta.json"] = lambda p: srsly.write_json(p, self.meta) serializers["meta.json"] = lambda p: srsly.write_json(p, self.meta)
serializers["config.cfg"] = lambda p: self.config.to_disk(p) serializers["config.cfg"] = lambda p: self.config.to_disk(p)
for name, proc in self.components:
if name in exclude: if name in exclude:
continue continue
if not hasattr(proc, "to_disk"): if not hasattr(proc, "to_disk"):
@ -1603,7 +1614,7 @@ class Language:
deserializers["tokenizer"] = lambda p: self.tokenizer.from_disk( deserializers["tokenizer"] = lambda p: self.tokenizer.from_disk(
p, exclude=["vocab"] p, exclude=["vocab"]
) )
for name, proc in self._pipeline: for name, proc in self.components:
if name in exclude: if name in exclude:
continue continue
if not hasattr(proc, "from_disk"): if not hasattr(proc, "from_disk"):
@ -1632,7 +1643,7 @@ class Language:
serializers["tokenizer"] = lambda: self.tokenizer.to_bytes(exclude=["vocab"]) serializers["tokenizer"] = lambda: self.tokenizer.to_bytes(exclude=["vocab"])
serializers["meta.json"] = lambda: srsly.json_dumps(self.meta) serializers["meta.json"] = lambda: srsly.json_dumps(self.meta)
serializers["config.cfg"] = lambda: self.config.to_bytes() serializers["config.cfg"] = lambda: self.config.to_bytes()
for name, proc in self._pipeline: for name, proc in self.components:
if name in exclude: if name in exclude:
continue continue
if not hasattr(proc, "to_bytes"): if not hasattr(proc, "to_bytes"):
@ -1666,7 +1677,7 @@ class Language:
deserializers["tokenizer"] = lambda b: self.tokenizer.from_bytes( deserializers["tokenizer"] = lambda b: self.tokenizer.from_bytes(
b, exclude=["vocab"] b, exclude=["vocab"]
) )
for name, proc in self._pipeline: for name, proc in self.components:
if name in exclude: if name in exclude:
continue continue
if not hasattr(proc, "from_bytes"): if not hasattr(proc, "from_bytes"):
@ -1716,7 +1727,7 @@ class DisabledPipes(list):
def restore(self) -> None: def restore(self) -> None:
"""Restore the pipeline to its state when DisabledPipes was created.""" """Restore the pipeline to its state when DisabledPipes was created."""
for name in self.names: for name in self.names:
if name not in self.nlp._pipe_names: if name not in self.nlp.component_names:
raise ValueError(Errors.E008.format(name=name)) raise ValueError(Errors.E008.format(name=name))
self.nlp.enable_pipe(name) self.nlp.enable_pipe(name)
self[:] = [] self[:] = []

View File

@ -281,15 +281,15 @@ def test_disable_enable_pipes():
assert nlp.pipeline == [(f"{name}1", c1), (f"{name}2", c2)] assert nlp.pipeline == [(f"{name}1", c1), (f"{name}2", c2)]
assert nlp.pipe_names == [f"{name}1", f"{name}2"] assert nlp.pipe_names == [f"{name}1", f"{name}2"]
nlp.disable_pipe(f"{name}1") nlp.disable_pipe(f"{name}1")
assert nlp._disabled == set([f"{name}1"]) assert nlp.disabled == set([f"{name}1"])
assert nlp._pipe_names == [f"{name}1", f"{name}2"] assert nlp.component_names == [f"{name}1", f"{name}2"]
assert nlp.pipe_names == [f"{name}2"] assert nlp.pipe_names == [f"{name}2"]
assert nlp.config["nlp"]["disabled"] == [f"{name}1"] assert nlp.config["nlp"]["disabled"] == [f"{name}1"]
nlp("hello") nlp("hello")
assert results[f"{name}1"] == "" # didn't run assert results[f"{name}1"] == "" # didn't run
assert results[f"{name}2"] == "hello" # ran assert results[f"{name}2"] == "hello" # ran
nlp.enable_pipe(f"{name}1") nlp.enable_pipe(f"{name}1")
assert nlp._disabled == set() assert nlp.disabled == set()
assert nlp.pipe_names == [f"{name}1", f"{name}2"] assert nlp.pipe_names == [f"{name}1", f"{name}2"]
assert nlp.config["nlp"]["disabled"] == [] assert nlp.config["nlp"]["disabled"] == []
nlp("world") nlp("world")
@ -297,22 +297,22 @@ def test_disable_enable_pipes():
assert results[f"{name}2"] == "world" assert results[f"{name}2"] == "world"
nlp.disable_pipe(f"{name}2") nlp.disable_pipe(f"{name}2")
nlp.remove_pipe(f"{name}2") nlp.remove_pipe(f"{name}2")
assert nlp._pipeline == [(f"{name}1", c1)] assert nlp.components == [(f"{name}1", c1)]
assert nlp.pipeline == [(f"{name}1", c1)] assert nlp.pipeline == [(f"{name}1", c1)]
assert nlp._pipe_names == [f"{name}1"] assert nlp.component_names == [f"{name}1"]
assert nlp.pipe_names == [f"{name}1"] assert nlp.pipe_names == [f"{name}1"]
assert nlp._disabled == set() assert nlp.disabled == set()
assert nlp.config["nlp"]["disabled"] == [] assert nlp.config["nlp"]["disabled"] == []
nlp.rename_pipe(f"{name}1", name) nlp.rename_pipe(f"{name}1", name)
assert nlp._pipeline == [(name, c1)] assert nlp.components == [(name, c1)]
assert nlp._pipe_names == [name] assert nlp.component_names == [name]
nlp("!") nlp("!")
assert results[f"{name}1"] == "!" assert results[f"{name}1"] == "!"
assert results[f"{name}2"] == "world" assert results[f"{name}2"] == "world"
with pytest.raises(ValueError): with pytest.raises(ValueError):
nlp.disable_pipe(f"{name}2") nlp.disable_pipe(f"{name}2")
nlp.disable_pipe(name) nlp.disable_pipe(name)
assert nlp._pipe_names == [name] assert nlp.component_names == [name]
assert nlp.pipe_names == [] assert nlp.pipe_names == []
assert nlp.config["nlp"]["disabled"] == [name] assert nlp.config["nlp"]["disabled"] == [name]
nlp("?") nlp("?")

View File

@ -186,23 +186,23 @@ def test_serialize_pipeline_disable_enable():
config = nlp.config.copy() config = nlp.config.copy()
nlp2 = English.from_config(config) nlp2 = English.from_config(config)
assert nlp2.pipe_names == ["ner"] assert nlp2.pipe_names == ["ner"]
assert nlp2._pipe_names == ["ner", "tagger"] assert nlp2.component_names == ["ner", "tagger"]
assert nlp2._disabled == set(["tagger"]) assert nlp2.disabled == set(["tagger"])
assert nlp2.config["nlp"]["disabled"] == ["tagger"] assert nlp2.config["nlp"]["disabled"] == ["tagger"]
with make_tempdir() as d: with make_tempdir() as d:
nlp2.to_disk(d) nlp2.to_disk(d)
nlp3 = spacy.load(d) nlp3 = spacy.load(d)
assert nlp3.pipe_names == ["ner"] assert nlp3.pipe_names == ["ner"]
assert nlp3._pipe_names == ["ner", "tagger"] assert nlp3.component_names == ["ner", "tagger"]
with make_tempdir() as d: with make_tempdir() as d:
nlp3.to_disk(d) nlp3.to_disk(d)
nlp4 = spacy.load(d, disable=["ner"]) nlp4 = spacy.load(d, disable=["ner"])
assert nlp4.pipe_names == [] assert nlp4.pipe_names == []
assert nlp4._pipe_names == ["ner", "tagger"] assert nlp4.component_names == ["ner", "tagger"]
assert nlp4._disabled == set(["ner", "tagger"]) assert nlp4.disabled == set(["ner", "tagger"])
with make_tempdir() as d: with make_tempdir() as d:
nlp.to_disk(d) nlp.to_disk(d)
nlp5 = spacy.load(d, exclude=["tagger"]) nlp5 = spacy.load(d, exclude=["tagger"])
assert nlp5.pipe_names == ["ner"] assert nlp5.pipe_names == ["ner"]
assert nlp5._pipe_names == ["ner"] assert nlp5.component_names == ["ner"]
assert nlp5._disabled == set() assert nlp5.disabled == set()