Make user-facing Language.disabled return list

More consistent with all the other properties
This commit is contained in:
Ines Montani 2020-08-29 12:08:33 +02:00
parent 0687d7148e
commit 15d73f4dc3
3 changed files with 22 additions and 14 deletions

View File

@ -160,7 +160,7 @@ class Language:
if self.lang is None:
self.lang = self.vocab.lang
self.components = []
self.disabled = set()
self._disabled = set()
self.max_length = max_length
self.resolved = {}
# Create the default tokenizer from the default config
@ -211,7 +211,7 @@ class Language:
# TODO: Adding this back to prevent breaking people's code etc., but
# we should consider removing it
self._meta["pipeline"] = self.pipe_names
self._meta["disabled"] = list(self.disabled)
self._meta["disabled"] = self.disabled
return self._meta
@meta.setter
@ -241,7 +241,7 @@ class Language:
if pipe_meta.default_score_weights:
score_weights.append(pipe_meta.default_score_weights)
self._config["nlp"]["pipeline"] = self.component_names
self._config["nlp"]["disabled"] = list(self.disabled)
self._config["nlp"]["disabled"] = self.disabled
self._config["components"] = pipeline
self._config["training"]["score_weights"] = combine_score_weights(score_weights)
if not srsly.is_json_serializable(self._config):
@ -252,6 +252,14 @@ class Language:
def config(self, value: Config) -> None:
self._config = value
@property
def disabled(self) -> List[str]:
"""Get the names of all disabled components.
RETURNS (List[str]): The disabled components.
"""
return list(self._disabled)
@property
def factory_names(self) -> List[str]:
"""Get names of all available factories.
@ -277,7 +285,7 @@ class Language:
RETURNS (List[Tuple[str, Callable[[Doc], Doc]]]): The pipeline.
"""
return [(name, p) for name, p in self.components if name not in self.disabled]
return [(name, p) for name, p in self.components if name not in self._disabled]
@property
def pipe_names(self) -> List[str]:
@ -855,7 +863,7 @@ class Language:
self._pipe_configs.pop(name)
# Make sure the name is also removed from the set of disabled components
if name in self.disabled:
self.disabled.remove(name)
self._disabled.remove(name)
return removed
def disable_pipe(self, name: str) -> None:
@ -867,7 +875,7 @@ class Language:
"""
if name not in self.component_names:
raise ValueError(Errors.E001.format(name=name, opts=self.component_names))
self.disabled.add(name)
self._disabled.add(name)
def enable_pipe(self, name: str) -> None:
"""Enable a previously disabled pipeline component so it's run as part
@ -878,7 +886,7 @@ class Language:
if name not in self.component_names:
raise ValueError(Errors.E001.format(name=name, opts=self.component_names))
if name in self.disabled:
self.disabled.remove(name)
self._disabled.remove(name)
def __call__(
self,
@ -1540,7 +1548,7 @@ class Language:
source_name = pipe_cfg.get("component", pipe_name)
nlp.add_pipe(source_name, source=source_nlps[model], name=pipe_name)
disabled_pipes = [*config["nlp"]["disabled"], *disable]
nlp.disabled = set(p for p in disabled_pipes if p not in exclude)
nlp._disabled = set(p for p in disabled_pipes if p not in exclude)
nlp.config = filled if auto_fill else config
nlp.resolved = resolved
if after_pipeline_creation is not None:

View File

@ -281,7 +281,7 @@ def test_disable_enable_pipes():
assert nlp.pipeline == [(f"{name}1", c1), (f"{name}2", c2)]
assert nlp.pipe_names == [f"{name}1", f"{name}2"]
nlp.disable_pipe(f"{name}1")
assert nlp.disabled == set([f"{name}1"])
assert nlp.disabled == [f"{name}1"]
assert nlp.component_names == [f"{name}1", f"{name}2"]
assert nlp.pipe_names == [f"{name}2"]
assert nlp.config["nlp"]["disabled"] == [f"{name}1"]
@ -289,7 +289,7 @@ def test_disable_enable_pipes():
assert results[f"{name}1"] == "" # didn't run
assert results[f"{name}2"] == "hello" # ran
nlp.enable_pipe(f"{name}1")
assert nlp.disabled == set()
assert nlp.disabled == []
assert nlp.pipe_names == [f"{name}1", f"{name}2"]
assert nlp.config["nlp"]["disabled"] == []
nlp("world")
@ -301,7 +301,7 @@ def test_disable_enable_pipes():
assert nlp.pipeline == [(f"{name}1", c1)]
assert nlp.component_names == [f"{name}1"]
assert nlp.pipe_names == [f"{name}1"]
assert nlp.disabled == set()
assert nlp.disabled == []
assert nlp.config["nlp"]["disabled"] == []
nlp.rename_pipe(f"{name}1", name)
assert nlp.components == [(name, c1)]

View File

@ -187,7 +187,7 @@ def test_serialize_pipeline_disable_enable():
nlp2 = English.from_config(config)
assert nlp2.pipe_names == ["ner"]
assert nlp2.component_names == ["ner", "tagger"]
assert nlp2.disabled == set(["tagger"])
assert nlp2.disabled == ["tagger"]
assert nlp2.config["nlp"]["disabled"] == ["tagger"]
with make_tempdir() as d:
nlp2.to_disk(d)
@ -199,10 +199,10 @@ def test_serialize_pipeline_disable_enable():
nlp4 = spacy.load(d, disable=["ner"])
assert nlp4.pipe_names == []
assert nlp4.component_names == ["ner", "tagger"]
assert nlp4.disabled == set(["ner", "tagger"])
assert nlp4.disabled == ["ner", "tagger"]
with make_tempdir() as d:
nlp.to_disk(d)
nlp5 = spacy.load(d, exclude=["tagger"])
assert nlp5.pipe_names == ["ner"]
assert nlp5.component_names == ["ner"]
assert nlp5.disabled == set()
assert nlp5.disabled == []