Support disable keyword in Language.__init__

This commit is contained in:
Matthew Honnibal 2017-06-05 13:13:07 +02:00
parent b4cdd05466
commit 2479cde446

View File

@ -85,11 +85,13 @@ class BaseDefaults(object):
return NeuralEntityRecognizer(nlp.vocab, **cfg)
@classmethod
def create_pipeline(cls, nlp=None):
def create_pipeline(cls, nlp=None, disable=tuple()):
meta = nlp.meta if nlp is not None else {}
# Resolve strings, like "cnn", "lstm", etc
pipeline = []
for entry in cls.pipeline:
if entry in disable or getattr(entry, 'name', entry) in disable:
continue
factory = cls.Defaults.factories[entry]
pipeline.append(factory(nlp, **meta.get(entry, {})))
return pipeline
@ -141,7 +143,8 @@ class Language(object):
Defaults = BaseDefaults
lang = None
def __init__(self, vocab=True, make_doc=True, pipeline=None, meta={}, **kwargs):
def __init__(self, vocab=True, make_doc=True, pipeline=None, meta={},
disable=tuple(), **kwargs):
"""Initialise a Language object.
vocab (Vocab): A `Vocab` object. If `True`, a vocab is created via
@ -151,12 +154,14 @@ class Language(object):
pipeline (list): A list of annotation processes or IDs of annotation,
processes, e.g. a `Tagger` object, or `'tagger'`. IDs are looked
up in `Language.Defaults.factories`.
disable (list): A list of component names to exclude from the pipeline.
The disable list has priority over the pipeline list -- if the same
string occurs in both, the component is not loaded.
meta (dict): Custom meta data for the Language class. Is written to by
models to add model meta data.
RETURNS (Language): The newly constructed object.
"""
self.meta = dict(meta)
if vocab is True:
factory = self.Defaults.create_vocab
vocab = factory(self, **meta.get('vocab', {}))
@ -166,9 +171,13 @@ class Language(object):
make_doc = factory(self, **meta.get('tokenizer', {}))
self.tokenizer = make_doc
if pipeline is True:
self.pipeline = self.Defaults.create_pipeline(self)
self.pipeline = self.Defaults.create_pipeline(self, disable)
elif pipeline:
self.pipeline = list(pipeline)
# Careful not to do getattr(p, 'name', None) here
# If we had disable=[None], we'd disable everything!
self.pipeline = [p for p in pipeline
if p not in disable
and getattr(p, 'name', p) not in disable]
# Resolve strings, like "cnn", "lstm", etc
for i, entry in enumerate(self.pipeline):
if entry in self.Defaults.factories: