mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-28 02:04:07 +03:00
Add Language.disable_pipes()
This commit is contained in:
parent
d9bb1e5de8
commit
e70f80f29e
|
@ -1,6 +1,7 @@
|
||||||
# coding: utf8
|
# coding: utf8
|
||||||
from __future__ import absolute_import, unicode_literals
|
from __future__ import absolute_import, unicode_literals
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
import copy
|
||||||
|
|
||||||
from thinc.neural import Model
|
from thinc.neural import Model
|
||||||
from thinc.neural.optimizers import Adam
|
from thinc.neural.optimizers import Adam
|
||||||
|
@ -329,6 +330,29 @@ class Language(object):
|
||||||
doc = proc(doc)
|
doc = proc(doc)
|
||||||
return doc
|
return doc
|
||||||
|
|
||||||
|
def disable_pipes(self, *names):
|
||||||
|
'''Disable one or more pipeline components.
|
||||||
|
|
||||||
|
If used as a context manager, the pipeline will be restored to the initial
|
||||||
|
state at the end of the block. Otherwise, a DisabledPipes object is
|
||||||
|
returned, that has a `.restore()` method you can use to undo your
|
||||||
|
changes.
|
||||||
|
|
||||||
|
EXAMPLE:
|
||||||
|
|
||||||
|
>>> nlp.add_pipe('parser')
|
||||||
|
>>> nlp.add_pipe('tagger')
|
||||||
|
>>> with nlp.disable_pipes('parser', 'tagger'):
|
||||||
|
>>> assert not nlp.has_pipe('parser')
|
||||||
|
>>> assert nlp.has_pipe('parser')
|
||||||
|
>>> disabled = nlp.disable_pipes('parser')
|
||||||
|
>>> assert len(disabled) == 1
|
||||||
|
>>> assert not nlp.has_pipe('parser')
|
||||||
|
>>> disabled.restore()
|
||||||
|
>>> assert nlp.has_pipe('parser')
|
||||||
|
'''
|
||||||
|
return DisabledPipes(self, *names)
|
||||||
|
|
||||||
def make_doc(self, text):
|
def make_doc(self, text):
|
||||||
return self.tokenizer(text)
|
return self.tokenizer(text)
|
||||||
|
|
||||||
|
@ -655,6 +679,42 @@ class Language(object):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class DisabledPipes(list):
|
||||||
|
'''Manager for temporary pipeline disabling.'''
|
||||||
|
def __init__(self, nlp, *names):
|
||||||
|
self.nlp = nlp
|
||||||
|
self.names = names
|
||||||
|
# Important! Not deep copy -- we just want the container (but we also
|
||||||
|
# want to support people providing arbitrarily typed nlp.pipeline
|
||||||
|
# objects.)
|
||||||
|
self.original_pipeline = copy.copy(nlp.pipeline)
|
||||||
|
list.__init__(self)
|
||||||
|
self.extend(nlp.remove_pipe(name) for name in names)
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __exit__(self, *args):
|
||||||
|
self.restore()
|
||||||
|
|
||||||
|
def restore(self):
|
||||||
|
'''Restore the pipeline to its state when DisabledPipes was created.'''
|
||||||
|
current, self.nlp.pipeline = self.nlp.pipeline, self.original_pipeline
|
||||||
|
unexpected = [name for name in current if not self.nlp.has_pipe(name)]
|
||||||
|
if unexpected:
|
||||||
|
# Don't change the pipeline if we're raising an error.
|
||||||
|
self.nlp.pipeline = current
|
||||||
|
msg = (
|
||||||
|
"Some current components would be lost when restoring "
|
||||||
|
"previous pipeline state. If you added components after "
|
||||||
|
"calling nlp.disable_pipes(), you should remove them "
|
||||||
|
"explicitly with nlp.remove_pipe() before the pipeline is "
|
||||||
|
"restore. Names of the new components: %s"
|
||||||
|
)
|
||||||
|
raise ValueError(msg % unexpected)
|
||||||
|
self[:] = []
|
||||||
|
|
||||||
|
|
||||||
def unpickle_language(vocab, meta, bytes_data):
|
def unpickle_language(vocab, meta, bytes_data):
|
||||||
lang = Language(vocab=vocab)
|
lang = Language(vocab=vocab)
|
||||||
lang.from_bytes(bytes_data)
|
lang.from_bytes(bytes_data)
|
||||||
|
|
|
@ -82,3 +82,21 @@ def test_remove_pipe(nlp, name):
|
||||||
assert not len(nlp.pipeline)
|
assert not len(nlp.pipeline)
|
||||||
assert removed_name == name
|
assert removed_name == name
|
||||||
assert removed_component == new_pipe
|
assert removed_component == new_pipe
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('name', ['my_component'])
|
||||||
|
def test_disable_pipes_method(nlp, name):
|
||||||
|
nlp.add_pipe(new_pipe, name=name)
|
||||||
|
assert nlp.has_pipe(name)
|
||||||
|
disabled = nlp.disable_pipes(name)
|
||||||
|
assert not nlp.has_pipe(name)
|
||||||
|
disabled.restore()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('name', ['my_component'])
|
||||||
|
def test_disable_pipes_context(nlp, name):
|
||||||
|
nlp.add_pipe(new_pipe, name=name)
|
||||||
|
assert nlp.has_pipe(name)
|
||||||
|
with nlp.disable_pipes(name):
|
||||||
|
assert not nlp.has_pipe(name)
|
||||||
|
assert nlp.has_pipe(name)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user