Add Language.disable_pipes()

This commit is contained in:
Matthew Honnibal 2017-10-25 13:46:41 +02:00
parent d9bb1e5de8
commit e70f80f29e
2 changed files with 78 additions and 0 deletions

View File

@ -1,6 +1,7 @@
# coding: utf8 # coding: utf8
from __future__ import absolute_import, unicode_literals from __future__ import absolute_import, unicode_literals
from contextlib import contextmanager from contextlib import contextmanager
import copy
from thinc.neural import Model from thinc.neural import Model
from thinc.neural.optimizers import Adam from thinc.neural.optimizers import Adam
@ -329,6 +330,29 @@ class Language(object):
doc = proc(doc) doc = proc(doc)
return doc return doc
def disable_pipes(self, *names):
'''Disable one or more pipeline components.
If used as a context manager, the pipeline will be restored to the initial
state at the end of the block. Otherwise, a DisabledPipes object is
returned, that has a `.restore()` method you can use to undo your
changes.
EXAMPLE:
>>> nlp.add_pipe('parser')
>>> nlp.add_pipe('tagger')
>>> with nlp.disable_pipes('parser', 'tagger'):
>>> assert not nlp.has_pipe('parser')
>>> assert nlp.has_pipe('parser')
>>> disabled = nlp.disable_pipes('parser')
>>> assert len(disabled) == 1
>>> assert not nlp.has_pipe('parser')
>>> disabled.restore()
>>> assert nlp.has_pipe('parser')
'''
return DisabledPipes(self, *names)
def make_doc(self, text): def make_doc(self, text):
return self.tokenizer(text) return self.tokenizer(text)
@ -655,6 +679,42 @@ class Language(object):
return self return self
class DisabledPipes(list):
'''Manager for temporary pipeline disabling.'''
def __init__(self, nlp, *names):
self.nlp = nlp
self.names = names
# Important! Not deep copy -- we just want the container (but we also
# want to support people providing arbitrarily typed nlp.pipeline
# objects.)
self.original_pipeline = copy.copy(nlp.pipeline)
list.__init__(self)
self.extend(nlp.remove_pipe(name) for name in names)
def __enter__(self):
pass
def __exit__(self, *args):
self.restore()
def restore(self):
'''Restore the pipeline to its state when DisabledPipes was created.'''
current, self.nlp.pipeline = self.nlp.pipeline, self.original_pipeline
unexpected = [name for name in current if not self.nlp.has_pipe(name)]
if unexpected:
# Don't change the pipeline if we're raising an error.
self.nlp.pipeline = current
msg = (
"Some current components would be lost when restoring "
"previous pipeline state. If you added components after "
"calling nlp.disable_pipes(), you should remove them "
"explicitly with nlp.remove_pipe() before the pipeline is "
"restore. Names of the new components: %s"
)
raise ValueError(msg % unexpected)
self[:] = []
def unpickle_language(vocab, meta, bytes_data): def unpickle_language(vocab, meta, bytes_data):
lang = Language(vocab=vocab) lang = Language(vocab=vocab)
lang.from_bytes(bytes_data) lang.from_bytes(bytes_data)

View File

@ -82,3 +82,21 @@ def test_remove_pipe(nlp, name):
assert not len(nlp.pipeline) assert not len(nlp.pipeline)
assert removed_name == name assert removed_name == name
assert removed_component == new_pipe assert removed_component == new_pipe
@pytest.mark.parametrize('name', ['my_component'])
def test_disable_pipes_method(nlp, name):
nlp.add_pipe(new_pipe, name=name)
assert nlp.has_pipe(name)
disabled = nlp.disable_pipes(name)
assert not nlp.has_pipe(name)
disabled.restore()
@pytest.mark.parametrize('name', ['my_component'])
def test_disable_pipes_context(nlp, name):
nlp.add_pipe(new_pipe, name=name)
assert nlp.has_pipe(name)
with nlp.disable_pipes(name):
assert not nlp.has_pipe(name)
assert nlp.has_pipe(name)