Tidy up util and helpers

This commit is contained in:
ines 2017-10-27 14:39:09 +02:00
parent d941fc3667
commit ea4a41c8fb
6 changed files with 71 additions and 63 deletions

View File

@ -87,15 +87,15 @@ def symlink_to(orig, dest):
def is_config(python2=None, python3=None, windows=None, linux=None, osx=None):
return ((python2 == None or python2 == is_python2) and
(python3 == None or python3 == is_python3) and
(windows == None or windows == is_windows) and
(linux == None or linux == is_linux) and
(osx == None or osx == is_osx))
return ((python2 is None or python2 == is_python2) and
(python3 is None or python3 == is_python3) and
(windows is None or windows == is_windows) and
(linux is None or linux == is_linux) and
(osx is None or osx == is_osx))
def normalize_string_keys(old):
'''Given a dictionary, make sure keys are unicode strings, not bytes.'''
"""Given a dictionary, make sure keys are unicode strings, not bytes."""
new = {}
for key, value in old.items():
if isinstance(key, bytes_):

View File

@ -24,7 +24,7 @@ def depr_model_download(lang):
def resolve_load_name(name, **overrides):
"""Resolve model loading if deprecated path kwarg is specified in overrides.
"""Resolve model loading if deprecated path kwarg in overrides.
name (unicode): Name of model to load.
**overrides: Overrides specified in spacy.load().
@ -32,8 +32,9 @@ def resolve_load_name(name, **overrides):
"""
if overrides.get('path') not in (None, False, True):
name = overrides.get('path')
prints("To load a model from a path, you can now use the first argument. "
"The model meta is used to load the required Language class.",
"OLD: spacy.load('en', path='/some/path')", "NEW: spacy.load('/some/path')",
prints("To load a model from a path, you can now use the first "
"argument. The model meta is used to load the Language class.",
"OLD: spacy.load('en', path='/some/path')",
"NEW: spacy.load('/some/path')",
title="Warning: deprecated argument 'path'")
return name

View File

@ -264,7 +264,6 @@ GLOSSARY = {
'nk': 'noun kernel element',
'nmc': 'numerical component',
'oa': 'accusative object',
'oa': 'second accusative object',
'oc': 'clausal object',
'og': 'genitive object',
'op': 'prepositional object',

View File

@ -43,8 +43,8 @@ def POS_tree(root, light=False, flat=False):
def parse_tree(doc, light=False, flat=False):
"""Makes a copy of the doc, then construct a syntactic parse tree, similar to
the one used in displaCy. Generates the POS tree for all sentences in a doc.
"""Make a copy of the doc and construct a syntactic parse tree similar to
displaCy. Generates the POS tree for all sentences in a doc.
doc (Doc): The doc for parsing.
RETURNS (dict): The parse tree.
@ -66,8 +66,9 @@ def parse_tree(doc, light=False, flat=False):
'NE': '', 'word': 'ate', 'arc': 'ROOT', 'POS_coarse': 'VERB',
'POS_fine': 'VBD', 'lemma': 'eat'}
"""
doc_clone = Doc(doc.vocab, words=[w.text for w in doc])
doc_clone = Doc(doc.vocab, words=[w.text for w in doc])
doc_clone.from_array([HEAD, TAG, DEP, ENT_IOB, ENT_TYPE],
doc.to_array([HEAD, TAG, DEP, ENT_IOB, ENT_TYPE]))
merge_ents(doc_clone) # merge the entities into single tokens first
return [POS_tree(sent.root, light=light, flat=flat) for sent in doc_clone.sents]
return [POS_tree(sent.root, light=light, flat=flat)
for sent in doc_clone.sents]

View File

@ -1,5 +1,9 @@
# coding: utf8
from __future__ import unicode_literals
import functools
class Underscore(object):
doc_extensions = {}
span_extensions = {}

View File

@ -10,25 +10,27 @@ from pathlib import Path
import sys
import textwrap
import random
import numpy
import io
import dill
from collections import OrderedDict
from thinc.neural._classes.model import Model
import functools
from .symbols import ORTH
from .compat import cupy, CudaStream, path2str, basestring_, input_, unicode_
from .compat import import_file
import msgpack
import msgpack_numpy
msgpack_numpy.patch()
import ujson
from .symbols import ORTH
from .compat import cupy, CudaStream, path2str, basestring_, input_, unicode_
from .compat import copy_array, normalize_string_keys, getattr_, import_file
LANGUAGES = {}
_data_path = Path(__file__).parent / 'data'
_PRINT_ENV = False
def set_env_log(value):
global _PRINT_ENV
_PRINT_ENV = value
def get_lang_class(lang):
@ -38,11 +40,12 @@ def get_lang_class(lang):
RETURNS (Language): Language class.
"""
global LANGUAGES
if not lang in LANGUAGES:
if lang not in LANGUAGES:
try:
module = importlib.import_module('.lang.%s' % lang, 'spacy')
except ImportError:
raise ImportError("Can't import language %s from spacy.lang." %lang)
msg = "Can't import language %s from spacy.lang."
raise ImportError(msg % lang)
LANGUAGES[lang] = getattr(module, module.__all__[0])
return LANGUAGES[lang]
@ -100,14 +103,14 @@ def load_model(name, **overrides):
data_path = get_data_path()
if not data_path or not data_path.exists():
raise IOError("Can't find spaCy data path: %s" % path2str(data_path))
if isinstance(name, basestring_):
if name in set([d.name for d in data_path.iterdir()]): # in data dir / shortcut
if isinstance(name, basestring_): # in data dir / shortcut
if name in set([d.name for d in data_path.iterdir()]):
return load_model_from_link(name, **overrides)
if is_package(name): # installed as package
if is_package(name): # installed as package
return load_model_from_package(name, **overrides)
if Path(name).exists(): # path to model data directory
if Path(name).exists(): # path to model data directory
return load_model_from_path(Path(name), **overrides)
elif hasattr(name, 'exists'): # Path or Path-like to model data
elif hasattr(name, 'exists'): # Path or Path-like to model data
return load_model_from_path(name, **overrides)
raise IOError("Can't find model '%s'" % name)
@ -120,7 +123,7 @@ def load_model_from_link(name, **overrides):
except AttributeError:
raise IOError(
"Cant' load '%s'. If you're using a shortcut link, make sure it "
"points to a valid model package (not just a data directory)." % name)
"points to a valid package (not just a data directory)." % name)
return cls.load(**overrides)
@ -164,7 +167,8 @@ def load_model_from_init_py(init_file, **overrides):
data_dir = '%s_%s-%s' % (meta['lang'], meta['name'], meta['version'])
data_path = model_path / data_dir
if not model_path.exists():
raise ValueError("Can't find model directory: %s" % path2str(data_path))
msg = "Can't find model directory: %s"
raise ValueError(msg % path2str(data_path))
return load_model_from_path(data_path, meta, **overrides)
@ -176,14 +180,16 @@ def get_model_meta(path):
"""
model_path = ensure_path(path)
if not model_path.exists():
raise ValueError("Can't find model directory: %s" % path2str(model_path))
msg = "Can't find model directory: %s"
raise ValueError(msg % path2str(model_path))
meta_path = model_path / 'meta.json'
if not meta_path.is_file():
raise IOError("Could not read meta.json from %s" % meta_path)
meta = read_json(meta_path)
for setting in ['lang', 'name', 'version']:
if setting not in meta or not meta[setting]:
raise ValueError("No valid '%s' setting found in model meta.json" % setting)
msg = "No valid '%s' setting found in model meta.json"
raise ValueError(msg % setting)
return meta
@ -240,7 +246,7 @@ def get_async(stream, numpy_array):
return numpy_array
else:
array = cupy.ndarray(numpy_array.shape, order='C',
dtype=numpy_array.dtype)
dtype=numpy_array.dtype)
array.set(numpy_array, stream=stream)
return array
@ -274,12 +280,6 @@ def itershuffle(iterable, bufsize=1000):
raise StopIteration
_PRINT_ENV = False
def set_env_log(value):
global _PRINT_ENV
_PRINT_ENV = value
def env_opt(name, default=None):
if type(default) is float:
type_convert = float
@ -305,17 +305,20 @@ def read_regex(path):
path = ensure_path(path)
with path.open() as file_:
entries = file_.read().split('\n')
expression = '|'.join(['^' + re.escape(piece) for piece in entries if piece.strip()])
expression = '|'.join(['^' + re.escape(piece)
for piece in entries if piece.strip()])
return re.compile(expression)
def compile_prefix_regex(entries):
if '(' in entries:
# Handle deprecated data
expression = '|'.join(['^' + re.escape(piece) for piece in entries if piece.strip()])
expression = '|'.join(['^' + re.escape(piece)
for piece in entries if piece.strip()])
return re.compile(expression)
else:
expression = '|'.join(['^' + piece for piece in entries if piece.strip()])
expression = '|'.join(['^' + piece
for piece in entries if piece.strip()])
return re.compile(expression)
@ -359,16 +362,15 @@ def update_exc(base_exceptions, *addition_dicts):
exc = dict(base_exceptions)
for additions in addition_dicts:
for orth, token_attrs in additions.items():
if not all(isinstance(attr[ORTH], unicode_) for attr in token_attrs):
msg = "Invalid value for ORTH in exception: key='%s', orths='%s'"
if not all(isinstance(attr[ORTH], unicode_)
for attr in token_attrs):
msg = "Invalid ORTH value in exception: key='%s', orths='%s'"
raise ValueError(msg % (orth, token_attrs))
described_orth = ''.join(attr[ORTH] for attr in token_attrs)
if orth != described_orth:
raise ValueError("Invalid tokenizer exception: ORTH values "
"combined don't match original string. "
"key='%s', orths='%s'" % (orth, described_orth))
# overlap = set(exc.keys()).intersection(set(additions))
# assert not overlap, overlap
msg = ("Invalid tokenizer exception: ORTH values combined "
"don't match original string. key='%s', orths='%s'")
raise ValueError(msg % (orth, described_orth))
exc.update(additions)
exc = expand_exc(exc, "'", "")
return exc
@ -401,17 +403,15 @@ def normalize_slice(length, start, stop, step=None):
raise ValueError("Stepped slices not supported in Span objects."
"Try: list(tokens)[start:stop:step] instead.")
if start is None:
start = 0
start = 0
elif start < 0:
start += length
start += length
start = min(length, max(0, start))
if stop is None:
stop = length
stop = length
elif stop < 0:
stop += length
stop += length
stop = min(length, max(start, stop))
assert 0 <= start <= stop <= length
return start, stop
@ -428,7 +428,7 @@ def compounding(start, stop, compound):
>>> assert next(sizes) == 1.5 * 1.5
"""
def clip(value):
return max(value, stop) if (start>stop) else min(value, stop)
return max(value, stop) if (start > stop) else min(value, stop)
curr = float(start)
while True:
yield clip(curr)
@ -438,7 +438,7 @@ def compounding(start, stop, compound):
def decaying(start, stop, decay):
"""Yield an infinite series of linearly decaying values."""
def clip(value):
return max(value, stop) if (start>stop) else min(value, stop)
return max(value, stop) if (start > stop) else min(value, stop)
nr_upd = 1.
while True:
yield clip(start * 1./(1. + decay * nr_upd))
@ -530,17 +530,19 @@ def print_markdown(data, title=None):
if isinstance(data, dict):
data = list(data.items())
markdown = ["* **{}:** {}".format(l, unicode_(v)) for l, v in data if not excl_value(v)]
markdown = ["* **{}:** {}".format(l, unicode_(v))
for l, v in data if not excl_value(v)]
if title:
print("\n## {}".format(title))
print('\n{}\n'.format('\n'.join(markdown)))
def prints(*texts, **kwargs):
"""Print formatted message (manual ANSI escape sequences to avoid dependency)
"""Print formatted message (manual ANSI escape sequences to avoid
dependency)
*texts (unicode): Texts to print. Each argument is rendered as paragraph.
**kwargs: 'title' becomes coloured headline. 'exits'=True performs sys exit.
**kwargs: 'title' becomes coloured headline. exits=True performs sys exit.
"""
exits = kwargs.get('exits', None)
title = kwargs.get('title', None)
@ -570,7 +572,8 @@ def _wrap(text, wrap_max=80, indent=4):
def minify_html(html):
"""Perform a template-specific, rudimentary HTML minification for displaCy.
Disclaimer: NOT a general-purpose solution, only removes indentation/newlines.
Disclaimer: NOT a general-purpose solution, only removes indentation and
newlines.
html (unicode): Markup to minify.
RETURNS (unicode): "Minified" HTML.