From 9dc8043a7eaa4fb81ccd9ca18bb748509a093099 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sat, 24 Sep 2016 14:08:53 +0200 Subject: [PATCH] Refactor Language to use new Defaults class, and work on revised data loading. We're getting rid of sputnik's weird file-system wrapper, and using pathlib. --- spacy/language.py | 268 ++++++++++++++++++++++------------------------ 1 file changed, 131 insertions(+), 137 deletions(-) diff --git a/spacy/language.py b/spacy/language.py index f88f8e4c7..9ba0caf23 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -1,31 +1,108 @@ from __future__ import absolute_import -from os import path from warnings import warn -import io +import pathlib try: import ujson as json except ImportError: import json + from .tokenizer import Tokenizer from .vocab import Vocab from .syntax.parser import Parser from .tagger import Tagger from .matcher import Matcher -from .serialize.packer import Packer from . import attrs from . import orth from .syntax.ner import BiluoPushDown from .syntax.arc_eager import ArcEager -from . import util -from . import about from .attrs import TAG, DEP, ENT_IOB, ENT_TYPE, HEAD -class Language(object): - lang = None +class Defaults(object): + def __init__(self, lang, path): + self.lang = lang + self.path = path + + def Vectors(self): + pass + + def Vocab(self, vectors=None, get_lex_attr=None): + if get_lex_attr is None: + get_lex_attr = self.lex_attrs() + if vectors is None: + vectors = self.Vectors() + return Vocab.load(self.path, get_lex_attr=get_lex_attr, vectors=vectors) + + def Tokenizer(self, vocab): + return Tokenizer.load(self.path, vocab) + + def Tagger(self, vocab): + return Tagger.load(self.path, self.vocab) + + def Parser(self, vocab): + if (self.path / 'deps').exists(): + return Parser.load(self.path / 'deps', vocab, ArcEager) + else: + return None + + def Entity(self, vocab): + if (self.path / 'ner').exists(): + return Parser.load(self.path / 'ner', vocab, BiluoPushDown) + else: + return None + + def Matcher(self, vocab): + return Matcher.load(self.path, vocab) + + def Pipeline(self, nlp): + return [ + nlp.tokenizer, + nlp.tagger, + nlp.parser, + nlp.entity] + + def dep_labels(self): + return {0: {'ROOT': True}} + + def ner_labels(self): + return {0: {'PER': True, 'LOC': True, 'ORG': True, 'MISC': True}} + + def lex_attrs(self, *args, **kwargs): + if 'oov_prob' in kwargs: + oov_prob = kwargs.get('oov_prob', -20) + else: + with (self.path / 'vocab' / 'oov_prob').open() as file_: + oov_prob = file_.read().strip() + return { + attrs.LOWER: self.lower, + attrs.NORM: self.norm, + attrs.SHAPE: orth.word_shape, + attrs.PREFIX: self.prefix, + attrs.SUFFIX: self.suffix, + attrs.CLUSTER: self.cluster, + attrs.PROB: lambda string: oov_prob, + attrs.LANG: lambda string: self.lang, + attrs.IS_ALPHA: orth.is_alpha, + attrs.IS_ASCII: orth.is_ascii, + attrs.IS_DIGIT: self.is_digit, + attrs.IS_LOWER: orth.is_lower, + attrs.IS_PUNCT: orth.is_punct, + attrs.IS_SPACE: self.is_space, + attrs.IS_TITLE: orth.is_title, + attrs.IS_UPPER: orth.is_upper, + attrs.IS_BRACKET: orth.is_bracket, + attrs.IS_QUOTE: orth.is_quote, + attrs.IS_LEFT_PUNCT: orth.is_left_punct, + attrs.IS_RIGHT_PUNCT: orth.is_right_punct, + attrs.LIKE_URL: orth.like_url, + attrs.LIKE_NUM: orth.like_number, + attrs.LIKE_EMAIL: orth.like_email, + attrs.IS_STOP: self.is_stop, + attrs.IS_OOV: lambda string: True + } @staticmethod def lower(string): @@ -59,94 +136,27 @@ class Language(object): def is_stop(string): return 0 - @classmethod - def default_lex_attrs(cls, *args, **kwargs): - oov_prob = kwargs.get('oov_prob', -20) - return { - attrs.LOWER: cls.lower, - attrs.NORM: cls.norm, - attrs.SHAPE: orth.word_shape, - attrs.PREFIX: cls.prefix, - attrs.SUFFIX: cls.suffix, - attrs.CLUSTER: cls.cluster, - attrs.PROB: lambda string: oov_prob, - attrs.LANG: lambda string: cls.lang, - attrs.IS_ALPHA: orth.is_alpha, - attrs.IS_ASCII: orth.is_ascii, - attrs.IS_DIGIT: cls.is_digit, - attrs.IS_LOWER: orth.is_lower, - attrs.IS_PUNCT: orth.is_punct, - attrs.IS_SPACE: cls.is_space, - attrs.IS_TITLE: orth.is_title, - attrs.IS_UPPER: orth.is_upper, - attrs.IS_BRACKET: orth.is_bracket, - attrs.IS_QUOTE: orth.is_quote, - attrs.IS_LEFT_PUNCT: orth.is_left_punct, - attrs.IS_RIGHT_PUNCT: orth.is_right_punct, - attrs.LIKE_URL: orth.like_url, - attrs.LIKE_NUM: orth.like_number, - attrs.LIKE_EMAIL: orth.like_email, - attrs.IS_STOP: cls.is_stop, - attrs.IS_OOV: lambda string: True - } - @classmethod - def default_dep_labels(cls): - return {0: {'ROOT': True}} - - @classmethod - def default_ner_labels(cls): - return {0: {'PER': True, 'LOC': True, 'ORG': True, 'MISC': True}} - - @classmethod - def default_vocab(cls, package, get_lex_attr=None, vectors_package=None): - if get_lex_attr is None: - if package.has_file('vocab', 'oov_prob'): - with package.open(('vocab', 'oov_prob')) as file_: - oov_prob = float(file_.read().strip()) - get_lex_attr = cls.default_lex_attrs(oov_prob=oov_prob) - else: - get_lex_attr = cls.default_lex_attrs() - if hasattr(package, 'dir_path'): - return Vocab.from_package(package, get_lex_attr=get_lex_attr, - vectors_package=vectors_package) - else: - return Vocab.load(package, get_lex_attr) - - @classmethod - def default_parser(cls, package, vocab): - if hasattr(package, 'dir_path'): - data_dir = package.dir_path('deps') - else: - data_dir = package - if data_dir and path.exists(data_dir): - return Parser.from_dir(data_dir, vocab.strings, ArcEager) - else: - return None - - @classmethod - def default_entity(cls, package, vocab): - if hasattr(package, 'dir_path'): - data_dir = package.dir_path('ner') - else: - data_dir = package - if data_dir and path.exists(data_dir): - return Parser.from_dir(data_dir, vocab.strings, BiluoPushDown) - else: - return None +class Language(object): + '''A text-processing pipeline. Usually you'll load this once per process, and + pass the instance around your program. + ''' + + lang = None def __init__(self, - data_dir=None, - vocab=None, - tokenizer=None, - tagger=None, - parser=None, - entity=None, - matcher=None, - serializer=None, - load_vectors=True, - package=None, - vectors_package=None): + path=None, + vocab=True, + tokenizer=True, + tagger=True, + parser=True, + entity=True, + matcher=True, + serializer=True, + vectors=True, + pipeline=True, + defaults=True, + data_dir=None): """ A model can be specified: @@ -165,44 +175,24 @@ class Language(object): - spacy.load('en_default', via='/my/package/root') - spacy.load('en_default==1.0.0', via='/my/package/root') """ - if package is None: - if data_dir is None: - package = util.get_package_by_name(about.__models__[self.lang]) - else: - package = util.get_package(data_dir) - - if load_vectors is not True: - warn("load_vectors is deprecated", DeprecationWarning) - - if vocab in (None, True): - vocab = self.default_vocab(package, vectors_package=vectors_package) - self.vocab = vocab - if tokenizer in (None, True): - tokenizer = Tokenizer.from_package(package, self.vocab) - self.tokenizer = tokenizer - if tagger in (None, True): - tagger = Tagger.from_package(package, self.vocab) - self.tagger = tagger - if entity in (None, True): - entity = self.default_entity(package, self.vocab) - self.entity = entity - if parser in (None, True): - parser = self.default_parser(package, self.vocab) - self.parser = parser - if matcher in (None, True): - matcher = Matcher.from_package(package, self.vocab) - self.matcher = matcher - self.pipeline = [ - self.tokenizer, - self.tagger, - self.entity, - self.parser, - self.matcher - ] + if data_dir is not None and path is None: + warn("'data_dir' argument now named 'path'. Doing what you mean.") + path = data_dir + if isinstance(path, basestring): + path = pathlib.Path(path) + defaults = defaults if defaults is not True else self.get_defaults(self.path) + + self.vocab = vocab if vocab is not True else defaults.Vocab(vectors=vectors) + self.tokenizer = tokenizer if tokenizer is not True else defaults.Tokenizer(self.vocab) + self.tagger = tagger if tagger is not True else defaults.Tagger(self.vocab) + self.entity = entity if entity is not True else defaults.Entity(self.vocab) + self.parser = parser if parser is not True else defaults.Parser(self.vocab) + self.matcher = matcher if matcher is not True else defaults.Matcher(self.vocab) + self.pipeline = self.pipeline if pipeline is not True else defaults.Pipeline(self) def __reduce__(self): args = ( - None, # data_dir + self.path, self.vocab, self.tokenizer, self.tagger, @@ -255,23 +245,23 @@ class Language(object): for doc in stream: yield doc - def end_training(self, data_dir=None): - if data_dir is None: - data_dir = self.data_dir + def end_training(self, path=None): + if path is None: + path = self.path if self.parser: self.parser.model.end_training() - self.parser.model.dump(path.join(data_dir, 'deps', 'model')) + self.parser.model.dump(path / 'deps' / 'model') if self.entity: self.entity.model.end_training() - self.entity.model.dump(path.join(data_dir, 'ner', 'model')) + self.entity.model.dump(path / 'ner' / 'model') if self.tagger: self.tagger.model.end_training() - self.tagger.model.dump(path.join(data_dir, 'pos', 'model')) + self.tagger.model.dump(path / 'pos' / 'model') - strings_loc = path.join(data_dir, 'vocab', 'strings.json') - with io.open(strings_loc, 'w', encoding='utf8') as file_: + strings_loc = path / 'vocab' / 'strings.json' + with strings_loc.open('w', encoding='utf8') as file_: self.vocab.strings.dump(file_) - self.vocab.dump(path.join(data_dir, 'vocab', 'lexemes.bin')) + self.vocab.dump(path / 'vocab' / 'lexemes.bin') if self.tagger: tagger_freqs = list(self.tagger.freqs[TAG].items()) @@ -289,7 +279,7 @@ class Language(object): else: entity_iob_freqs = [] entity_type_freqs = [] - with open(path.join(data_dir, 'vocab', 'serializer.json'), 'w') as file_: + with (path / 'vocab' / 'serializer.json').open('w') as file_: file_.write( json.dumps([ (TAG, tagger_freqs), @@ -298,3 +288,7 @@ class Language(object): (ENT_TYPE, entity_type_freqs), (HEAD, head_freqs) ])) + + + def get_defaults(self, path): + return Defaults(path)