* Fix merge conflicts for headers branch

This commit is contained in:
Matthew Honnibal 2015-12-27 17:46:25 +01:00
commit 8b61d45ed0
23 changed files with 263 additions and 210 deletions

View File

@ -55,7 +55,13 @@ install:
build_script:
# Build the compiled extension
- "%CMD_IN_ENV% python build.py pip"
- "%CMD_IN_ENV% python setup.py build_ext --inplace"
- ps: appveyor\download.ps1
- "tar -xzf corpora/en/wordnet.tar.gz"
- "%CMD_IN_ENV% python bin/init_model.py en lang_data/ corpora/ data"
- "cp package.json data"
- "%CMD_IN_ENV% sputnik build data en_default.sputnik"
- "%CMD_IN_ENV% sputnik install en_default.sputnik"
test_script:
# Run the project tests

View File

@ -17,8 +17,21 @@ env:
- PIP_DATE=2015-10-01 MODE=setup-develop
install:
- pip install --disable-pip-version-check -U pip
- python build.py prepare $PIP_DATE
- "pip install --upgrade setuptools"
- "pip install cython fabric fabtools"
- "pip install -r requirements.txt"
- "python setup.py build_ext --inplace"
- "mkdir -p corpora/en"
- "cd corpora/en"
- "wget --no-check-certificate http://wordnetcode.princeton.edu/3.0/WordNet-3.0.tar.gz"
- "tar -xzf WordNet-3.0.tar.gz"
- "mv WordNet-3.0 wordnet"
- "cd ../../"
- "export PYTHONPATH=`pwd`"
- "python bin/init_model.py en lang_data/ corpora/ data"
- "cp package.json data"
- "sputnik build data en_default.sputnik"
- "sputnik install en_default.sputnik"
script:
- python build.py $MODE;

17
package.json Normal file
View File

@ -0,0 +1,17 @@
{
"name": "en_default",
"version": "0.100.0",
"description": "english default model",
"license": "public domain",
"include": [
"deps/*",
"ner/*",
"pos/*",
"tokenizer/*",
"vocab/*",
"wordnet/*"
],
"compatibility": {
"spacy": "==0.100.0"
}
}

View File

@ -10,4 +10,4 @@ plac
six
ujson
cloudpickle
sputnik == 0.6.2
sputnik == 0.6.3

View File

@ -6,6 +6,4 @@ from ..language import Language
class German(Language):
@classmethod
def default_data_dir(cls):
return path.join(path.dirname(__file__), 'data')
pass

View File

@ -4,8 +4,6 @@ from os import path
from ..language import Language
LOCAL_DATA_DIR = path.join(path.dirname(__file__), 'data')
# improved list from Stone, Denis, Kwantes (2010)
STOPWORDS = """
@ -35,9 +33,9 @@ your yours yourself yourselves
STOPWORDS = set(w for w in STOPWORDS.split() if w)
class English(Language):
@classmethod
def default_data_dir(cls):
return LOCAL_DATA_DIR
def __init__(self, **kwargs):
kwargs['lang'] = 'en'
super(English, self).__init__(**kwargs)
@staticmethod
def is_stop(string):

View File

@ -8,8 +8,11 @@ from sputnik import Sputnik
def migrate(path):
data_path = os.path.join(path, 'data')
if os.path.isdir(data_path) and not os.path.islink(data_path):
shutil.rmtree(data_path)
if os.path.isdir(data_path):
if os.path.islink(data_path):
os.unlink(data_path)
else:
shutil.rmtree(data_path)
for filename in os.listdir(path):
if filename.endswith('.tgz'):
os.unlink(os.path.join(path, filename))
@ -53,9 +56,6 @@ def main(data_size='all', force=False):
# FIXME clean up old-style packages
migrate(path)
# FIXME supply spacy with an old-style data dir
link(package, os.path.join(path, 'data'))
if __name__ == '__main__':
plac.call(main)

View File

@ -6,6 +6,4 @@ from ..language import Language
class Finnish(Language):
@classmethod
def default_data_dir(cls):
return path.join(path.dirname(__file__), 'data')
pass

View File

@ -6,6 +6,4 @@ from ..language import Language
class Italian(Language):
@classmethod
def default_data_dir(cls):
return path.join(path.dirname(__file__), 'data')
pass

View File

@ -20,6 +20,7 @@ from .syntax.ner import BiluoPushDown
from .syntax.arc_eager import ArcEager
from .attrs import TAG, DEP, ENT_IOB, ENT_TYPE, HEAD
from .util import get_package
class Language(object):
@ -100,7 +101,7 @@ class Language(object):
return 0
@classmethod
def default_lex_attrs(cls, data_dir=None):
def default_lex_attrs(cls):
return {
attrs.LOWER: cls.lower,
attrs.NORM: cls.norm,
@ -134,79 +135,96 @@ class Language(object):
return {0: {'PER': True, 'LOC': True, 'ORG': True, 'MISC': True}}
@classmethod
def default_data_dir(cls):
return path.join(path.dirname(__file__), 'data')
@classmethod
def default_vocab(cls, data_dir=None, get_lex_attr=None):
if data_dir is None:
data_dir = cls.default_data_dir()
def default_vocab(cls, package=None, get_lex_attr=None):
if package is None:
package = get_package()
if get_lex_attr is None:
get_lex_attr = cls.default_lex_attrs(data_dir)
return Vocab.from_dir(
path.join(data_dir, 'vocab'),
get_lex_attr=get_lex_attr)
get_lex_attr = cls.default_lex_attrs()
return Vocab.from_package(package, get_lex_attr=get_lex_attr)
@classmethod
def default_tokenizer(cls, vocab, data_dir):
if path.exists(data_dir):
return Tokenizer.from_dir(vocab, data_dir)
else:
return Tokenizer(vocab, {}, None, None, None)
@classmethod
def default_tagger(cls, vocab, data_dir):
if path.exists(data_dir):
return Tagger.from_dir(data_dir, vocab)
else:
return None
@classmethod
def default_parser(cls, vocab, data_dir):
if path.exists(data_dir):
def default_parser(cls, package, vocab):
data_dir = package.dir_path('deps', require=False)
if data_dir and path.exists(data_dir):
return Parser.from_dir(data_dir, vocab.strings, ArcEager)
else:
return None
@classmethod
def default_entity(cls, vocab, data_dir):
if path.exists(data_dir):
def default_entity(cls, package, vocab):
data_dir = package.dir_path('ner', require=False)
if data_dir and path.exists(data_dir):
return Parser.from_dir(data_dir, vocab.strings, BiluoPushDown)
else:
return None
@classmethod
def default_matcher(cls, vocab, data_dir):
if path.exists(data_dir):
return Matcher.from_dir(data_dir, vocab)
else:
return None
def __init__(self, **kwargs):
"""
a model can be specified:
1) by a path to the model directory (DEPRECATED)
- Language(data_dir='path/to/data')
2) by a language identifier (and optionally a package root dir)
- Language(lang='en')
- Language(lang='en', data_dir='spacy/data')
3) by a model name/version (and optionally a package root dir)
- Language(model='en_default')
- Language(model='en_default ==1.0.0')
- Language(model='en_default <1.1.0, data_dir='spacy/data')
"""
data_dir = kwargs.pop('data_dir', None)
lang = kwargs.pop('lang', None)
model = kwargs.pop('model', None)
vocab = kwargs.pop('vocab', None)
tokenizer = kwargs.pop('tokenizer', None)
tagger = kwargs.pop('tagger', None)
parser = kwargs.pop('parser', None)
entity = kwargs.pop('entity', None)
matcher = kwargs.pop('matcher', None)
serializer = kwargs.pop('serializer', None)
load_vectors = kwargs.pop('load_vectors', True)
# support non-package data dirs
if data_dir and path.exists(path.join(data_dir, 'vocab')):
class Package(object):
def __init__(self, root):
self.root = root
def has_file(self, *path_parts):
return path.exists(path.join(self.root, *path_parts))
def file_path(self, *path_parts, **kwargs):
return path.join(self.root, *path_parts)
def dir_path(self, *path_parts, **kwargs):
return path.join(self.root, *path_parts)
def load_utf8(self, func, *path_parts, **kwargs):
with io.open(self.file_path(path.join(*path_parts)),
mode='r', encoding='utf8') as f:
return func(f)
warn("using non-package data_dir", DeprecationWarning)
package = Package(data_dir)
else:
package = get_package(name=model, data_path=data_dir)
def __init__(self, data_dir=None, vocab=None, tokenizer=None, tagger=None,
parser=None, entity=None, matcher=None, serializer=None,
load_vectors=True):
if load_vectors is not True:
warn("load_vectors is deprecated", DeprecationWarning)
if data_dir in (None, True):
data_dir = self.default_data_dir()
if vocab in (None, True):
vocab = self.default_vocab(data_dir)
self.vocab = self.default_vocab(package)
if tokenizer in (None, True):
tokenizer = self.default_tokenizer(vocab, data_dir=path.join(data_dir, 'tokenizer'))
self.tokenizer = Tokenizer.from_package(package, self.vocab)
if tagger in (None, True):
tagger = self.default_tagger(vocab, data_dir=path.join(data_dir, 'pos'))
self.tagger = Tagger.from_package(package, self.vocab)
if entity in (None, True):
entity = self.default_entity(vocab, data_dir=path.join(data_dir, 'ner'))
self.entity = self.default_entity(package, self.vocab)
if parser in (None, True):
parser = self.default_parser(vocab, data_dir=path.join(data_dir, 'deps'))
self.parser = self.default_parser(package, self.vocab)
if matcher in (None, True):
matcher = self.default_matcher(vocab, data_dir=data_dir)
self.vocab = vocab
self.tokenizer = tokenizer
self.tagger = tagger
self.parser = parser
self.entity = entity
self.matcher = matcher
self.matcher = Matcher.from_package(package, self.vocab)
def __reduce__(self):
return (self.__class__,

View File

@ -12,16 +12,21 @@ from .parts_of_speech import NOUN, VERB, ADJ, PUNCT
class Lemmatizer(object):
@classmethod
def from_dir(cls, data_dir):
def from_package(cls, package):
index = {}
exc = {}
for pos in ['adj', 'noun', 'verb']:
index[pos] = read_index(path.join(data_dir, 'wordnet', 'index.%s' % pos))
exc[pos] = read_exc(path.join(data_dir, 'wordnet', '%s.exc' % pos))
if path.exists(path.join(data_dir, 'vocab', 'lemma_rules.json')):
rules = json.load(codecs.open(path.join(data_dir, 'vocab', 'lemma_rules.json'), encoding='utf_8'))
else:
rules = {}
index[pos] = package.load_utf8(read_index,
'wordnet', 'index.%s' % pos,
default=set()) # TODO: really optional?
exc[pos] = package.load_utf8(read_exc,
'wordnet', '%s.exc' % pos,
default={}) # TODO: really optional?
rules = package.load_utf8(json.load,
'vocab', 'lemma_rules.json',
default={}) # TODO: really optional?
return cls(index, exc, rules)
def __init__(self, index, exceptions, rules):
@ -70,11 +75,9 @@ def lemmatize(string, index, exceptions, rules):
return set(forms)
def read_index(loc):
def read_index(fileobj):
index = set()
if not path.exists(loc):
return index
for line in codecs.open(loc, 'r', 'utf8'):
for line in fileobj:
if line.startswith(' '):
continue
pieces = line.split()
@ -84,11 +87,9 @@ def read_index(loc):
return index
def read_exc(loc):
def read_exc(fileobj):
exceptions = {}
if not path.exists(loc):
return exceptions
for line in codecs.open(loc, 'r', 'utf8'):
for line in fileobj:
if line.startswith(' '):
continue
pieces = line.split()

View File

@ -169,14 +169,11 @@ cdef class Matcher:
cdef object _patterns
@classmethod
def from_dir(cls, data_dir, Vocab vocab):
patterns_loc = path.join(data_dir, 'vocab', 'gazetteer.json')
if path.exists(patterns_loc):
patterns_data = open(patterns_loc).read()
patterns = json.loads(patterns_data)
return cls(vocab, patterns)
else:
return cls(vocab, {})
def from_package(cls, package, Vocab vocab):
patterns = package.load_utf8(json.load,
'vocab', 'gazetteer.json',
default={}) # TODO: really optional?
return cls(vocab, patterns)
def __init__(self, vocab, patterns):
self.vocab = vocab

View File

@ -146,15 +146,19 @@ cdef class Tagger:
return cls(vocab, model)
@classmethod
def from_dir(cls, data_dir, vocab):
if path.exists(path.join(data_dir, 'templates.json')):
templates = json.loads(open(path.join(data_dir, 'templates.json')))
else:
templates = cls.default_templates()
def from_package(cls, package, vocab):
# TODO: templates.json deprecated? not present in latest package
templates = cls.default_templates()
# templates = package.load_utf8(json.load,
# 'pos', 'templates.json',
# default=cls.default_templates())
model = TaggerModel(vocab.morphology.n_tags,
ConjunctionExtracter(N_CONTEXT_FIELDS, templates))
if path.exists(path.join(data_dir, 'model')):
model.load(path.join(data_dir, 'model'))
if package.has_file('pos', 'model'): # TODO: really optional?
model.load(package.file_path('pos', 'model'))
return cls(vocab, model)
def __init__(self, Vocab vocab, TaggerModel model):

View File

@ -1,12 +1,11 @@
from spacy.en import English
import pytest
from spacy.en import English, LOCAL_DATA_DIR
import os
@pytest.fixture(scope="session")
def EN():
data_dir = os.environ.get('SPACY_DATA', LOCAL_DATA_DIR)
return English(data_dir=data_dir)
return English()
def pytest_addoption(parser):

View File

@ -10,7 +10,6 @@ from spacy.en import English
from spacy.vocab import Vocab
from spacy.tokens.doc import Doc
from spacy.tokenizer import Tokenizer
from spacy.en import LOCAL_DATA_DIR
from os import path
from spacy.attrs import ORTH, SPACY, TAG, DEP, HEAD

View File

@ -1,9 +1,8 @@
import pytest
from spacy.en import English, LOCAL_DATA_DIR
from spacy.en import English
import os
@pytest.fixture(scope="session")
def en_nlp():
data_dir = os.environ.get('SPACY_DATA', LOCAL_DATA_DIR)
return English(data_dir=data_dir)
return English()

View File

@ -4,31 +4,33 @@ import io
import pickle
from spacy.lemmatizer import Lemmatizer, read_index, read_exc
from spacy.en import LOCAL_DATA_DIR
from os import path
from spacy.util import get_package
import pytest
def test_read_index():
wn = path.join(LOCAL_DATA_DIR, 'wordnet')
index = read_index(path.join(wn, 'index.noun'))
@pytest.fixture
def package():
return get_package()
@pytest.fixture
def lemmatizer(package):
return Lemmatizer.from_package(package)
def test_read_index(package):
index = package.load_utf8(read_index, 'wordnet', 'index.noun')
assert 'man' in index
assert 'plantes' not in index
assert 'plant' in index
def test_read_exc():
wn = path.join(LOCAL_DATA_DIR, 'wordnet')
exc = read_exc(path.join(wn, 'verb.exc'))
def test_read_exc(package):
exc = package.load_utf8(read_exc, 'wordnet', 'verb.exc')
assert exc['was'] == ('be',)
@pytest.fixture
def lemmatizer():
return Lemmatizer.from_dir(path.join(LOCAL_DATA_DIR))
def test_noun_lemmas(lemmatizer):
do = lemmatizer.noun

View File

@ -2,16 +2,15 @@ from __future__ import unicode_literals
import pytest
import gc
from spacy.en import English, LOCAL_DATA_DIR
from spacy.en import English
import os
data_dir = os.environ.get('SPACY_DATA', LOCAL_DATA_DIR)
# Let this have its own instances, as we have to be careful about memory here
# that's the point, after all
@pytest.mark.models
def get_orphan_token(text, i):
nlp = English(data_dir=data_dir)
nlp = English()
tokens = nlp(text)
gc.collect()
token = tokens[i]
@ -41,7 +40,7 @@ def _orphan_from_list(toks):
@pytest.mark.models
def test_list_orphans():
# Test case from NSchrading
nlp = English(data_dir=data_dir)
nlp = English()
samples = ["a", "test blah wat okay"]
lst = []
for sample in samples:

View File

@ -5,9 +5,8 @@ import os
@pytest.fixture(scope='session')
def nlp():
from spacy.en import English, LOCAL_DATA_DIR
data_dir = os.environ.get('SPACY_DATA', LOCAL_DATA_DIR)
return English(data_dir=data_dir)
from spacy.en import English
return English()
@pytest.fixture()

View File

@ -10,9 +10,8 @@ def token(doc):
def test_load_resources_and_process_text():
from spacy.en import English, LOCAL_DATA_DIR
data_dir = os.environ.get('SPACY_DATA', LOCAL_DATA_DIR)
nlp = English(data_dir=data_dir)
from spacy.en import English
nlp = English()
doc = nlp('Hello, world. Here are two sentences.')

View File

@ -41,8 +41,8 @@ cdef class Tokenizer:
return (self.__class__, args, None, None)
@classmethod
def from_dir(cls, Vocab vocab, data_dir):
rules, prefix_re, suffix_re, infix_re = read_lang_data(data_dir)
def from_package(cls, package, Vocab vocab):
rules, prefix_re, suffix_re, infix_re = read_lang_data(package)
prefix_re = re.compile(prefix_re)
suffix_re = re.compile(suffix_re)
infix_re = re.compile(infix_re)

View File

@ -1,10 +1,24 @@
from os import path
import os
import io
import json
import re
from sputnik import Sputnik
from .attrs import TAG, HEAD, DEP, ENT_IOB, ENT_TYPE
DATA_DIR = path.join(path.dirname(__file__), '..', 'data')
def get_package(name=None, data_path=None):
if data_path is None:
if os.environ.get('SPACY_DATA'):
data_path = os.environ.get('SPACY_DATA')
else:
data_path = os.path.abspath(
os.path.join(os.path.dirname(__file__), 'data'))
sputnik = Sputnik('spacy', '0.100.0') # TODO: retrieve version
pool = sputnik.pool(data_path)
return pool.get(name or 'en_default')
def normalize_slice(length, start, stop, step=None):
@ -31,67 +45,63 @@ def utf8open(loc, mode='r'):
return io.open(loc, mode, encoding='utf8')
def read_lang_data(data_dir):
with open(path.join(data_dir, 'specials.json')) as file_:
tokenization = json.load(file_)
prefix = read_prefix(data_dir)
suffix = read_suffix(data_dir)
infix = read_infix(data_dir)
def read_lang_data(package):
tokenization = package.load_utf8(json.load, 'tokenizer', 'specials.json')
prefix = package.load_utf8(read_prefix, 'tokenizer', 'prefix.txt')
suffix = package.load_utf8(read_suffix, 'tokenizer', 'suffix.txt')
infix = package.load_utf8(read_infix, 'tokenizer', 'infix.txt')
return tokenization, prefix, suffix, infix
def read_prefix(data_dir):
with utf8open(path.join(data_dir, 'prefix.txt')) as file_:
entries = file_.read().split('\n')
expression = '|'.join(['^' + re.escape(piece) for piece in entries if piece.strip()])
def read_prefix(fileobj):
entries = fileobj.read().split('\n')
expression = '|'.join(['^' + re.escape(piece) for piece in entries if piece.strip()])
return expression
def read_suffix(data_dir):
with utf8open(path.join(data_dir, 'suffix.txt')) as file_:
entries = file_.read().split('\n')
expression = '|'.join([piece + '$' for piece in entries if piece.strip()])
def read_suffix(fileobj):
entries = fileobj.read().split('\n')
expression = '|'.join([piece + '$' for piece in entries if piece.strip()])
return expression
def read_infix(data_dir):
with utf8open(path.join(data_dir, 'infix.txt')) as file_:
entries = file_.read().split('\n')
expression = '|'.join([piece for piece in entries if piece.strip()])
def read_infix(fileobj):
entries = fileobj.read().split('\n')
expression = '|'.join([piece for piece in entries if piece.strip()])
return expression
def read_tokenization(lang):
loc = path.join(DATA_DIR, lang, 'tokenization')
entries = []
seen = set()
with utf8open(loc) as file_:
for line in file_:
line = line.strip()
if line.startswith('#'):
continue
if not line:
continue
pieces = line.split()
chunk = pieces.pop(0)
assert chunk not in seen, chunk
seen.add(chunk)
entries.append((chunk, list(pieces)))
if chunk[0].isalpha() and chunk[0].islower():
chunk = chunk[0].title() + chunk[1:]
pieces[0] = pieces[0][0].title() + pieces[0][1:]
seen.add(chunk)
entries.append((chunk, pieces))
return entries
# def read_tokenization(lang):
# loc = path.join(DATA_DIR, lang, 'tokenization')
# entries = []
# seen = set()
# with utf8open(loc) as file_:
# for line in file_:
# line = line.strip()
# if line.startswith('#'):
# continue
# if not line:
# continue
# pieces = line.split()
# chunk = pieces.pop(0)
# assert chunk not in seen, chunk
# seen.add(chunk)
# entries.append((chunk, list(pieces)))
# if chunk[0].isalpha() and chunk[0].islower():
# chunk = chunk[0].title() + chunk[1:]
# pieces[0] = pieces[0][0].title() + pieces[0][1:]
# seen.add(chunk)
# entries.append((chunk, pieces))
# return entries
def read_detoken_rules(lang): # Deprecated?
loc = path.join(DATA_DIR, lang, 'detokenize')
entries = []
with utf8open(loc) as file_:
for line in file_:
entries.append(line.strip())
return entries
# def read_detoken_rules(lang): # Deprecated?
# loc = path.join(DATA_DIR, lang, 'detokenize')
# entries = []
# with utf8open(loc) as file_:
# for line in file_:
# entries.append(line.strip())
# return entries
def align_tokens(ref, indices): # Deprecated, surely?

View File

@ -47,28 +47,27 @@ cdef class Vocab:
'''A map container for a language's LexemeC structs.
'''
@classmethod
def from_dir(cls, data_dir, get_lex_attr=None):
if not path.exists(data_dir):
raise IOError("Directory %s not found -- cannot load Vocab." % data_dir)
if not path.isdir(data_dir):
raise IOError("Path %s is a file, not a dir -- cannot load Vocab." % data_dir)
def from_package(cls, package, get_lex_attr=None):
tag_map = package.load_utf8(json.load,
'vocab', 'tag_map.json')
lemmatizer = Lemmatizer.from_package(package)
serializer_freqs = package.load_utf8(json.load,
'vocab', 'serializer.json',
require=False) # TODO: really optional?
tag_map = json.load(open(path.join(data_dir, 'tag_map.json')))
lemmatizer = Lemmatizer.from_dir(path.join(data_dir, '..'))
if path.exists(path.join(data_dir, 'serializer.json')):
serializer_freqs = json.load(open(path.join(data_dir, 'serializer.json')))
else:
serializer_freqs = None
cdef Vocab self = cls(get_lex_attr=get_lex_attr, tag_map=tag_map,
lemmatizer=lemmatizer, serializer_freqs=serializer_freqs)
if path.exists(path.join(data_dir, 'strings.json')):
with io.open(path.join(data_dir, 'strings.json'), 'r', encoding='utf8') as file_:
self.strings.load(file_)
self.load_lexemes(path.join(data_dir, 'lexemes.bin'))
if path.exists(path.join(data_dir, 'vec.bin')):
self.vectors_length = self.load_vectors_from_bin_loc(path.join(data_dir, 'vec.bin'))
if package.has_file('vocab', 'strings.json'): # TODO: really optional?
package.load_utf8(self.strings.load, 'vocab', 'strings.json')
self.load_lexemes(package.file_path('vocab', 'lexemes.bin'))
if package.has_file('vocab', 'vec.bin'): # TODO: really optional?
self.vectors_length = self.load_vectors_from_bin_loc(
package.file_path('vocab', 'vec.bin'))
return self
def __init__(self, get_lex_attr=None, tag_map=None, lemmatizer=None, serializer_freqs=None):