* Work on refactoring default arguments to English.__init__

This commit is contained in:
Matthew Honnibal 2015-07-07 15:53:25 +02:00
parent 2d0e99a096
commit 1d2deb4616

View File

@ -5,7 +5,7 @@ import re
from .. import orth from .. import orth
from ..vocab import Vocab from ..vocab import Vocab
from ..tokenizer import Tokenizer from ..tokenizer import Tokenizer
from ..syntax.parser import Parser from ..syntax import parser
from ..syntax.arc_eager import ArcEager from ..syntax.arc_eager import ArcEager
from ..syntax.ner import BiluoPushDown from ..syntax.ner import BiluoPushDown
from ..tokens import Tokens from ..tokens import Tokens
@ -16,6 +16,7 @@ from .pos import POS_TAGS
from .attrs import get_flags from .attrs import get_flags
from . import regexes from . import regexes
from ..exceptions import ModelNotLoaded
from ..util import read_lang_data from ..util import read_lang_data
@ -89,54 +90,44 @@ class English(object):
ParserTransitionSystem = ArcEager ParserTransitionSystem = ArcEager
EntityTransitionSystem = BiluoPushDown EntityTransitionSystem = BiluoPushDown
def __init__(self, data_dir='', Tokenizer=True, Vectors=True, Parser=True, def __init__(self, data_dir='', Tokenizer=Tokenizer.from_dir, Vectors=True,
Tagger=True, Entity=True, Senser=True, load_vectors=True): Parser=True, Tagger=EnPosTagger, Entity=True, load_vectors=True):
if data_dir == '': if data_dir == '':
data_dir = LOCAL_DATA_DIR data_dir = LOCAL_DATA_DIR
self._data_dir = data_dir
# TODO: Deprecation warning # TODO: Deprecation warning
if load_vectors is False: if load_vectors is False:
vectors = False vectors = False
self.vocab = Vocab(data_dir=path.join(data_dir, 'vocab') if data_dir else None, self.vocab = Vocab(data_dir=path.join(data_dir, 'vocab') if data_dir else None,
get_lex_props=get_lex_props, vectors=Vectors) get_lex_props=get_lex_props, load_vectors=Vectors,
pos_tags=POS_TAGS)
if Tokenizer is True:
Tokenizer = tokenizer.Tokenizer
if Tagger is True: if Tagger is True:
Tagger = pos.EnPosTagger Tagger = EnPosTagger
if Parser is True: if Parser is True:
transition_system = self.ParserTransitionSystem transition_system = self.ParserTransitionSystem
Parser = lambda s, d: parser.Parser(s, d, transition_system Parser = lambda s, d: parser.Parser(s, d, transition_system)
if Entity is True: if Entity is True:
transition_system = self.EntityTransitionSystem transition_system = self.EntityTransitionSystem
Entity = lambda s, d: parser.Parser(s, d, transition_system) Entity = lambda s, d: parser.Parser(s, d, transition_system)
if Senser is True:
Senser = wsd.SuperSenseTagger
self.tokenizer = Tokenizer(self.vocab, data_dir) if Tokenizer else None if Tokenizer:
self.tagger = Tagger(self.vocab.strings, data_dir) if Tagger else None self.tokenizer = Tokenizer(self.vocab, path.join(data_dir, 'tokenizer'))
self.parser = Parser(self.vocab.strings, data_dir) if Parser else None
self.entity = Entity(self.vocab.strings, data_dir) if Entity else None
self.senser = Senser(self.vocab.strings, data_dir) if Senser else None
self._data_dir = data_dir
tag_names = list(POS_TAGS.keys())
tag_names.sort()
if data_dir is None:
tok_rules = {}
prefix_re = None
suffix_re = None
infix_re = None
else: else:
tok_data_dir = path.join(data_dir, 'tokenizer') self.tokenizer = None
tok_rules, prefix_re, suffix_re, infix_re = read_lang_data(tok_data_dir) if Tagger:
prefix_re = re.compile(prefix_re) self.tagger = Tagger(self.vocab.strings, data_dir)
suffix_re = re.compile(suffix_re) else:
infix_re = re.compile(infix_re) self.tagger = None
if Parser:
self.tokenizer = Tokenizer(self.vocab, tok_rules, prefix_re, self.parser = Parser(self.vocab.strings, path.join(data_dir, 'deps'))
suffix_re, infix_re, else:
POS_TAGS, tag_names) self.parser = None
if Entity:
self.entity = Entity(self.vocab.strings, path.join(data_dir, 'ner'))
else:
self.entity = None
self.mwe_merger = RegexMerger([ self.mwe_merger = RegexMerger([
('IN', 'O', regexes.MW_PREPOSITIONS_RE), ('IN', 'O', regexes.MW_PREPOSITIONS_RE),
@ -185,31 +176,22 @@ class English(object):
tokens = self.tokenizer(text) tokens = self.tokenizer(text)
if parse == -1 and tag == False: if parse == -1 and tag == False:
parse = False parse = False
elif parse == -1 and not self.has_parser_model: elif parse == -1 and self.parser is None:
parse = False parse = False
if entity == -1 and tag == False: if entity == -1 and tag == False:
entity = False entity = False
elif entity == -1 and not self.has_entity_model: elif entity == -1 and self.entity is None:
entity = False entity = False
if tag and self.has_tagger_model: if tag:
ModelNotLoaded.check(self.tagger, 'tagger')
self.tagger(tokens) self.tagger(tokens)
if parse == True and not self.has_parser_model: if parse:
msg = ("Received parse=True, but parser model not found.\n\n" ModelNotLoaded.check(self.parser, 'parser')
"Run:\n"
"$ python -m spacy.en.download\n"
"To install the model.")
raise IOError(msg)
if entity == True and not self.has_entity_model:
msg = ("Received entity=True, but entity model not found.\n\n"
"Run:\n"
"$ python -m spacy.en.download\n"
"To install the model.")
raise IOError(msg)
if parse and self.has_parser_model:
self.parser(tokens) self.parser(tokens)
if entity and self.has_entity_model: if entity:
ModelNotLoaded.check(self.entity, 'entity')
self.entity(tokens) self.entity(tokens)
if merge_mwes and self.mwe_merger is not None: if merge_mwes and self.mwe_merger is not None:
self.mwe_merger(tokens) self.mwe_merger(tokens)
return tokens return tokens