Also support passing list to Language.disable_pipes (#4521)

* Also support passing list to Language.disable_pipes

* Adjust internals
This commit is contained in:
Ines Montani 2019-10-25 16:19:08 +02:00 committed by Matthew Honnibal
parent 1185702993
commit d2da117114
7 changed files with 42 additions and 13 deletions

View File

@ -156,8 +156,7 @@ def train(
"`lang` argument ('{}') ".format(nlp.lang, lang), "`lang` argument ('{}') ".format(nlp.lang, lang),
exits=1, exits=1,
) )
other_pipes = [pipe for pipe in nlp.pipe_names if pipe not in pipeline] nlp.disable_pipes([p for p in nlp.pipe_names if p not in pipeline])
nlp.disable_pipes(*other_pipes)
for pipe in pipeline: for pipe in pipeline:
if pipe not in nlp.pipe_names: if pipe not in nlp.pipe_names:
if pipe == "parser": if pipe == "parser":

View File

@ -448,6 +448,8 @@ class Language(object):
DOCS: https://spacy.io/api/language#disable_pipes DOCS: https://spacy.io/api/language#disable_pipes
""" """
if len(names) == 1 and isinstance(names[0], (list, tuple)):
names = names[0] # support list of names instead of spread
return DisabledPipes(self, *names) return DisabledPipes(self, *names)
def make_doc(self, text): def make_doc(self, text):

View File

@ -187,7 +187,7 @@ class EntityRuler(object):
] ]
except ValueError: except ValueError:
subsequent_pipes = [] subsequent_pipes = []
with self.nlp.disable_pipes(*subsequent_pipes): with self.nlp.disable_pipes(subsequent_pipes):
for entry in patterns: for entry in patterns:
label = entry["label"] label = entry["label"]
if "id" in entry: if "id" in entry:

View File

@ -105,6 +105,16 @@ def test_disable_pipes_context(nlp, name):
assert nlp.has_pipe(name) assert nlp.has_pipe(name)
def test_disable_pipes_list_arg(nlp):
for name in ["c1", "c2", "c3"]:
nlp.add_pipe(new_pipe, name=name)
assert nlp.has_pipe(name)
with nlp.disable_pipes(["c1", "c2"]):
assert not nlp.has_pipe("c1")
assert not nlp.has_pipe("c2")
assert nlp.has_pipe("c3")
@pytest.mark.parametrize("n_pipes", [100]) @pytest.mark.parametrize("n_pipes", [100])
def test_add_lots_of_pipes(nlp, n_pipes): def test_add_lots_of_pipes(nlp, n_pipes):
for i in range(n_pipes): for i in range(n_pipes):

View File

@ -34,8 +34,7 @@ def test_issue3611():
nlp.add_pipe(textcat, last=True) nlp.add_pipe(textcat, last=True)
# training the network # training the network
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "textcat"] with nlp.disable_pipes([p for p in nlp.pipe_names if p != "textcat"]):
with nlp.disable_pipes(*other_pipes):
optimizer = nlp.begin_training() optimizer = nlp.begin_training()
for i in range(3): for i in range(3):
losses = {} losses = {}

View File

@ -34,8 +34,7 @@ def test_issue4030():
nlp.add_pipe(textcat, last=True) nlp.add_pipe(textcat, last=True)
# training the network # training the network
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "textcat"] with nlp.disable_pipes([p for p in nlp.pipe_names if p != "textcat"]):
with nlp.disable_pipes(*other_pipes):
optimizer = nlp.begin_training() optimizer = nlp.begin_training()
for i in range(3): for i in range(3):
losses = {} losses = {}

View File

@ -323,18 +323,38 @@ you can use to undo your changes.
> #### Example > #### Example
> >
> ```python > ```python
> with nlp.disable_pipes('tagger', 'parser'): > # New API as of v2.2.2
> with nlp.disable_pipes(["tagger", "parser"]):
> nlp.begin_training()
>
> with nlp.disable_pipes("tagger", "parser"):
> nlp.begin_training() > nlp.begin_training()
> >
> disabled = nlp.disable_pipes('tagger', 'parser') > disabled = nlp.disable_pipes("tagger", "parser")
> nlp.begin_training() > nlp.begin_training()
> disabled.restore() > disabled.restore()
> ``` > ```
| Name | Type | Description | | Name | Type | Description |
| ----------- | --------------- | ------------------------------------------------------------------------------------ | | ----------------------------------------- | --------------- | ------------------------------------------------------------------------------------ |
| `*disabled` | unicode | Names of pipeline components to disable. | | `disabled` <Tag variant="new">2.2.2</Tag> | list | Names of pipeline components to disable. |
| **RETURNS** | `DisabledPipes` | The disabled pipes that can be restored by calling the object's `.restore()` method. | | `*disabled` | unicode | Names of pipeline components to disable. |
| **RETURNS** | `DisabledPipes` | The disabled pipes that can be restored by calling the object's `.restore()` method. |
<Infobox title="Changed in v2.2.2" variant="warning">
As of spaCy v2.2.2, the `Language.disable_pipes` method can also take a list of
component names as its first argument (instead of a variable number of
arguments). This is especially useful if you're generating the component names
to disable programmatically. The new syntax will become the default in the
future.
```diff
- disabled = nlp.disable_pipes("tagger", "parser")
+ disabled = nlp.disable_pipes(["tagger", "parser"])
```
</Infobox>
## Language.to_disk {#to_disk tag="method" new="2"} ## Language.to_disk {#to_disk tag="method" new="2"}