Also support passing list to Language.disable_pipes (#4521)

* Also support passing list to Language.disable_pipes

* Adjust internals
This commit is contained in:
Ines Montani 2019-10-25 16:19:08 +02:00 committed by Matthew Honnibal
parent 1185702993
commit d2da117114
7 changed files with 42 additions and 13 deletions

View File

@ -156,8 +156,7 @@ def train(
"`lang` argument ('{}') ".format(nlp.lang, lang),
exits=1,
)
other_pipes = [pipe for pipe in nlp.pipe_names if pipe not in pipeline]
nlp.disable_pipes(*other_pipes)
nlp.disable_pipes([p for p in nlp.pipe_names if p not in pipeline])
for pipe in pipeline:
if pipe not in nlp.pipe_names:
if pipe == "parser":

View File

@ -448,6 +448,8 @@ class Language(object):
DOCS: https://spacy.io/api/language#disable_pipes
"""
if len(names) == 1 and isinstance(names[0], (list, tuple)):
names = names[0] # support list of names instead of spread
return DisabledPipes(self, *names)
def make_doc(self, text):

View File

@ -187,7 +187,7 @@ class EntityRuler(object):
]
except ValueError:
subsequent_pipes = []
with self.nlp.disable_pipes(*subsequent_pipes):
with self.nlp.disable_pipes(subsequent_pipes):
for entry in patterns:
label = entry["label"]
if "id" in entry:

View File

@ -105,6 +105,16 @@ def test_disable_pipes_context(nlp, name):
assert nlp.has_pipe(name)
def test_disable_pipes_list_arg(nlp):
for name in ["c1", "c2", "c3"]:
nlp.add_pipe(new_pipe, name=name)
assert nlp.has_pipe(name)
with nlp.disable_pipes(["c1", "c2"]):
assert not nlp.has_pipe("c1")
assert not nlp.has_pipe("c2")
assert nlp.has_pipe("c3")
@pytest.mark.parametrize("n_pipes", [100])
def test_add_lots_of_pipes(nlp, n_pipes):
for i in range(n_pipes):

View File

@ -34,8 +34,7 @@ def test_issue3611():
nlp.add_pipe(textcat, last=True)
# training the network
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "textcat"]
with nlp.disable_pipes(*other_pipes):
with nlp.disable_pipes([p for p in nlp.pipe_names if p != "textcat"]):
optimizer = nlp.begin_training()
for i in range(3):
losses = {}

View File

@ -34,8 +34,7 @@ def test_issue4030():
nlp.add_pipe(textcat, last=True)
# training the network
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "textcat"]
with nlp.disable_pipes(*other_pipes):
with nlp.disable_pipes([p for p in nlp.pipe_names if p != "textcat"]):
optimizer = nlp.begin_training()
for i in range(3):
losses = {}

View File

@ -323,18 +323,38 @@ you can use to undo your changes.
> #### Example
>
> ```python
> with nlp.disable_pipes('tagger', 'parser'):
> # New API as of v2.2.2
> with nlp.disable_pipes(["tagger", "parser"]):
> nlp.begin_training()
>
> with nlp.disable_pipes("tagger", "parser"):
> nlp.begin_training()
>
> disabled = nlp.disable_pipes('tagger', 'parser')
> disabled = nlp.disable_pipes("tagger", "parser")
> nlp.begin_training()
> disabled.restore()
> ```
| Name | Type | Description |
| ----------- | --------------- | ------------------------------------------------------------------------------------ |
| `*disabled` | unicode | Names of pipeline components to disable. |
| **RETURNS** | `DisabledPipes` | The disabled pipes that can be restored by calling the object's `.restore()` method. |
| Name | Type | Description |
| ----------------------------------------- | --------------- | ------------------------------------------------------------------------------------ |
| `disabled` <Tag variant="new">2.2.2</Tag> | list | Names of pipeline components to disable. |
| `*disabled` | unicode | Names of pipeline components to disable. |
| **RETURNS** | `DisabledPipes` | The disabled pipes that can be restored by calling the object's `.restore()` method. |
<Infobox title="Changed in v2.2.2" variant="warning">
As of spaCy v2.2.2, the `Language.disable_pipes` method can also take a list of
component names as its first argument (instead of a variable number of
arguments). This is especially useful if you're generating the component names
to disable programmatically. The new syntax will become the default in the
future.
```diff
- disabled = nlp.disable_pipes("tagger", "parser")
+ disabled = nlp.disable_pipes(["tagger", "parser"])
```
</Infobox>
## Language.to_disk {#to_disk tag="method" new="2"}