untangle data_path/via

This commit is contained in:
Henning Peters 2016-01-16 12:23:45 +01:00
parent 6d1a3af343
commit 235f094534
12 changed files with 51 additions and 58 deletions

View File

@ -6,6 +6,8 @@ import shutil
import plac
import sputnik
from sputnik.package_list import (PackageNotFoundException,
CompatiblePackageNotFoundException)
from .. import about
@ -22,28 +24,21 @@ def migrate(path):
os.unlink(os.path.join(path, filename))
def link(package, path):
if os.path.exists(path):
if os.path.isdir(path):
shutil.rmtree(path)
else:
os.unlink(path)
if not hasattr(os, 'symlink'): # not supported by win+py27
shutil.copytree(package.dir_path('data'), path)
else:
os.symlink(package.dir_path('data'), path)
@plac.annotations(
force=("Force overwrite", "flag", "f", bool),
)
def main(data_size='all', force=False):
path = os.path.dirname(os.path.abspath(__file__))
if force:
sputnik.purge(about.__name__, about.__version__)
try:
sputnik.package(about.__name__, about.__version__, about.__default_model__)
print("Model already installed. Please run 'python -m "
"spacy.en.download --force' to reinstall.", file=sys.stderr)
sys.exit(1)
except PackageNotFoundException, CompatiblePackageNotFoundException:
pass
package = sputnik.install(about.__name__, about.__version__, about.__default_model__)
try:
@ -54,7 +49,7 @@ def main(data_size='all', force=False):
sys.exit(1)
# FIXME clean up old-style packages
migrate(path)
migrate(os.path.dirname(os.path.abspath(__file__)))
print("Model successfully installed.", file=sys.stderr)

View File

@ -155,7 +155,6 @@ class Language(object):
return Parser.from_dir(data_dir, vocab.strings, BiluoPushDown)
def __init__(self,
via=None,
data_dir=None,
vocab=None,
tokenizer=None,
@ -172,9 +171,9 @@ class Language(object):
1) by calling a Language subclass
- spacy.en.English()
2) by calling a Language subclass with via (previously: data_dir)
2) by calling a Language subclass with data_dir
- spacy.en.English('my/model/root')
- spacy.en.English(via='my/model/root')
- spacy.en.English(data_dir='my/model/root')
3) by package name
- spacy.load('en_default')
@ -185,15 +184,11 @@ class Language(object):
- spacy.load('en_default==1.0.0', via='/my/package/root')
"""
if data_dir is not None and via is None:
warn("Use of data_dir is deprecated, use via instead.", DeprecationWarning)
via = data_dir
if package is None:
if via is None:
if data_dir is None:
package = util.get_package_by_name()
else:
package = util.get_package(via)
package = util.get_package(data_dir)
if load_vectors is not True:
warn("load_vectors is deprecated", DeprecationWarning)

View File

@ -170,8 +170,8 @@ cdef class Matcher:
cdef object _patterns
@classmethod
def load(cls, via, Vocab vocab):
return cls.from_package(get_package(via), vocab=vocab)
def load(cls, data_dir, Vocab vocab):
return cls.from_package(get_package(data_dir), vocab=vocab)
@classmethod
def from_package(cls, package, Vocab vocab):

View File

@ -148,8 +148,8 @@ cdef class Tagger:
return cls(vocab, model)
@classmethod
def load(cls, via, vocab):
return cls.from_package(get_package(via), vocab=vocab)
def load(cls, data_dir, vocab):
return cls.from_package(get_package(data_dir), vocab=vocab)
@classmethod
def from_package(cls, pkg, vocab):

View File

@ -7,11 +7,11 @@ import os
@pytest.fixture(scope="session")
def EN():
if os.environ.get('SPACY_DATA'):
data_path = os.environ.get('SPACY_DATA')
data_dir = os.environ.get('SPACY_DATA')
else:
data_path = None
print("Load EN from %s" % data_path)
return English(data_dir=data_path)
data_dir = None
print("Load EN from %s" % data_dir)
return English(data_dir=data_dir)
def pytest_addoption(parser):

View File

@ -13,6 +13,7 @@ from spacy.tokenizer import Tokenizer
from os import path
import os
from spacy import util
from spacy.attrs import ORTH, SPACY, TAG, DEP, HEAD
from spacy.serialize.packer import Packer
@ -21,11 +22,13 @@ from spacy.serialize.bits import BitArray
@pytest.fixture
def vocab():
if os.environ.get('SPACY_DATA'):
data_path = os.environ.get('SPACY_DATA')
data_dir = os.environ.get('SPACY_DATA')
if data_dir is None:
package = util.get_package_by_name()
else:
data_path = None
vocab = English.default_vocab(package=data_path)
package = util.get_package(data_dir)
vocab = English.default_vocab(package=package)
lex = vocab['dog']
assert vocab[vocab.strings['dog']].orth_ == 'dog'
lex = vocab['the']

View File

@ -5,23 +5,23 @@ import io
import pickle
from spacy.lemmatizer import Lemmatizer, read_index, read_exc
from spacy.util import get_package
from spacy import util
import pytest
@pytest.fixture
def package():
if os.environ.get('SPACY_DATA'):
data_path = os.environ.get('SPACY_DATA')
data_dir = os.environ.get('SPACY_DATA')
if data_dir is None:
return util.get_package_by_name()
else:
data_path = None
return get_package(data_path=data_path)
return util.get_package(data_dir)
@pytest.fixture
def lemmatizer(package):
return Lemmatizer.load(package)
return Lemmatizer.from_package(package)
def test_read_index(package):

View File

@ -7,10 +7,10 @@ import os
def nlp():
from spacy.en import English
if os.environ.get('SPACY_DATA'):
data_path = os.environ.get('SPACY_DATA')
data_dir = os.environ.get('SPACY_DATA')
else:
data_path = None
return English(data_dir=data_path)
data_dir = None
return English(data_dir=data_dir)
@pytest.fixture()

View File

@ -11,13 +11,13 @@ def token(doc):
def test_load_resources_and_process_text():
if os.environ.get('SPACY_DATA'):
data_path = os.environ.get('SPACY_DATA')
data_dir = os.environ.get('SPACY_DATA')
else:
data_path = None
print("Load EN from %s" % data_path)
data_dir = None
print("Load EN from %s" % data_dir)
from spacy.en import English
nlp = English(data_dir=data_path)
nlp = English(data_dir=data_dir)
doc = nlp('Hello, world. Here are two sentences.')

View File

@ -42,8 +42,8 @@ cdef class Tokenizer:
return (self.__class__, args, None, None)
@classmethod
def load(cls, via, Vocab vocab):
return cls.from_package(get_package(via), vocab=vocab)
def load(cls, data_dir, Vocab vocab):
return cls.from_package(get_package(data_dir), vocab=vocab)
@classmethod
def from_package(cls, package, Vocab vocab):

View File

@ -14,10 +14,10 @@ from . import about
from .attrs import TAG, HEAD, DEP, ENT_IOB, ENT_TYPE
def get_package():
if not isinstance(via, six.string_types):
raise RuntimeError('via must be a string')
return DirPackage(via)
def get_package(data_dir):
if not isinstance(data_dir, six.string_types):
raise RuntimeError('data_dir must be a string')
return DirPackage(data_dir)
def get_package_by_name(name=None, via=None):

View File

@ -48,8 +48,8 @@ cdef class Vocab:
'''A map container for a language's LexemeC structs.
'''
@classmethod
def load(cls, via, get_lex_attr=None):
return cls.from_package(get_package(via), get_lex_attr=get_lex_attr)
def load(cls, data_dir, get_lex_attr=None):
return cls.from_package(get_package(data_dir), get_lex_attr=get_lex_attr)
@classmethod
def from_package(cls, package, get_lex_attr=None):