Refactor defaults

This commit is contained in:
Matthew Honnibal 2016-10-18 16:18:25 +02:00
parent a45a9d5092
commit 7d5212f131

View File

@ -34,97 +34,79 @@ from .pipeline import DependencyParser, EntityRecognizer
class BaseDefaults(object):
def __init__(self, lang, path):
self.path = path
self.lang = lang
self.lex_attr_getters = dict(self.__class__.lex_attr_getters)
if self.path and (self.path / 'vocab' / 'oov_prob').exists():
with (self.path / 'vocab' / 'oov_prob').open() as file_:
oov_prob = file_.read().strip()
self.lex_attr_getters[PROB] = lambda string: oov_prob
self.lex_attr_getters[LANG] = lambda string: lang
self.lex_attr_getters[IS_STOP] = lambda string: string in self.stop_words
@classmethod
def create_lemmatizer(cls, nlp=None):
if nlp is None or nlp.path is None:
return Lemmatizer({}, {}, {})
else:
return Lemmatizer.load(nlp.path)
def Lemmatizer(self):
return Lemmatizer.load(self.path) if self.path else Lemmatizer({}, {}, {})
def Vectors(self):
@classmethod
def create_vocab(cls, nlp=None):
lemmatizer = cls.create_lemmatizer(nlp)
if nlp is None or nlp.path is None:
return Vocab(lex_attr_getters=cls.lex_attr_getters, tag_map=cls.tag_map,
lemmatizer=lemmatizer)
else:
return Vocab.load(nlp.path, lex_attr_getters=cls.lex_attr_getters,
tag_map=cls.tag_map, lemmatizer=lemmatizer)
@classmethod
def add_vectors(cls, nlp=None):
return True
def Vocab(self, lex_attr_getters=True, tag_map=True,
lemmatizer=True, serializer_freqs=True, vectors=True):
if lex_attr_getters is True:
lex_attr_getters = self.lex_attr_getters
if tag_map is True:
tag_map = self.tag_map
if lemmatizer is True:
lemmatizer = self.Lemmatizer()
if vectors is True:
vectors = self.Vectors()
if self.path:
return Vocab.load(self.path, lex_attr_getters=lex_attr_getters,
tag_map=tag_map, lemmatizer=lemmatizer,
serializer_freqs=serializer_freqs)
@classmethod
def create_tokenizer(cls, nlp=None):
rules = cls.tokenizer_exceptions
prefix_search = util.compile_prefix_regex(cls.prefixes).search
suffix_search = util.compile_suffix_regex(cls.suffixes).search
infix_finditer = util.compile_infix_regex(cls.infixes).finditer
vocab = nlp.vocab if nlp is not None else cls.create_vocab(nlp)
return Tokenizer(nlp.vocab, rules=rules,
prefix_search=prefix_search, suffix_search=suffix_search,
infix_finditer=infix_finditer)
@classmethod
def create_tagger(cls, nlp=None):
if nlp is None:
return Tagger(cls.create_vocab(), features=cls.tagger_features)
elif nlp.path is None or not (nlp.path / 'ner').exists():
return Tagger(nlp.vocab, features=cls.tagger_features)
else:
return Vocab(lex_attr_getters=lex_attr_getters, tag_map=tag_map,
lemmatizer=lemmatizer, serializer_freqs=serializer_freqs)
return Tagger.load(nlp.path / 'ner', nlp.vocab)
def Tokenizer(self, vocab, rules=None, prefix_search=None, suffix_search=None,
infix_finditer=None):
if rules is None:
rules = self.tokenizer_exceptions
if prefix_search is None:
prefix_search = util.compile_prefix_regex(self.prefixes).search
if suffix_search is None:
suffix_search = util.compile_suffix_regex(self.suffixes).search
if infix_finditer is None:
infix_finditer = util.compile_infix_regex(self.infixes).finditer
if self.path:
return Tokenizer.load(self.path, vocab, rules=rules,
prefix_search=prefix_search,
suffix_search=suffix_search,
infix_finditer=infix_finditer)
@classmethod
def create_parser(cls, nlp=None):
if nlp is None:
return DependencyParser(cls.create_vocab(), features=cls.parser_features)
elif nlp.path is None or not (nlp.path / 'deps').exists():
return DependencyParser(nlp.vocab, features=cls.parser_features)
else:
tokenizer = Tokenizer(vocab, rules=rules,
prefix_search=prefix_search, suffix_search=suffix_search,
infix_finditer=infix_finditer)
return tokenizer
return DependencyParser.load(nlp.path / 'deps', nlp.vocab)
def Tagger(self, vocab, **cfg):
if self.path:
return Tagger.load(self.path / 'pos', vocab)
@classmethod
def create_entity(cls, nlp=None):
if nlp is None:
return EntityRecognizer(cls.create_vocab(), features=cls.entity_features)
elif nlp.path is None or not (nlp.path / 'ner').exists():
return EntityRecognizer(nlp.vocab, features=cls.entity_features)
else:
if 'features' not in cfg:
cfg['features'] = self.parser_features
return Tagger(vocab, **cfg)
return EntityRecognizer.load(nlp.path / 'ner', nlp.vocab)
def Parser(self, vocab, **cfg):
if self.path and (self.path / 'deps').exists():
return DependencyParser.load(self.path / 'deps', vocab)
@classmethod
def create_matcher(cls, nlp=None):
if nlp is None:
return Matcher(cls.create_vocab())
elif nlp.path is None or not (nlp.path / 'vocab').exists():
return Matcher(nlp.vocab)
else:
if 'features' not in cfg:
cfg['features'] = self.parser_features
return DependencyParser(vocab, **cfg)
return Matcher.load(nlp.path / 'vocab', nlp.vocab)
def Entity(self, vocab, **cfg):
if self.path and (self.path / 'ner').exists():
return EntityRecognizer.load(self.path / 'ner', vocab)
else:
if 'features' not in cfg:
cfg['features'] = self.entity_features
return EntityRecognizer(vocab, **cfg)
def Matcher(self, vocab, **cfg):
if self.path:
return Matcher.load(self.path, vocab)
else:
return Matcher(vocab)
def MakeDoc(self, nlp, **cfg):
return lambda text: nlp.tokenizer(text)
def Pipeline(self, nlp, **cfg):
@classmethod
def create_pipeline(self, nlp=None):
pipeline = []
if nlp is None:
return []
if nlp.tagger:
pipeline.append(nlp.tagger)
if nlp.parser:
@ -147,6 +129,8 @@ class BaseDefaults(object):
entity_features = get_templates('ner')
tagger_features = Tagger.feature_templates # TODO -- fix this
stop_words = set()
lex_attr_getters = {
@ -240,78 +224,48 @@ class Language(object):
yield Trainer(self, gold_tuples)
self.end_training()
def __init__(self,
path=True,
vocab=True,
tokenizer=True,
tagger=True,
parser=True,
entity=True,
matcher=True,
serializer=True,
vectors=True,
make_doc=True,
pipeline=True,
defaults=True,
data_dir=None):
"""
A model can be specified:
1) by calling a Language subclass
- spacy.en.English()
2) by calling a Language subclass with data_dir
- spacy.en.English('my/model/root')
- spacy.en.English(data_dir='my/model/root')
3) by package name
- spacy.load('en_default')
- spacy.load('en_default==1.0.0')
4) by package name with a relocated package base
- spacy.load('en_default', via='/my/package/root')
- spacy.load('en_default==1.0.0', via='/my/package/root')
"""
if data_dir is not None and path is None:
warn("'data_dir' argument now named 'path'. Doing what you mean.")
path = data_dir
def __init__(self, path=True, **overrides):
if 'data_dir' in overrides and 'path' not in overrides:
raise ValueError("The argument 'data_dir' has been renamed to 'path'")
path = overrides.get('path', True)
if isinstance(path, basestring):
path = pathlib.Path(path)
if path is True:
path = util.match_best_version(self.lang, '', util.get_data_path())
self.path = path
defaults = defaults if defaults is not True else self.get_defaults(self.path)
self.defaults = defaults
self.vocab = vocab if vocab is not True else defaults.Vocab(vectors=vectors)
self.tokenizer = tokenizer if tokenizer is not True else defaults.Tokenizer(self.vocab)
self.tagger = tagger if tagger is not True else defaults.Tagger(self.vocab)
self.entity = entity if entity is not True else defaults.Entity(self.vocab)
self.parser = parser if parser is not True else defaults.Parser(self.vocab)
self.matcher = matcher if matcher is not True else defaults.Matcher(self.vocab)
self.vocab = self.Defaults.create_vocab(self) \
if 'vocab' not in overrides \
else overrides['vocab']
self.tokenizer = self.Defaults.create_tokenizer(self) \
if 'tokenizer' not in overrides \
else overrides['tokenizer']
self.tagger = self.Defaults.create_tagger(self) \
if 'tagger' not in overrides \
else overrides['tagger']
self.parser = self.Defaults.create_tagger(self) \
if 'parser' not in overrides \
else overrides['parser']
self.entity = self.Defaults.create_entity(self) \
if 'entity' not in overrides \
else overrides['entity']
self.matcher = self.Defaults.create_matcher(self) \
if 'matcher' not in overrides \
else overrides['matcher']
if make_doc in (None, True, False):
self.make_doc = defaults.MakeDoc(self)
if 'make_doc' in overrides:
self.make_doc = overrides['make_doc']
elif 'create_make_doc' in overrides:
self.make_doc = overrides['create_make_doc']
else:
self.make_doc = make_doc
if pipeline in (None, False):
self.pipeline = []
elif pipeline is True:
self.pipeline = defaults.Pipeline(self)
self.make_doc = lambda text: self.tokenizer(text)
if 'pipeline' in overrides:
self.pipeline = overrides['pipeline']
elif 'create_pipeline' in overrides:
self.pipeline = overrides['create_pipeline']
else:
self.pipeline = pipeline(self)
def __reduce__(self):
args = (
self.path,
self.vocab,
self.tokenizer,
self.tagger,
self.parser,
self.entity,
self.matcher
)
return (self.__class__, args, None, None)
self.pipeline = [self.tagger, self.parser, self.matcher, self.entity]
def __call__(self, text, tag=True, parse=True, entity=True):
"""Apply the pipeline to some text. The text can span multiple sentences,