mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Implement new Language methods and pipeline API
This commit is contained in:
parent
3468d535ad
commit
212c8f0711
|
@ -70,59 +70,7 @@ class BaseDefaults(object):
|
|||
prefix_search=prefix_search, suffix_search=suffix_search,
|
||||
infix_finditer=infix_finditer, token_match=token_match)
|
||||
|
||||
@classmethod
|
||||
def create_tagger(cls, nlp=None, **cfg):
|
||||
if nlp is None:
|
||||
return NeuralTagger(cls.create_vocab(nlp), **cfg)
|
||||
else:
|
||||
return NeuralTagger(nlp.vocab, **cfg)
|
||||
|
||||
@classmethod
|
||||
def create_parser(cls, nlp=None, **cfg):
|
||||
if nlp is None:
|
||||
return NeuralDependencyParser(cls.create_vocab(nlp), **cfg)
|
||||
else:
|
||||
return NeuralDependencyParser(nlp.vocab, **cfg)
|
||||
|
||||
@classmethod
|
||||
def create_entity(cls, nlp=None, **cfg):
|
||||
if nlp is None:
|
||||
return NeuralEntityRecognizer(cls.create_vocab(nlp), **cfg)
|
||||
else:
|
||||
return NeuralEntityRecognizer(nlp.vocab, **cfg)
|
||||
|
||||
@classmethod
|
||||
def create_pipeline(cls, nlp=None, disable=tuple()):
|
||||
meta = nlp.meta if nlp is not None else {}
|
||||
# Resolve strings, like "cnn", "lstm", etc
|
||||
pipeline = []
|
||||
for entry in meta.get('pipeline', []):
|
||||
if entry in disable or getattr(entry, 'name', entry) in disable:
|
||||
continue
|
||||
factory = cls.Defaults.factories[entry]
|
||||
pipeline.append(factory(nlp, **meta.get(entry, {})))
|
||||
return pipeline
|
||||
|
||||
factories = {
|
||||
'make_doc': create_tokenizer,
|
||||
'tensorizer': lambda nlp, **cfg: [TokenVectorEncoder(nlp.vocab, **cfg)],
|
||||
'tagger': lambda nlp, **cfg: [NeuralTagger(nlp.vocab, **cfg)],
|
||||
'parser': lambda nlp, **cfg: [
|
||||
NeuralDependencyParser(nlp.vocab, **cfg),
|
||||
nonproj.deprojectivize],
|
||||
'ner': lambda nlp, **cfg: [NeuralEntityRecognizer(nlp.vocab, **cfg)],
|
||||
'similarity': lambda nlp, **cfg: [SimilarityHook(nlp.vocab, **cfg)],
|
||||
'textcat': lambda nlp, **cfg: [TextCategorizer(nlp.vocab, **cfg)],
|
||||
# Temporary compatibility -- delete after pivot
|
||||
'token_vectors': lambda nlp, **cfg: [TokenVectorEncoder(nlp.vocab, **cfg)],
|
||||
'tags': lambda nlp, **cfg: [NeuralTagger(nlp.vocab, **cfg)],
|
||||
'dependencies': lambda nlp, **cfg: [
|
||||
NeuralDependencyParser(nlp.vocab, **cfg),
|
||||
nonproj.deprojectivize,
|
||||
],
|
||||
'entities': lambda nlp, **cfg: [NeuralEntityRecognizer(nlp.vocab, **cfg)],
|
||||
}
|
||||
|
||||
pipe_names = ['tensorizer', 'tagger', 'parser', 'ner']
|
||||
token_match = TOKEN_MATCH
|
||||
prefixes = tuple(TOKENIZER_PREFIXES)
|
||||
suffixes = tuple(TOKENIZER_SUFFIXES)
|
||||
|
@ -152,8 +100,17 @@ class Language(object):
|
|||
Defaults = BaseDefaults
|
||||
lang = None
|
||||
|
||||
def __init__(self, vocab=True, make_doc=True, pipeline=None,
|
||||
meta={}, disable=tuple(), **kwargs):
|
||||
factories = {
|
||||
'tokenizer': lambda nlp: nlp.Defaults.create_tokenizer(nlp),
|
||||
'tensorizer': lambda nlp, **cfg: TokenVectorEncoder(nlp.vocab, **cfg),
|
||||
'tagger': lambda nlp, **cfg: NeuralTagger(nlp.vocab, **cfg),
|
||||
'parser': lambda nlp, **cfg: NeuralDependencyParser(nlp.vocab, **cfg), # nonproj.deprojectivize,
|
||||
'ner': lambda nlp, **cfg: NeuralEntityRecognizer(nlp.vocab, **cfg),
|
||||
'similarity': lambda nlp, **cfg: SimilarityHook(nlp.vocab, **cfg),
|
||||
'textcat': lambda nlp, **cfg: TextCategorizer(nlp.vocab, **cfg)
|
||||
}
|
||||
|
||||
def __init__(self, vocab=True, make_doc=True, meta={}, **kwargs):
|
||||
"""Initialise a Language object.
|
||||
|
||||
vocab (Vocab): A `Vocab` object. If `True`, a vocab is created via
|
||||
|
@ -179,28 +136,7 @@ class Language(object):
|
|||
factory = self.Defaults.create_tokenizer
|
||||
make_doc = factory(self, **meta.get('tokenizer', {}))
|
||||
self.tokenizer = make_doc
|
||||
if pipeline is True:
|
||||
self.pipeline = self.Defaults.create_pipeline(self, disable)
|
||||
elif pipeline:
|
||||
# Careful not to do getattr(p, 'name', None) here
|
||||
# If we had disable=[None], we'd disable everything!
|
||||
self.pipeline = [p for p in pipeline
|
||||
if p not in disable
|
||||
and getattr(p, 'name', p) not in disable]
|
||||
# Resolve strings, like "cnn", "lstm", etc
|
||||
for i, entry in enumerate(self.pipeline):
|
||||
if entry in self.Defaults.factories:
|
||||
factory = self.Defaults.factories[entry]
|
||||
self.pipeline[i] = factory(self, **meta.get(entry, {}))
|
||||
else:
|
||||
self.pipeline = []
|
||||
flat_list = []
|
||||
for pipe in self.pipeline:
|
||||
if isinstance(pipe, list):
|
||||
flat_list.extend(pipe)
|
||||
else:
|
||||
flat_list.append(pipe)
|
||||
self.pipeline = flat_list
|
||||
self.pipeline = []
|
||||
self._optimizer = None
|
||||
|
||||
@property
|
||||
|
@ -214,11 +150,7 @@ class Language(object):
|
|||
self._meta.setdefault('email', '')
|
||||
self._meta.setdefault('url', '')
|
||||
self._meta.setdefault('license', '')
|
||||
pipeline = []
|
||||
for component in self.pipeline:
|
||||
if hasattr(component, 'name'):
|
||||
pipeline.append(component.name)
|
||||
self._meta['pipeline'] = pipeline
|
||||
self._meta['pipeline'] = self.pipe_names
|
||||
return self._meta
|
||||
|
||||
@meta.setter
|
||||
|
@ -228,31 +160,133 @@ class Language(object):
|
|||
# Conveniences to access pipeline components
|
||||
@property
|
||||
def tensorizer(self):
|
||||
return self.get_component('tensorizer')
|
||||
return self.get_pipe('tensorizer')
|
||||
|
||||
@property
|
||||
def tagger(self):
|
||||
return self.get_component('tagger')
|
||||
return self.get_pipe('tagger')
|
||||
|
||||
@property
|
||||
def parser(self):
|
||||
return self.get_component('parser')
|
||||
return self.get_pipe('parser')
|
||||
|
||||
@property
|
||||
def entity(self):
|
||||
return self.get_component('ner')
|
||||
return self.get_pipe('ner')
|
||||
|
||||
@property
|
||||
def matcher(self):
|
||||
return self.get_component('matcher')
|
||||
return self.get_pipe('matcher')
|
||||
|
||||
def get_component(self, name):
|
||||
if self.pipeline in (True, None):
|
||||
return None
|
||||
for proc in self.pipeline:
|
||||
if hasattr(proc, 'name') and proc.name.endswith(name):
|
||||
return proc
|
||||
return None
|
||||
@property
|
||||
def pipe_names(self):
|
||||
"""Get names of available pipeline components.
|
||||
|
||||
RETURNS (list): List of component name strings, in order.
|
||||
"""
|
||||
return [pipe_name for pipe_name, _ in self.pipeline]
|
||||
|
||||
def get_pipe(self, name):
|
||||
"""Get a pipeline component for a given component name.
|
||||
|
||||
name (unicode): Name of pipeline component to get.
|
||||
RETURNS (callable): The pipeline component.
|
||||
"""
|
||||
for pipe_name, component in self.pipeline:
|
||||
if pipe_name == name:
|
||||
return component
|
||||
msg = "No component '{}' found in pipeline. Available names: {}"
|
||||
raise KeyError(msg.format(name, self.pipe_names))
|
||||
|
||||
def create_pipe(self, name, config=dict()):
|
||||
"""Create a pipeline component from a factory.
|
||||
|
||||
name (unicode): Factory name to look up in `Language.factories`.
|
||||
RETURNS (callable): Pipeline component.
|
||||
"""
|
||||
if name not in self.factories:
|
||||
raise KeyError("Can't find factory for '{}'.".format(name))
|
||||
factory = self.factories[name]
|
||||
return factory(self, **config)
|
||||
|
||||
def add_pipe(self, component, name=None, before=None, after=None,
|
||||
first=None, last=None):
|
||||
"""Add a component to the processing pipeline. Valid components are
|
||||
callables that take a `Doc` object, modify it and return it. Only one of
|
||||
before, after, first or last can be set. Default behaviour is "last".
|
||||
|
||||
component (callable): The pipeline component.
|
||||
name (unicode): Name of pipeline component. Overwrites existing
|
||||
component.name attribute if available. If no name is set and
|
||||
the component exposes no name attribute, component.__name__ is
|
||||
used. An error is raised if the name already exists in the pipeline.
|
||||
before (unicode): Component name to insert component directly before.
|
||||
after (unicode): Component name to insert component directly after.
|
||||
first (bool): Insert component first / not first in the pipeline.
|
||||
last (bool): Insert component last / not last in the pipeline.
|
||||
|
||||
EXAMPLE:
|
||||
>>> nlp.add_pipe(component, before='ner')
|
||||
>>> nlp.add_pipe(component, name='custom_name', last=True)
|
||||
"""
|
||||
if name is None:
|
||||
name = getattr(component, 'name', component.__name__)
|
||||
if name in self.pipe_names:
|
||||
raise ValueError("'{}' already exists in pipeline.".format(name))
|
||||
if sum([bool(before), bool(after), bool(first), bool(last)]) >= 2:
|
||||
msg = ("Invalid constraints. You can only set one of the "
|
||||
"following: before, after, first, last.")
|
||||
raise ValueError(msg)
|
||||
pipe = (name, component)
|
||||
if last or not any([first, before, after]):
|
||||
self.pipeline.append(pipe)
|
||||
elif first:
|
||||
self.pipeline.insert(0, pipe)
|
||||
elif before and before in self.pipe_names:
|
||||
self.pipeline.insert(self.pipe_names.index(before), pipe)
|
||||
elif after and after in self.pipe_names:
|
||||
self.pipeline.insert(self.pipe_names.index(after), pipe)
|
||||
else:
|
||||
msg = "Can't find '{}' in pipeline. Available names: {}"
|
||||
unfound = before or after
|
||||
raise ValueError(msg.format(unfound, self.pipe_names))
|
||||
|
||||
def replace_pipe(self, name, component):
|
||||
"""Replace a component in the pipeline.
|
||||
|
||||
name (unicode): Name of the component to replace.
|
||||
component (callable): Pipeline component.
|
||||
"""
|
||||
if name not in self.pipe_names:
|
||||
msg = "Can't find '{}' in pipeline. Available names: {}"
|
||||
raise ValueError(msg.format(name, self.pipe_names))
|
||||
self.pipeline[self.pipe_names.index(name)] = (name, component)
|
||||
|
||||
def rename_pipe(self, old_name, new_name):
|
||||
"""Rename a pipeline component.
|
||||
|
||||
old_name (unicode): Name of the component to rename.
|
||||
new_name (unicode): New name of the component.
|
||||
"""
|
||||
if old_name not in self.pipe_names:
|
||||
msg = "Can't find '{}' in pipeline. Available names: {}"
|
||||
raise ValueError(msg.format(old_name, self.pipe_names))
|
||||
if new_name in self.pipe_names:
|
||||
msg = "'{}' already exists in pipeline. Existing names: {}"
|
||||
raise ValueError(msg.format(new_name, self.pipe_names))
|
||||
i = self.pipe_names.index(old_name)
|
||||
self.pipeline[i] = (new_name, self.pipeline[i][1])
|
||||
|
||||
def remove_pipe(self, name):
|
||||
"""Remove a component from the pipeline.
|
||||
|
||||
name (unicode): Name of the component to remove.
|
||||
RETURNS (tuple): A (name, component) tuple of the removed component.
|
||||
"""
|
||||
if name not in self.pipe_names:
|
||||
msg = "Can't find '{}' in pipeline. Available names: {}"
|
||||
raise ValueError(msg.format(name, self.pipe_names))
|
||||
return self.pipeline.pop(self.pipe_names.index(name))
|
||||
|
||||
def __call__(self, text, disable=[]):
|
||||
"""'Apply the pipeline to some text. The text can span multiple sentences,
|
||||
|
@ -269,8 +303,7 @@ class Language(object):
|
|||
('An', 'NN')
|
||||
"""
|
||||
doc = self.make_doc(text)
|
||||
for proc in self.pipeline:
|
||||
name = getattr(proc, 'name', None)
|
||||
for name, proc in self.pipeline:
|
||||
if name in disable:
|
||||
continue
|
||||
doc = proc(doc)
|
||||
|
@ -308,7 +341,7 @@ class Language(object):
|
|||
grads[key] = (W, dW)
|
||||
pipes = list(self.pipeline)
|
||||
random.shuffle(pipes)
|
||||
for proc in pipes:
|
||||
for name, proc in pipes:
|
||||
if not hasattr(proc, 'update'):
|
||||
continue
|
||||
proc.update(docs, golds, drop=drop, sgd=get_grads, losses=losses)
|
||||
|
@ -322,7 +355,7 @@ class Language(object):
|
|||
docs_golds (iterable): Tuples of `Doc` and `GoldParse` objects.
|
||||
YIELDS (tuple): Tuples of preprocessed `Doc` and `GoldParse` objects.
|
||||
"""
|
||||
for proc in self.pipeline:
|
||||
for name, proc in self.pipeline:
|
||||
if hasattr(proc, 'preprocess_gold'):
|
||||
docs_golds = proc.preprocess_gold(docs_golds)
|
||||
for doc, gold in docs_golds:
|
||||
|
@ -371,7 +404,7 @@ class Language(object):
|
|||
else:
|
||||
device = None
|
||||
link_vectors_to_models(self.vocab)
|
||||
for proc in self.pipeline:
|
||||
for name, proc in self.pipeline:
|
||||
if hasattr(proc, 'begin_training'):
|
||||
context = proc.begin_training(get_gold_tuples(),
|
||||
pipeline=self.pipeline)
|
||||
|
@ -393,7 +426,7 @@ class Language(object):
|
|||
docs, golds = zip(*docs_golds)
|
||||
docs = list(docs)
|
||||
golds = list(golds)
|
||||
for pipe in self.pipeline:
|
||||
for name, pipe in self.pipeline:
|
||||
if not hasattr(pipe, 'pipe'):
|
||||
for doc in docs:
|
||||
pipe(doc)
|
||||
|
@ -419,7 +452,7 @@ class Language(object):
|
|||
>>> with nlp.use_params(optimizer.averages):
|
||||
>>> nlp.to_disk('/tmp/checkpoint')
|
||||
"""
|
||||
contexts = [pipe.use_params(params) for pipe
|
||||
contexts = [pipe.use_params(params) for name, pipe
|
||||
in self.pipeline if hasattr(pipe, 'use_params')]
|
||||
# TODO: Having trouble with contextlib
|
||||
# Workaround: these aren't actually context managers atm.
|
||||
|
@ -466,8 +499,7 @@ class Language(object):
|
|||
yield (doc, context)
|
||||
return
|
||||
docs = (self.make_doc(text) for text in texts)
|
||||
for proc in self.pipeline:
|
||||
name = getattr(proc, 'name', None)
|
||||
for name, proc in self.pipeline:
|
||||
if name in disable:
|
||||
continue
|
||||
if hasattr(proc, 'pipe'):
|
||||
|
@ -495,14 +527,14 @@ class Language(object):
|
|||
('tokenizer', lambda p: self.tokenizer.to_disk(p, vocab=False)),
|
||||
('meta.json', lambda p: p.open('w').write(json_dumps(self.meta)))
|
||||
))
|
||||
for proc in self.pipeline:
|
||||
for name, proc in self.pipeline:
|
||||
if not hasattr(proc, 'name'):
|
||||
continue
|
||||
if proc.name in disable:
|
||||
if name in disable:
|
||||
continue
|
||||
if not hasattr(proc, 'to_disk'):
|
||||
continue
|
||||
serializers[proc.name] = lambda p, proc=proc: proc.to_disk(p, vocab=False)
|
||||
serializers[name] = lambda p, proc=proc: proc.to_disk(p, vocab=False)
|
||||
serializers['vocab'] = lambda p: self.vocab.to_disk(p)
|
||||
util.to_disk(path, serializers, {p: False for p in disable})
|
||||
|
||||
|
@ -526,14 +558,12 @@ class Language(object):
|
|||
('tokenizer', lambda p: self.tokenizer.from_disk(p, vocab=False)),
|
||||
('meta.json', lambda p: p.open('w').write(json_dumps(self.meta)))
|
||||
))
|
||||
for proc in self.pipeline:
|
||||
if not hasattr(proc, 'name'):
|
||||
continue
|
||||
if proc.name in disable:
|
||||
for name, proc in self.pipeline:
|
||||
if name in disable:
|
||||
continue
|
||||
if not hasattr(proc, 'to_disk'):
|
||||
continue
|
||||
deserializers[proc.name] = lambda p, proc=proc: proc.from_disk(p, vocab=False)
|
||||
deserializers[name] = lambda p, proc=proc: proc.from_disk(p, vocab=False)
|
||||
exclude = {p: False for p in disable}
|
||||
if not (path / 'vocab').exists():
|
||||
exclude['vocab'] = True
|
||||
|
@ -552,8 +582,8 @@ class Language(object):
|
|||
('tokenizer', lambda: self.tokenizer.to_bytes(vocab=False)),
|
||||
('meta', lambda: ujson.dumps(self.meta))
|
||||
))
|
||||
for i, proc in enumerate(self.pipeline):
|
||||
if getattr(proc, 'name', None) in disable:
|
||||
for i, (name, proc) in enumerate(self.pipeline):
|
||||
if name in disable:
|
||||
continue
|
||||
if not hasattr(proc, 'to_bytes'):
|
||||
continue
|
||||
|
@ -572,8 +602,8 @@ class Language(object):
|
|||
('tokenizer', lambda b: self.tokenizer.from_bytes(b, vocab=False)),
|
||||
('meta', lambda b: self.meta.update(ujson.loads(b)))
|
||||
))
|
||||
for i, proc in enumerate(self.pipeline):
|
||||
if getattr(proc, 'name', None) in disable:
|
||||
for i, (name, proc) in enumerate(self.pipeline):
|
||||
if name in disable:
|
||||
continue
|
||||
if not hasattr(proc, 'from_bytes'):
|
||||
continue
|
||||
|
|
|
@ -135,7 +135,11 @@ def load_model_from_path(model_path, meta=False, **overrides):
|
|||
if not meta:
|
||||
meta = get_model_meta(model_path)
|
||||
cls = get_lang_class(meta['lang'])
|
||||
nlp = cls(pipeline=meta.get('pipeline', True), meta=meta, **overrides)
|
||||
nlp = cls(meta=meta, **overrides)
|
||||
for name in meta.get('pipeline', []):
|
||||
config = meta.get('pipeline_args', {}).get(name, {})
|
||||
component = nlp.create_pipe(name, config=config)
|
||||
nlp.add_pipe(component, name=name)
|
||||
return nlp.from_disk(model_path)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user