Finish refactoring data loading

This commit is contained in:
Matthew Honnibal 2016-09-24 20:26:17 +02:00
parent 83e364188c
commit fd65cf6cbb
17 changed files with 220 additions and 141 deletions

View File

@ -7,31 +7,26 @@ from . import de
from . import zh
_data_path = pathlib.Path(__file__).parent / 'data'
set_lang_class(en.English.lang, en.English)
set_lang_class(de.German.lang, de.German)
set_lang_class(zh.Chinese.lang, zh.Chinese)
def get_data_path():
return _data_path
def set_data_path(path):
global _data_path
if isinstance(path, basestring):
path = pathlib.Path(path)
_data_path = path
def load(name, vocab=None, tokenizer=None, parser=None, tagger=None, entity=None,
matcher=None, serializer=None, vectors=None, via=None):
def load(name, vocab=True, tokenizer=True, parser=True, tagger=True, entity=True,
matcher=True, serializer=True, vectors=True, via=None):
if via is None:
via = get_data_path()
cls = get_lang_class(name)
via = util.get_data_path()
target_name, target_version = util.split_data_name(name)
path = util.match_best_version(target_name, target_version, via)
if isinstance(vectors, basestring):
vectors_name, vectors_version = util.split_data_name(vectors)
vectors = util.match_best_version(vectors_name, vectors_version, via)
cls = get_lang_class(target_name)
return cls(
via,
path,
vectors=vectors,
vocab=vocab,
tokenizer=tokenizer,

View File

@ -2,15 +2,19 @@ from libc.stdio cimport fopen, fclose, fread, fwrite, FILE
cdef class CFile:
def __init__(self, loc, mode):
def __init__(self, loc, mode, on_open_error=None):
if isinstance(mode, unicode):
mode_str = mode.encode('ascii')
else:
mode_str = mode
loc = str(loc)
cdef bytes bytes_loc = loc.encode('utf8') if type(loc) == unicode else loc
self.fp = fopen(<char*>bytes_loc, mode_str)
if self.fp == NULL:
raise IOError("Could not open binary file %s" % bytes_loc)
if on_open_error is not None:
on_open_error()
else:
raise IOError("Could not open binary file %s" % bytes_loc)
self.is_open = True
def __dealloc__(self):

View File

@ -3,15 +3,22 @@ from __future__ import unicode_literals, print_function
from os import path
from ..language import Language
from ..vocab import Vocab
from ..attrs import LANG
class German(Language):
lang = 'de'
class Defaults(Language.Defaults):
def Vocab(self, vectors=None, lex_attr_getters=None):
if lex_attr_getters is None:
lex_attr_getters = dict(self.lex_attr_getters)
if vectors is None:
vectors = self.Vectors()
# set a dummy lemmatizer for now that simply returns the same string
# until the morphology is done for German
return Vocab.load(self.path, get_lex_attr=lex_attr_getters, vectors=vectors,
lemmatizer=False)
@classmethod
def default_vocab(cls, package, get_lex_attr=None, vectors_package=None):
vocab = super(German,cls).default_vocab(package,get_lex_attr,vectors_package)
# set a dummy lemmatizer for now that simply returns the same string
# until the morphology is done for German
vocab.morphology.lemmatizer = lambda string,pos: set([string])
return vocab
stop_words = set()

View File

@ -43,23 +43,6 @@ def read_lang_data(package):
return tokenization, prefix, suffix, infix
def read_prefix(fileobj):
entries = fileobj.read().split('\n')
expression = '|'.join(['^' + re.escape(piece) for piece in entries if piece.strip()])
return expression
def read_suffix(fileobj):
entries = fileobj.read().split('\n')
expression = '|'.join([piece + '$' for piece in entries if piece.strip()])
return expression
def read_infix(fileobj):
entries = fileobj.read().split('\n')
expression = '|'.join([piece for piece in entries if piece.strip()])
return expression
def align_tokens(ref, indices): # Deprecated, surely?
start = 0

View File

@ -5,37 +5,35 @@ from os import path
from ..language import Language
# improved list from Stone, Denis, Kwantes (2010)
STOPWORDS = """
a about above across after afterwards again against all almost alone along already also although always am among amongst amoungst amount an and another any anyhow anyone anything anyway anywhere are around as at back be
became because become becomes becoming been before beforehand behind being below beside besides between beyond bill both bottom but by call can
cannot cant co computer con could couldnt cry de describe
detail did didn do does doesn doing don done down due during
each eg eight either eleven else elsewhere empty enough etc even ever every everyone everything everywhere except few fifteen
fify fill find fire first five for former formerly forty found four from front full further get give go
had has hasnt have he hence her here hereafter hereby herein hereupon hers herself him himself his how however hundred i ie
if in inc indeed interest into is it its itself keep last latter latterly least less ltd
just
kg km
made make many may me meanwhile might mill mine more moreover most mostly move much must my myself name namely
neither never nevertheless next nine no nobody none noone nor not nothing now nowhere of off
often on once one only onto or other others otherwise our ours ourselves out over own part per
perhaps please put rather re
quite
rather really regarding
same say see seem seemed seeming seems serious several she should show side since sincere six sixty so some somehow someone something sometime sometimes somewhere still such system take ten
than that the their them themselves then thence there thereafter thereby therefore therein thereupon these they thick thin third this those though three through throughout thru thus to together too top toward towards twelve twenty two un under
until up unless upon us used using
various very very via
was we well were what whatever when whence whenever where whereafter whereas whereby wherein whereupon wherever whether which while whither who whoever whole whom whose why will with within without would yet you
your yours yourself yourselves
"""
STOPWORDS = set(w for w in STOPWORDS.split() if w)
class English(Language):
lang = 'en'
@staticmethod
def is_stop(string):
return 1 if string.lower() in STOPWORDS else 0
class Defaults(Language.Defaults):
lex_attr_getters = dict(Language.Defaults.lex_attr_getters)
# improved list from Stone, Denis, Kwantes (2010)
stop_words = set("""
a about above across after afterwards again against all almost alone along already also although always am among amongst amoungst amount an and another any anyhow anyone anything anyway anywhere are around as at back be
became because become becomes becoming been before beforehand behind being below beside besides between beyond bill both bottom but by call can
cannot cant co computer con could couldnt cry de describe
detail did didn do does doesn doing don done down due during
each eg eight either eleven else elsewhere empty enough etc even ever every everyone everything everywhere except few fifteen
fify fill find fire first five for former formerly forty found four from front full further get give go
had has hasnt have he hence her here hereafter hereby herein hereupon hers herself him himself his how however hundred i ie
if in inc indeed interest into is it its itself keep last latter latterly least less ltd
just
kg km
made make many may me meanwhile might mill mine more moreover most mostly move much must my myself name namely
neither never nevertheless next nine no nobody none noone nor not nothing now nowhere of off
often on once one only onto or other others otherwise our ours ourselves out over own part per
perhaps please put rather re
quite
rather really regarding
same say see seem seemed seeming seems serious several she should show side since sincere six sixty so some somehow someone something sometime sometimes somewhere still such system take ten
than that the their them themselves then thence there thereafter thereby therefore therein thereupon these they thick thin third this those though three through throughout thru thus to together too top toward towards twelve twenty two un under
until up unless upon us used using
various very very via
was we well were what whatever when whence whenever where whereafter whereas whereby wherein whereupon wherever whether which while whither who whoever whole whom whose why will with within without would yet you
your yours yourself yourselves
""".split())

View File

@ -1,4 +1,5 @@
from __future__ import absolute_import
from __future__ import unicode_literals
from warnings import warn
import pathlib
@ -17,36 +18,38 @@ from . import attrs
from . import orth
from .syntax.ner import BiluoPushDown
from .syntax.arc_eager import ArcEager
from . import util
from .attrs import TAG, DEP, ENT_IOB, ENT_TYPE, HEAD
from .attrs import TAG, DEP, ENT_IOB, ENT_TYPE, HEAD, PROB, LANG, IS_STOP
class Defaults(object):
class BaseDefaults(object):
def __init__(self, lang, path):
self.lang = lang
self.path = path
self.lang = lang
self.lex_attr_getters = dict(self.__class__.lex_attr_getters)
if (self.path / 'vocab' / 'oov_prob').exists():
with (self.path / 'vocab' / 'oov_prob').open() as file_:
oov_prob = file_.read().strip()
self.lex_attr_getters['PROB'] = lambda string: oov_prob
self.lex_attr_getters['LANG'] = lambda string: self.lang,
self.lex_attr_getters[PROB] = lambda string: oov_prob
self.lex_attr_getters[LANG] = lambda string: lang
self.lex_attr_getters[IS_STOP] = lambda string: string in self.stop_words
def Vectors(self):
pass
return True
def Vocab(self, vectors=None, lex_attr_getters=None):
if lex_attr_getters is None:
lex_attr_getters = dict(self.lex_attr_getters)
if vectors is None:
vectors = self.Vectors()
return Vocab.load(self.path, get_lex_attr=get_lex_attr, vectors=vectors)
return Vocab.load(self.path, get_lex_attr=self.lex_attr_getters, vectors=vectors)
def Tokenizer(self, vocab):
return Tokenizer.load(self.path, vocab)
def Tagger(self, vocab):
return Tagger.load(self.path, self.vocab)
return Tagger.load(self.path / 'pos', vocab)
def Parser(self, vocab):
if (self.path / 'deps').exists():
@ -74,6 +77,9 @@ class Defaults(object):
ner_labels = {0: {'PER': True, 'LOC': True, 'ORG': True, 'MISC': True}}
stop_words = set()
lex_attr_getters = {
attrs.LOWER: lambda string: string.lower(),
attrs.NORM: lambda string: string,
@ -101,11 +107,12 @@ class Defaults(object):
}
class Language(object):
'''A text-processing pipeline. Usually you'll load this once per process, and
pass the instance around your program.
'''
Defaults = Defaults
Defaults = BaseDefaults
lang = None
def __init__(self,
@ -144,6 +151,8 @@ class Language(object):
path = data_dir
if isinstance(path, basestring):
path = pathlib.Path(path)
if path is None:
path = util.match_best_version(self.lang, '', util.get_data_path())
self.path = path
defaults = defaults if defaults is not True else self.get_defaults(self.path)
@ -256,4 +265,4 @@ class Language(object):
def get_defaults(self, path):
return Defaults(self.lang, path)
return self.Defaults(self.lang, path)

View File

@ -1,6 +1,6 @@
from __future__ import unicode_literals, print_function
from os import path
import codecs
import pathlib
try:
import ujson as json
@ -12,19 +12,24 @@ from .parts_of_speech import NOUN, VERB, ADJ, PUNCT
class Lemmatizer(object):
@classmethod
def load(cls, via):
return cls.from_package(get_package(via))
@classmethod
def from_package(cls, pkg):
def load(cls, path):
index = {}
exc = {}
for pos in ['adj', 'noun', 'verb']:
with pkg.open(('wordnet', 'index.%s' % pos), default=None) as file_:
index[pos] = read_index(file_) if file_ is not None else set()
with pkg.open(('wordnet', '%s.exc' % pos), default=None) as file_:
exc[pos] = read_exc(file_) if file_ is not None else {}
rules = pkg.load_json(('vocab', 'lemma_rules.json'), default={})
pos_index_path = path / 'wordnet' / 'index.{pos}'.format(pos=pos)
if pos_index_path.exists():
with pos_index_path.open() as file_:
index[pos] = read_index(file_)
else:
index[pos] = set()
pos_exc_path = path / 'wordnet' / '{pos}.exc'.format(pos=pos)
if pos_exc_path.exists():
with pos_exc_path.open() as file_:
exc[pos] = read_exc(file_)
else:
exc[pos] = {}
with (path / 'vocab' / 'lemma_rules.json').open() as file_:
rules = json.load(file_)
return cls(index, exc, rules)
def __init__(self, index, exceptions, rules):

View File

@ -197,9 +197,11 @@ cdef class Matcher:
@classmethod
def load(cls, path, vocab):
if (path / 'patterns.json').exists():
with (path / 'patterns.json').open() as file_:
if (path / 'gazetteer.json').exists():
with (path / 'gazetteer.json').open() as file_:
patterns = json.load(file_)
else:
patterns = {}
return cls(vocab, patterns)
def __init__(self, vocab, patterns={}):

View File

@ -17,5 +17,6 @@ cdef class Parser:
cdef readonly Vocab vocab
cdef readonly ParserModel model
cdef readonly TransitionSystem moves
cdef readonly object cfg
cdef int parseC(self, TokenC* tokens, int length, int nr_feat, int nr_class) nogil

View File

@ -81,12 +81,12 @@ cdef class Parser:
@classmethod
def load(cls, path, Vocab vocab, moves_class):
with (path / 'config.json').open() as file_:
cfg = json.loads(file_)
cfg = json.load(file_)
moves = moves_class(vocab.strings, cfg['labels'])
templates = get_templates(cfg['features'])
model = ParserModel(templates)
if (path / 'model').exists():
model.load(path / 'model')
model.load(str(path / 'model'))
return cls(vocab, moves, model, **cfg)
def __init__(self, Vocab vocab, transition_system, ParserModel model, **cfg):

View File

@ -1,5 +1,5 @@
import json
from os import path
import pathlib
from collections import defaultdict
from libc.string cimport memset
@ -102,10 +102,6 @@ cdef inline void _fill_from_token(atom_t* context, const TokenC* t) nogil:
cdef class Tagger:
"""A part-of-speech tagger for English"""
@classmethod
def read_config(cls, data_dir):
return json.load(open(path.join(data_dir, 'pos', 'config.json')))
@classmethod
def default_templates(cls):
return (
@ -146,15 +142,16 @@ cdef class Tagger:
@classmethod
def load(cls, path, vocab):
if (path / 'pos' / 'templates.json').exists():
with (path / 'pos' / 'templates.json').open() as file_:
path = path if not isinstance(path, basestring) else pathlib.Path(path)
if (path / 'templates.json').exists():
with (path / 'templates.json').open() as file_:
templates = json.load(file_)
else:
templates = cls.default_templates()
model = TaggerModel(templates)
if (path / 'pos' / 'model').exists():
model.load(path / 'pos' / 'model')
if (path / 'model').exists():
model.load(str(path / 'model'))
return cls(vocab, model)
def __init__(self, Vocab vocab, TaggerModel model):

View File

@ -22,13 +22,13 @@ from spacy.serialize.bits import BitArray
@pytest.fixture
def vocab():
data_dir = os.environ.get('SPACY_DATA')
if data_dir is None:
package = util.get_package_by_name('en')
path = os.environ.get('SPACY_DATA')
if path is None:
path = util.match_best_version('en', None, util.get_data_path())
else:
package = util.get_package(data_dir)
path = util.match_best_version('en', None, path)
vocab = English.default_vocab(package=package)
vocab = English.Defaults('en', path).Vocab()
lex = vocab['dog']
assert vocab[vocab.strings['dog']].orth_ == 'dog'
lex = vocab['the']
@ -40,7 +40,7 @@ def vocab():
@pytest.fixture
def tokenizer(vocab):
null_re = re.compile(r'!!!!!!!!!')
tokenizer = Tokenizer(vocab, {}, null_re, null_re, null_re)
tokenizer = Tokenizer(vocab, {}, null_re.search, null_re.search, null_re.finditer)
return tokenizer

View File

@ -11,29 +11,26 @@ import pytest
@pytest.fixture
def package():
data_dir = os.environ.get('SPACY_DATA')
if data_dir is None:
return util.get_package_by_name('en')
else:
return util.get_package(data_dir)
def path():
return util.match_best_version('en', None,
os.environ.get('SPACY_DATA', util.get_data_path()))
@pytest.fixture
def lemmatizer(package):
return Lemmatizer.from_package(package)
def lemmatizer(path):
return Lemmatizer.load(path)
def test_read_index(package):
with package.open(('wordnet', 'index.noun')) as file_:
def test_read_index(path):
with (path / 'wordnet' / 'index.noun').open() as file_:
index = read_index(file_)
assert 'man' in index
assert 'plantes' not in index
assert 'plant' in index
def test_read_exc(package):
with package.open(('wordnet', 'verb.exc')) as file_:
def test_read_exc(path):
with (path / 'wordnet' / 'verb.exc').open() as file_:
exc = read_exc(file_)
assert exc['was'] == ('be',)

View File

@ -16,9 +16,9 @@ cdef class Tokenizer:
cdef PreshMap _specials
cpdef readonly Vocab vocab
cdef object _prefix_re
cdef object _suffix_re
cdef object _infix_re
cdef public object prefix_search
cdef public object suffix_search
cdef public object infix_finditer
cdef object _rules
cpdef Doc tokens_from_list(self, list strings)

View File

@ -46,11 +46,11 @@ cdef class Tokenizer:
with (path / 'tokenizer' / 'specials.json').open() as file_:
rules = json.load(file_)
if prefix_search is None:
prefix_search = util.read_regex(path / 'tokenizer' / 'prefix.txt').search
prefix_search = util.read_prefix_regex(path / 'tokenizer' / 'prefix.txt').search
if suffix_search is None:
suffix_search = util.read_regex(path / 'tokenizer' / 'suffix.txt').search
suffix_search = util.read_suffix_regex(path / 'tokenizer' / 'suffix.txt').search
if infix_finditer is None:
infix_finditer = util.read_regex(path / 'tokenizer' / 'infix.txt').finditer
infix_finditer = util.read_infix_regex(path / 'tokenizer' / 'infix.txt').finditer
return cls(vocab, rules, prefix_search, suffix_search, infix_finditer)
@ -297,6 +297,7 @@ cdef class Tokenizer:
def find_suffix(self, unicode string):
match = self.suffix_search(string)
print("Suffix", match, string)
return (match.end() - match.start()) if match is not None else 0
def _load_special_tokenization(self, special_cases):

View File

@ -3,12 +3,14 @@ import io
import json
import re
import os.path
import pathlib
import six
from .attrs import TAG, HEAD, DEP, ENT_IOB, ENT_TYPE
LANGUAGES = {}
_data_path = pathlib.Path(__file__).parent / 'data'
def set_lang_class(name, cls):
@ -23,6 +25,81 @@ def get_lang_class(name):
return LANGUAGES[lang]
def get_data_path():
return _data_path
def set_data_path(path):
global _data_path
if isinstance(path, basestring):
path = pathlib.Path(path)
_data_path = path
def match_best_version(target_name, target_version, path):
path = path if not isinstance(path, basestring) else pathlib.Path(path)
matches = []
for data_name in path.iterdir():
name, version = split_data_name(data_name.parts[-1])
if name == target_name and constraint_match(target_version, version):
matches.append((tuple(float(v) for v in version.split('.')), data_name))
if matches:
return pathlib.Path(max(matches)[1])
else:
return None
def split_data_name(name):
return name.split('-', 1) if '-' in name else (name, '')
def constraint_match(constraint_string, version):
# From http://github.com/spacy-io/sputnik
if not constraint_string:
return True
constraints = [c.strip() for c in constraint_string.split(',') if c.strip()]
for c in constraints:
if not re.match(r'[><=][=]?\d+(\.\d+)*', c):
raise ValueError('invalid constraint: %s' % c)
return all(semver.match(version, c) for c in constraints)
def read_regex(path):
path = path if not isinstance(path, basestring) else pathlib.Path(path)
with path.open() as file_:
entries = file_.read().split('\n')
expression = '|'.join(['^' + re.escape(piece) for piece in entries if piece.strip()])
return re.compile(expression)
def read_prefix_regex(path):
path = path if not isinstance(path, basestring) else pathlib.Path(path)
with path.open() as file_:
entries = file_.read().split('\n')
expression = '|'.join(['^' + re.escape(piece) for piece in entries if piece.strip()])
return re.compile(expression)
def read_suffix_regex(path):
path = path if not isinstance(path, basestring) else pathlib.Path(path)
with path.open() as file_:
entries = file_.read().split('\n')
expression = '|'.join([piece + '$' for piece in entries if piece.strip()])
return re.compile(expression)
def read_infix_regex(path):
path = path if not isinstance(path, basestring) else pathlib.Path(path)
with path.open() as file_:
entries = file_.read().split('\n')
expression = '|'.join([piece for piece in entries if piece.strip()])
return re.compile(expression)
def normalize_slice(length, start, stop, step=None):
if not (step is None or step == 1):
raise ValueError("Stepped slices not supported in Span objects."

View File

@ -47,19 +47,21 @@ cdef class Vocab:
'''A map container for a language's LexemeC structs.
'''
@classmethod
def load(cls, path, get_lex_attr=None, vectors=True, lemmatizer=None):
def load(cls, path, get_lex_attr=None, vectors=True, lemmatizer=True):
if (path / 'vocab' / 'tag_map.json').exists():
with (path / 'vocab' / 'tag_map.json').open() as file_:
tag_map = json.loads(file_)
tag_map = json.load(file_)
else:
tag_map = {}
if lemmatizer is None:
if lemmatizer is True:
lemmatizer = Lemmatizer.load(path)
elif not lemmatizer:
lemmatizer = lambda string, pos: set((string,))
if (path / 'vocab' / 'serializer.json').exists():
with (path / 'vocab' / 'serializer.json').open() as file_:
serializer_freqs = json.loads(file_)
serializer_freqs = json.load(file_)
else:
serializer_freqs = {}
@ -72,7 +74,8 @@ cdef class Vocab:
if vectors is True:
vectors = lambda self_: self_.load_vectors_from_bin_loc(path / 'vocab' / 'vec.bin')
self.vectors_length = vectors(self)
if vectors:
self.vectors_length = vectors(self)
return self
def __init__(self, get_lex_attr=None, tag_map=None, lemmatizer=None, serializer_freqs=None):
@ -101,6 +104,7 @@ cdef class Vocab:
self.length = 1
self._serializer = None
print("Vocab lang", self.lang)
property serializer:
def __get__(self):
@ -113,7 +117,7 @@ cdef class Vocab:
def __get__(self):
langfunc = None
if self.get_lex_attr:
langfunc = self.get_lex_attr.get(LANG,None)
langfunc = self.get_lex_attr.get(LANG, None)
return langfunc('_') if langfunc else ''
def __len__(self):
@ -261,9 +265,8 @@ cdef class Vocab:
fp.close()
def load_lexemes(self, loc):
if not path.exists(loc):
raise IOError('LexemeCs file not found at %s' % loc)
fp = CFile(loc, 'rb')
fp = CFile(loc, 'rb',
on_open_error=lambda: IOError('LexemeCs file not found at %s' % loc))
cdef LexemeC* lexeme
cdef hash_t key
cdef unicode py_str