mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-25 00:34:20 +03:00
Support disable keyword in Language.__init__
This commit is contained in:
parent
b4cdd05466
commit
2479cde446
|
@ -85,11 +85,13 @@ class BaseDefaults(object):
|
|||
return NeuralEntityRecognizer(nlp.vocab, **cfg)
|
||||
|
||||
@classmethod
|
||||
def create_pipeline(cls, nlp=None):
|
||||
def create_pipeline(cls, nlp=None, disable=tuple()):
|
||||
meta = nlp.meta if nlp is not None else {}
|
||||
# Resolve strings, like "cnn", "lstm", etc
|
||||
pipeline = []
|
||||
for entry in cls.pipeline:
|
||||
if entry in disable or getattr(entry, 'name', entry) in disable:
|
||||
continue
|
||||
factory = cls.Defaults.factories[entry]
|
||||
pipeline.append(factory(nlp, **meta.get(entry, {})))
|
||||
return pipeline
|
||||
|
@ -141,7 +143,8 @@ class Language(object):
|
|||
Defaults = BaseDefaults
|
||||
lang = None
|
||||
|
||||
def __init__(self, vocab=True, make_doc=True, pipeline=None, meta={}, **kwargs):
|
||||
def __init__(self, vocab=True, make_doc=True, pipeline=None, meta={},
|
||||
disable=tuple(), **kwargs):
|
||||
"""Initialise a Language object.
|
||||
|
||||
vocab (Vocab): A `Vocab` object. If `True`, a vocab is created via
|
||||
|
@ -151,12 +154,14 @@ class Language(object):
|
|||
pipeline (list): A list of annotation processes or IDs of annotation,
|
||||
processes, e.g. a `Tagger` object, or `'tagger'`. IDs are looked
|
||||
up in `Language.Defaults.factories`.
|
||||
disable (list): A list of component names to exclude from the pipeline.
|
||||
The disable list has priority over the pipeline list -- if the same
|
||||
string occurs in both, the component is not loaded.
|
||||
meta (dict): Custom meta data for the Language class. Is written to by
|
||||
models to add model meta data.
|
||||
RETURNS (Language): The newly constructed object.
|
||||
"""
|
||||
self.meta = dict(meta)
|
||||
|
||||
if vocab is True:
|
||||
factory = self.Defaults.create_vocab
|
||||
vocab = factory(self, **meta.get('vocab', {}))
|
||||
|
@ -166,9 +171,13 @@ class Language(object):
|
|||
make_doc = factory(self, **meta.get('tokenizer', {}))
|
||||
self.tokenizer = make_doc
|
||||
if pipeline is True:
|
||||
self.pipeline = self.Defaults.create_pipeline(self)
|
||||
self.pipeline = self.Defaults.create_pipeline(self, disable)
|
||||
elif pipeline:
|
||||
self.pipeline = list(pipeline)
|
||||
# Careful not to do getattr(p, 'name', None) here
|
||||
# If we had disable=[None], we'd disable everything!
|
||||
self.pipeline = [p for p in pipeline
|
||||
if p not in disable
|
||||
and getattr(p, 'name', p) not in disable]
|
||||
# Resolve strings, like "cnn", "lstm", etc
|
||||
for i, entry in enumerate(self.pipeline):
|
||||
if entry in self.Defaults.factories:
|
||||
|
|
Loading…
Reference in New Issue
Block a user