From 7d5212f13117f12651b0285d5da2a0d2e5f5f2ba Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Tue, 18 Oct 2016 16:18:25 +0200 Subject: [PATCH] Refactor defaults --- spacy/language.py | 240 +++++++++++++++++++--------------------------- 1 file changed, 97 insertions(+), 143 deletions(-) diff --git a/spacy/language.py b/spacy/language.py index 681b26da8..1859a36d0 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -34,97 +34,79 @@ from .pipeline import DependencyParser, EntityRecognizer class BaseDefaults(object): - def __init__(self, lang, path): - self.path = path - self.lang = lang - self.lex_attr_getters = dict(self.__class__.lex_attr_getters) - if self.path and (self.path / 'vocab' / 'oov_prob').exists(): - with (self.path / 'vocab' / 'oov_prob').open() as file_: - oov_prob = file_.read().strip() - self.lex_attr_getters[PROB] = lambda string: oov_prob - self.lex_attr_getters[LANG] = lambda string: lang - self.lex_attr_getters[IS_STOP] = lambda string: string in self.stop_words + @classmethod + def create_lemmatizer(cls, nlp=None): + if nlp is None or nlp.path is None: + return Lemmatizer({}, {}, {}) + else: + return Lemmatizer.load(nlp.path) - def Lemmatizer(self): - return Lemmatizer.load(self.path) if self.path else Lemmatizer({}, {}, {}) - - def Vectors(self): + @classmethod + def create_vocab(cls, nlp=None): + lemmatizer = cls.create_lemmatizer(nlp) + if nlp is None or nlp.path is None: + return Vocab(lex_attr_getters=cls.lex_attr_getters, tag_map=cls.tag_map, + lemmatizer=lemmatizer) + else: + return Vocab.load(nlp.path, lex_attr_getters=cls.lex_attr_getters, + tag_map=cls.tag_map, lemmatizer=lemmatizer) + + @classmethod + def add_vectors(cls, nlp=None): return True - def Vocab(self, lex_attr_getters=True, tag_map=True, - lemmatizer=True, serializer_freqs=True, vectors=True): - if lex_attr_getters is True: - lex_attr_getters = self.lex_attr_getters - if tag_map is True: - tag_map = self.tag_map - if lemmatizer is True: - lemmatizer = self.Lemmatizer() - if vectors is True: - vectors = self.Vectors() - if self.path: - return Vocab.load(self.path, lex_attr_getters=lex_attr_getters, - tag_map=tag_map, lemmatizer=lemmatizer, - serializer_freqs=serializer_freqs) + @classmethod + def create_tokenizer(cls, nlp=None): + rules = cls.tokenizer_exceptions + prefix_search = util.compile_prefix_regex(cls.prefixes).search + suffix_search = util.compile_suffix_regex(cls.suffixes).search + infix_finditer = util.compile_infix_regex(cls.infixes).finditer + vocab = nlp.vocab if nlp is not None else cls.create_vocab(nlp) + return Tokenizer(nlp.vocab, rules=rules, + prefix_search=prefix_search, suffix_search=suffix_search, + infix_finditer=infix_finditer) + + @classmethod + def create_tagger(cls, nlp=None): + if nlp is None: + return Tagger(cls.create_vocab(), features=cls.tagger_features) + elif nlp.path is None or not (nlp.path / 'ner').exists(): + return Tagger(nlp.vocab, features=cls.tagger_features) else: - return Vocab(lex_attr_getters=lex_attr_getters, tag_map=tag_map, - lemmatizer=lemmatizer, serializer_freqs=serializer_freqs) + return Tagger.load(nlp.path / 'ner', nlp.vocab) - def Tokenizer(self, vocab, rules=None, prefix_search=None, suffix_search=None, - infix_finditer=None): - if rules is None: - rules = self.tokenizer_exceptions - if prefix_search is None: - prefix_search = util.compile_prefix_regex(self.prefixes).search - if suffix_search is None: - suffix_search = util.compile_suffix_regex(self.suffixes).search - if infix_finditer is None: - infix_finditer = util.compile_infix_regex(self.infixes).finditer - if self.path: - return Tokenizer.load(self.path, vocab, rules=rules, - prefix_search=prefix_search, - suffix_search=suffix_search, - infix_finditer=infix_finditer) + @classmethod + def create_parser(cls, nlp=None): + if nlp is None: + return DependencyParser(cls.create_vocab(), features=cls.parser_features) + elif nlp.path is None or not (nlp.path / 'deps').exists(): + return DependencyParser(nlp.vocab, features=cls.parser_features) else: - tokenizer = Tokenizer(vocab, rules=rules, - prefix_search=prefix_search, suffix_search=suffix_search, - infix_finditer=infix_finditer) - return tokenizer + return DependencyParser.load(nlp.path / 'deps', nlp.vocab) - def Tagger(self, vocab, **cfg): - if self.path: - return Tagger.load(self.path / 'pos', vocab) + @classmethod + def create_entity(cls, nlp=None): + if nlp is None: + return EntityRecognizer(cls.create_vocab(), features=cls.entity_features) + elif nlp.path is None or not (nlp.path / 'ner').exists(): + return EntityRecognizer(nlp.vocab, features=cls.entity_features) else: - if 'features' not in cfg: - cfg['features'] = self.parser_features - return Tagger(vocab, **cfg) + return EntityRecognizer.load(nlp.path / 'ner', nlp.vocab) - def Parser(self, vocab, **cfg): - if self.path and (self.path / 'deps').exists(): - return DependencyParser.load(self.path / 'deps', vocab) + @classmethod + def create_matcher(cls, nlp=None): + if nlp is None: + return Matcher(cls.create_vocab()) + elif nlp.path is None or not (nlp.path / 'vocab').exists(): + return Matcher(nlp.vocab) else: - if 'features' not in cfg: - cfg['features'] = self.parser_features - return DependencyParser(vocab, **cfg) + return Matcher.load(nlp.path / 'vocab', nlp.vocab) - def Entity(self, vocab, **cfg): - if self.path and (self.path / 'ner').exists(): - return EntityRecognizer.load(self.path / 'ner', vocab) - else: - if 'features' not in cfg: - cfg['features'] = self.entity_features - return EntityRecognizer(vocab, **cfg) - - def Matcher(self, vocab, **cfg): - if self.path: - return Matcher.load(self.path, vocab) - else: - return Matcher(vocab) - - def MakeDoc(self, nlp, **cfg): - return lambda text: nlp.tokenizer(text) - - def Pipeline(self, nlp, **cfg): + @classmethod + def create_pipeline(self, nlp=None): pipeline = [] + if nlp is None: + return [] if nlp.tagger: pipeline.append(nlp.tagger) if nlp.parser: @@ -147,6 +129,8 @@ class BaseDefaults(object): entity_features = get_templates('ner') + tagger_features = Tagger.feature_templates # TODO -- fix this + stop_words = set() lex_attr_getters = { @@ -240,78 +224,48 @@ class Language(object): yield Trainer(self, gold_tuples) self.end_training() - def __init__(self, - path=True, - vocab=True, - tokenizer=True, - tagger=True, - parser=True, - entity=True, - matcher=True, - serializer=True, - vectors=True, - make_doc=True, - pipeline=True, - defaults=True, - data_dir=None): - """ - A model can be specified: - - 1) by calling a Language subclass - - spacy.en.English() - - 2) by calling a Language subclass with data_dir - - spacy.en.English('my/model/root') - - spacy.en.English(data_dir='my/model/root') - - 3) by package name - - spacy.load('en_default') - - spacy.load('en_default==1.0.0') - - 4) by package name with a relocated package base - - spacy.load('en_default', via='/my/package/root') - - spacy.load('en_default==1.0.0', via='/my/package/root') - """ - if data_dir is not None and path is None: - warn("'data_dir' argument now named 'path'. Doing what you mean.") - path = data_dir + def __init__(self, path=True, **overrides): + if 'data_dir' in overrides and 'path' not in overrides: + raise ValueError("The argument 'data_dir' has been renamed to 'path'") + path = overrides.get('path', True) if isinstance(path, basestring): path = pathlib.Path(path) if path is True: path = util.match_best_version(self.lang, '', util.get_data_path()) + self.path = path - defaults = defaults if defaults is not True else self.get_defaults(self.path) - - self.defaults = defaults - self.vocab = vocab if vocab is not True else defaults.Vocab(vectors=vectors) - self.tokenizer = tokenizer if tokenizer is not True else defaults.Tokenizer(self.vocab) - self.tagger = tagger if tagger is not True else defaults.Tagger(self.vocab) - self.entity = entity if entity is not True else defaults.Entity(self.vocab) - self.parser = parser if parser is not True else defaults.Parser(self.vocab) - self.matcher = matcher if matcher is not True else defaults.Matcher(self.vocab) + + self.vocab = self.Defaults.create_vocab(self) \ + if 'vocab' not in overrides \ + else overrides['vocab'] + self.tokenizer = self.Defaults.create_tokenizer(self) \ + if 'tokenizer' not in overrides \ + else overrides['tokenizer'] + self.tagger = self.Defaults.create_tagger(self) \ + if 'tagger' not in overrides \ + else overrides['tagger'] + self.parser = self.Defaults.create_tagger(self) \ + if 'parser' not in overrides \ + else overrides['parser'] + self.entity = self.Defaults.create_entity(self) \ + if 'entity' not in overrides \ + else overrides['entity'] + self.matcher = self.Defaults.create_matcher(self) \ + if 'matcher' not in overrides \ + else overrides['matcher'] - if make_doc in (None, True, False): - self.make_doc = defaults.MakeDoc(self) + if 'make_doc' in overrides: + self.make_doc = overrides['make_doc'] + elif 'create_make_doc' in overrides: + self.make_doc = overrides['create_make_doc'] else: - self.make_doc = make_doc - if pipeline in (None, False): - self.pipeline = [] - elif pipeline is True: - self.pipeline = defaults.Pipeline(self) + self.make_doc = lambda text: self.tokenizer(text) + if 'pipeline' in overrides: + self.pipeline = overrides['pipeline'] + elif 'create_pipeline' in overrides: + self.pipeline = overrides['create_pipeline'] else: - self.pipeline = pipeline(self) - - def __reduce__(self): - args = ( - self.path, - self.vocab, - self.tokenizer, - self.tagger, - self.parser, - self.entity, - self.matcher - ) - return (self.__class__, args, None, None) + self.pipeline = [self.tagger, self.parser, self.matcher, self.entity] def __call__(self, text, tag=True, parse=True, entity=True): """Apply the pipeline to some text. The text can span multiple sentences,