* Refactor __init__ for simplicity. Allow parse=True, tag=True etc flags to be passed at top-level. Do not lazy-load parser.

This commit is contained in:
Matthew Honnibal 2015-07-08 12:35:29 +02:00
parent 4d24d513ad
commit 4e4fac452b

View File

@ -5,9 +5,10 @@ import re
from .. import orth from .. import orth
from ..vocab import Vocab from ..vocab import Vocab
from ..tokenizer import Tokenizer from ..tokenizer import Tokenizer
from ..syntax import parser
from ..syntax.arc_eager import ArcEager from ..syntax.arc_eager import ArcEager
from ..syntax.ner import BiluoPushDown from ..syntax.ner import BiluoPushDown
from ..syntax.parser import ParserFactory
from ..tokens import Tokens from ..tokens import Tokens
from ..multi_words import RegexMerger from ..multi_words import RegexMerger
@ -36,10 +37,7 @@ def get_lex_props(string):
'sentiment': 0 'sentiment': 0
} }
if_model_present = -1
LOCAL_DATA_DIR = path.join(path.dirname(__file__), 'data')
parse_if_model_present = -1
class English(object): class English(object):
@ -63,45 +61,23 @@ class English(object):
Start with nothing loaded: Start with nothing loaded:
>>> nlp = English(data_dir=None) >>> nlp = English(data_dir=None)
Keyword args:
data_dir (unicode):
A path to a directory from which to load the pipeline;
or '', to load default; or None, to load nothing.
Tokenizer (bool or callable):
desc
Vectors (bool or callable):
desc
Parser (bool or callable):
desc
Tagger (bool or callable):
desc
Entity (bool or callable):
desc
Senser (bool or callable):
desc
""" """
ParserTransitionSystem = ArcEager ParserTransitionSystem = ArcEager
EntityTransitionSystem = BiluoPushDown EntityTransitionSystem = BiluoPushDown
def __init__(self, data_dir='', Tokenizer=Tokenizer.from_dir, Vectors=True, def __init__(self,
Parser=True, Tagger=EnPosTagger, Entity=True, load_vectors=True): data_dir=path.join(path.dirname(__file__), 'data'),
if data_dir == '': Tokenizer=Tokenizer.from_dir,
data_dir = LOCAL_DATA_DIR Tagger=EnPosTagger,
Parser=ParserFactory(ParserTransitionSystem),
Entity=ParserFactory(EntityTransitionSystem),
load_vectors=True
):
self._data_dir = data_dir self._data_dir = data_dir
# TODO: Deprecation warning
if load_vectors is False:
vectors = False
self.vocab = Vocab(data_dir=path.join(data_dir, 'vocab') if data_dir else None, self.vocab = Vocab(data_dir=path.join(data_dir, 'vocab') if data_dir else None,
get_lex_props=get_lex_props, load_vectors=Vectors, get_lex_props=get_lex_props, load_vectors=load_vectors,
pos_tags=POS_TAGS) pos_tags=POS_TAGS)
if Tagger is True: if Tagger is True:
Tagger = EnPosTagger Tagger = EnPosTagger
@ -112,10 +88,8 @@ class English(object):
transition_system = self.EntityTransitionSystem transition_system = self.EntityTransitionSystem
Entity = lambda s, d: parser.Parser(s, d, transition_system) Entity = lambda s, d: parser.Parser(s, d, transition_system)
if Tokenizer: self.tokenizer = Tokenizer(self.vocab, path.join(data_dir, 'tokenizer'))
self.tokenizer = Tokenizer(self.vocab, path.join(data_dir, 'tokenizer'))
else:
self.tokenizer = None
if Tagger: if Tagger:
self.tagger = Tagger(self.vocab.strings, data_dir) self.tagger = Tagger(self.vocab.strings, data_dir)
else: else:
@ -128,33 +102,20 @@ class English(object):
self.entity = Entity(self.vocab.strings, path.join(data_dir, 'ner')) self.entity = Entity(self.vocab.strings, path.join(data_dir, 'ner'))
else: else:
self.entity = None self.entity = None
self.mwe_merger = RegexMerger([ self.mwe_merger = RegexMerger([
('IN', 'O', regexes.MW_PREPOSITIONS_RE), ('IN', 'O', regexes.MW_PREPOSITIONS_RE),
('CD', 'TIME', regexes.TIME_RE), ('CD', 'TIME', regexes.TIME_RE),
('NNP', 'DATE', regexes.DAYS_RE), ('NNP', 'DATE', regexes.DAYS_RE),
('CD', 'MONEY', regexes.MONEY_RE)]) ('CD', 'MONEY', regexes.MONEY_RE)])
def __call__(self, text, tag=True, parse=parse_if_model_present, def __call__(self, text, tag=True, parse=True, entity=True):
entity=parse_if_model_present, merge_mwes=False):
"""Apply the pipeline to some text. The text can span multiple sentences, """Apply the pipeline to some text. The text can span multiple sentences,
and can contain arbtrary whitespace. Alignment into the original string and can contain arbtrary whitespace. Alignment into the original string
is preserved.
The tagger and parser are lazy-loaded the first time they are required.
Loading the parser model usually takes 5-10 seconds.
Args: Args:
text (unicode): The text to be processed. text (unicode): The text to be processed.
Keyword args:
tag (bool): Whether to add part-of-speech tags to the text. Also
sets morphological analysis and lemmas.
parse (True, False, -1): Whether to add labelled syntactic dependencies.
-1 (default) is "guess": It will guess True if tag=True and the
model has been installed.
Returns: Returns:
tokens (spacy.tokens.Tokens): tokens (spacy.tokens.Tokens):
@ -164,36 +125,13 @@ class English(object):
>>> tokens[0].orth_, tokens[0].head.tag_ >>> tokens[0].orth_, tokens[0].head.tag_
('An', 'NN') ('An', 'NN')
""" """
if parse == True and tag == False:
msg = ("Incompatible arguments: tag=False, parse=True"
"Part-of-speech tags are required for parsing.")
raise ValueError(msg)
if entity == True and tag == False:
msg = ("Incompatible arguments: tag=False, entity=True"
"Part-of-speech tags are required for entity recognition.")
raise ValueError(msg)
tokens = self.tokenizer(text) tokens = self.tokenizer(text)
if parse == -1 and tag == False: if self.tagger and tag:
parse = False
elif parse == -1 and self.parser is None:
parse = False
if entity == -1 and tag == False:
entity = False
elif entity == -1 and self.entity is None:
entity = False
if tag:
ModelNotLoaded.check(self.tagger, 'tagger')
self.tagger(tokens) self.tagger(tokens)
if parse: if self.parser and parse:
ModelNotLoaded.check(self.parser, 'parser')
self.parser(tokens) self.parser(tokens)
if entity: if self.entity and entity:
ModelNotLoaded.check(self.entity, 'entity')
self.entity(tokens) self.entity(tokens)
if merge_mwes and self.mwe_merger is not None:
self.mwe_merger(tokens)
return tokens return tokens
@property @property