Add spacy.blank() method, that doesn't load data. Don't try to load data if path is falsey

This commit is contained in:
Matthew Honnibal 2016-09-26 11:07:46 +02:00
parent ae202e7a60
commit 722199acb8
2 changed files with 49 additions and 14 deletions

View File

@ -19,6 +19,23 @@ set_lang_class(de.German.lang, de.German)
set_lang_class(zh.Chinese.lang, zh.Chinese)
def blank(name, vocab=None, tokenizer=None, parser=None, tagger=None, entity=None,
matcher=None, serializer=None, vectors=None, pipeline=None):
target_name, target_version = util.split_data_name(name)
cls = get_lang_class(target_name)
return cls(
path,
vectors=vectors,
vocab=vocab,
tokenizer=tokenizer,
tagger=tagger,
parser=parser,
entity=entity,
matcher=matcher,
pipeline=pipeline,
serializer=serializer)
def load(name, vocab=True, tokenizer=True, parser=True, tagger=True, entity=True,
matcher=True, serializer=True, vectors=True, pipeline=True, via=None):
if via is None:

View File

@ -36,7 +36,7 @@ class BaseDefaults(object):
self.path = path
self.lang = lang
self.lex_attr_getters = dict(self.__class__.lex_attr_getters)
if (self.path / 'vocab' / 'oov_prob').exists():
if self.path and (self.path / 'vocab' / 'oov_prob').exists():
with (self.path / 'vocab' / 'oov_prob').open() as file_:
oov_prob = file_.read().strip()
self.lex_attr_getters[PROB] = lambda string: oov_prob
@ -44,7 +44,7 @@ class BaseDefaults(object):
self.lex_attr_getters[IS_STOP] = lambda string: string in self.stop_words
def Lemmatizer(self):
return Lemmatizer.load(self.path)
return Lemmatizer.load(self.path) if self.path else Lemmatizer({}, {}, {})
def Vectors(self):
return True
@ -59,9 +59,13 @@ class BaseDefaults(object):
lemmatizer = self.Lemmatizer()
if vectors is True:
vectors = self.Vectors()
return Vocab.load(self.path, lex_attr_getters=lex_attr_getters,
tag_map=tag_map, lemmatizer=lemmatizer,
serializer_freqs=serializer_freqs)
if self.path:
return Vocab.load(self.path, lex_attr_getters=lex_attr_getters,
tag_map=tag_map, lemmatizer=lemmatizer,
serializer_freqs=serializer_freqs)
else:
return Vocab(lex_attr_getters=lex_attr_getters, tag_map=tag_map,
lemmatizer=lemmatizer, serializer_freqs=serializer_freqs)
def Tokenizer(self, vocab, rules=None, prefix_search=None, suffix_search=None,
infix_finditer=None):
@ -73,27 +77,41 @@ class BaseDefaults(object):
suffix_search = util.compile_suffix_regex(self.suffixes).search
if infix_finditer is None:
infix_finditer = util.compile_infix_regex(self.infixes).finditer
return Tokenizer(vocab, rules=rules,
prefix_search=prefix_search, suffix_search=suffix_search,
infix_finditer=infix_finditer)
if self.path:
return Tokenizer.load(self.path, vocab, rules=rules,
prefix_search=prefix_search,
suffix_search=suffix_search,
infix_finditer=infix_finditer)
else:
return Tokenizer(vocab, rules=rules,
prefix_search=prefix_search, suffix_search=suffix_search,
infix_finditer=infix_finditer)
def Tagger(self, vocab):
return Tagger.load(self.path / 'pos', vocab)
if self.path:
return Tagger.load(self.path / 'pos', vocab)
else:
return Tagger.blank(vocab, Tagger.default_templates(self.lang))
def Parser(self, vocab):
if (self.path / 'deps').exists():
if self.path:
return Parser.load(self.path / 'deps', vocab, ArcEager)
else:
return None
return Parser.blank(vocab, ArcEager,
Parser.default_templates('%s-parser' % self.lang))
def Entity(self, vocab):
if (self.path / 'ner').exists():
if self.path and (self.path / 'ner').exists():
return Parser.load(self.path / 'ner', vocab, BiluoPushDown)
else:
return None
return Parser.blank(vocab, BiluoPushdown,
Parser.default_templates('%s-entity' % self.lang))
def Matcher(self, vocab):
return Matcher.load(self.path, vocab)
if self.path:
return Matcher.load(self.path, vocab)
else:
return Matcher(vocab)
def Pipeline(self, nlp):
return [