Tidy up util and helpers

This commit is contained in:
ines 2017-10-27 14:39:09 +02:00
parent d941fc3667
commit ea4a41c8fb
6 changed files with 71 additions and 63 deletions

View File

@ -87,15 +87,15 @@ def symlink_to(orig, dest):
def is_config(python2=None, python3=None, windows=None, linux=None, osx=None): def is_config(python2=None, python3=None, windows=None, linux=None, osx=None):
return ((python2 == None or python2 == is_python2) and return ((python2 is None or python2 == is_python2) and
(python3 == None or python3 == is_python3) and (python3 is None or python3 == is_python3) and
(windows == None or windows == is_windows) and (windows is None or windows == is_windows) and
(linux == None or linux == is_linux) and (linux is None or linux == is_linux) and
(osx == None or osx == is_osx)) (osx is None or osx == is_osx))
def normalize_string_keys(old): def normalize_string_keys(old):
'''Given a dictionary, make sure keys are unicode strings, not bytes.''' """Given a dictionary, make sure keys are unicode strings, not bytes."""
new = {} new = {}
for key, value in old.items(): for key, value in old.items():
if isinstance(key, bytes_): if isinstance(key, bytes_):

View File

@ -24,7 +24,7 @@ def depr_model_download(lang):
def resolve_load_name(name, **overrides): def resolve_load_name(name, **overrides):
"""Resolve model loading if deprecated path kwarg is specified in overrides. """Resolve model loading if deprecated path kwarg in overrides.
name (unicode): Name of model to load. name (unicode): Name of model to load.
**overrides: Overrides specified in spacy.load(). **overrides: Overrides specified in spacy.load().
@ -32,8 +32,9 @@ def resolve_load_name(name, **overrides):
""" """
if overrides.get('path') not in (None, False, True): if overrides.get('path') not in (None, False, True):
name = overrides.get('path') name = overrides.get('path')
prints("To load a model from a path, you can now use the first argument. " prints("To load a model from a path, you can now use the first "
"The model meta is used to load the required Language class.", "argument. The model meta is used to load the Language class.",
"OLD: spacy.load('en', path='/some/path')", "NEW: spacy.load('/some/path')", "OLD: spacy.load('en', path='/some/path')",
"NEW: spacy.load('/some/path')",
title="Warning: deprecated argument 'path'") title="Warning: deprecated argument 'path'")
return name return name

View File

@ -264,7 +264,6 @@ GLOSSARY = {
'nk': 'noun kernel element', 'nk': 'noun kernel element',
'nmc': 'numerical component', 'nmc': 'numerical component',
'oa': 'accusative object', 'oa': 'accusative object',
'oa': 'second accusative object',
'oc': 'clausal object', 'oc': 'clausal object',
'og': 'genitive object', 'og': 'genitive object',
'op': 'prepositional object', 'op': 'prepositional object',

View File

@ -43,8 +43,8 @@ def POS_tree(root, light=False, flat=False):
def parse_tree(doc, light=False, flat=False): def parse_tree(doc, light=False, flat=False):
"""Makes a copy of the doc, then construct a syntactic parse tree, similar to """Make a copy of the doc and construct a syntactic parse tree similar to
the one used in displaCy. Generates the POS tree for all sentences in a doc. displaCy. Generates the POS tree for all sentences in a doc.
doc (Doc): The doc for parsing. doc (Doc): The doc for parsing.
RETURNS (dict): The parse tree. RETURNS (dict): The parse tree.
@ -66,8 +66,9 @@ def parse_tree(doc, light=False, flat=False):
'NE': '', 'word': 'ate', 'arc': 'ROOT', 'POS_coarse': 'VERB', 'NE': '', 'word': 'ate', 'arc': 'ROOT', 'POS_coarse': 'VERB',
'POS_fine': 'VBD', 'lemma': 'eat'} 'POS_fine': 'VBD', 'lemma': 'eat'}
""" """
doc_clone = Doc(doc.vocab, words=[w.text for w in doc]) doc_clone = Doc(doc.vocab, words=[w.text for w in doc])
doc_clone.from_array([HEAD, TAG, DEP, ENT_IOB, ENT_TYPE], doc_clone.from_array([HEAD, TAG, DEP, ENT_IOB, ENT_TYPE],
doc.to_array([HEAD, TAG, DEP, ENT_IOB, ENT_TYPE])) doc.to_array([HEAD, TAG, DEP, ENT_IOB, ENT_TYPE]))
merge_ents(doc_clone) # merge the entities into single tokens first merge_ents(doc_clone) # merge the entities into single tokens first
return [POS_tree(sent.root, light=light, flat=flat) for sent in doc_clone.sents] return [POS_tree(sent.root, light=light, flat=flat)
for sent in doc_clone.sents]

View File

@ -1,5 +1,9 @@
# coding: utf8
from __future__ import unicode_literals
import functools import functools
class Underscore(object): class Underscore(object):
doc_extensions = {} doc_extensions = {}
span_extensions = {} span_extensions = {}

View File

@ -10,25 +10,27 @@ from pathlib import Path
import sys import sys
import textwrap import textwrap
import random import random
import numpy
import io
import dill
from collections import OrderedDict from collections import OrderedDict
from thinc.neural._classes.model import Model from thinc.neural._classes.model import Model
import functools import functools
from .symbols import ORTH
from .compat import cupy, CudaStream, path2str, basestring_, input_, unicode_
from .compat import import_file
import msgpack import msgpack
import msgpack_numpy import msgpack_numpy
msgpack_numpy.patch() msgpack_numpy.patch()
import ujson
from .symbols import ORTH
from .compat import cupy, CudaStream, path2str, basestring_, input_, unicode_
from .compat import copy_array, normalize_string_keys, getattr_, import_file
LANGUAGES = {} LANGUAGES = {}
_data_path = Path(__file__).parent / 'data' _data_path = Path(__file__).parent / 'data'
_PRINT_ENV = False
def set_env_log(value):
global _PRINT_ENV
_PRINT_ENV = value
def get_lang_class(lang): def get_lang_class(lang):
@ -38,11 +40,12 @@ def get_lang_class(lang):
RETURNS (Language): Language class. RETURNS (Language): Language class.
""" """
global LANGUAGES global LANGUAGES
if not lang in LANGUAGES: if lang not in LANGUAGES:
try: try:
module = importlib.import_module('.lang.%s' % lang, 'spacy') module = importlib.import_module('.lang.%s' % lang, 'spacy')
except ImportError: except ImportError:
raise ImportError("Can't import language %s from spacy.lang." %lang) msg = "Can't import language %s from spacy.lang."
raise ImportError(msg % lang)
LANGUAGES[lang] = getattr(module, module.__all__[0]) LANGUAGES[lang] = getattr(module, module.__all__[0])
return LANGUAGES[lang] return LANGUAGES[lang]
@ -100,14 +103,14 @@ def load_model(name, **overrides):
data_path = get_data_path() data_path = get_data_path()
if not data_path or not data_path.exists(): if not data_path or not data_path.exists():
raise IOError("Can't find spaCy data path: %s" % path2str(data_path)) raise IOError("Can't find spaCy data path: %s" % path2str(data_path))
if isinstance(name, basestring_): if isinstance(name, basestring_): # in data dir / shortcut
if name in set([d.name for d in data_path.iterdir()]): # in data dir / shortcut if name in set([d.name for d in data_path.iterdir()]):
return load_model_from_link(name, **overrides) return load_model_from_link(name, **overrides)
if is_package(name): # installed as package if is_package(name): # installed as package
return load_model_from_package(name, **overrides) return load_model_from_package(name, **overrides)
if Path(name).exists(): # path to model data directory if Path(name).exists(): # path to model data directory
return load_model_from_path(Path(name), **overrides) return load_model_from_path(Path(name), **overrides)
elif hasattr(name, 'exists'): # Path or Path-like to model data elif hasattr(name, 'exists'): # Path or Path-like to model data
return load_model_from_path(name, **overrides) return load_model_from_path(name, **overrides)
raise IOError("Can't find model '%s'" % name) raise IOError("Can't find model '%s'" % name)
@ -120,7 +123,7 @@ def load_model_from_link(name, **overrides):
except AttributeError: except AttributeError:
raise IOError( raise IOError(
"Cant' load '%s'. If you're using a shortcut link, make sure it " "Cant' load '%s'. If you're using a shortcut link, make sure it "
"points to a valid model package (not just a data directory)." % name) "points to a valid package (not just a data directory)." % name)
return cls.load(**overrides) return cls.load(**overrides)
@ -164,7 +167,8 @@ def load_model_from_init_py(init_file, **overrides):
data_dir = '%s_%s-%s' % (meta['lang'], meta['name'], meta['version']) data_dir = '%s_%s-%s' % (meta['lang'], meta['name'], meta['version'])
data_path = model_path / data_dir data_path = model_path / data_dir
if not model_path.exists(): if not model_path.exists():
raise ValueError("Can't find model directory: %s" % path2str(data_path)) msg = "Can't find model directory: %s"
raise ValueError(msg % path2str(data_path))
return load_model_from_path(data_path, meta, **overrides) return load_model_from_path(data_path, meta, **overrides)
@ -176,14 +180,16 @@ def get_model_meta(path):
""" """
model_path = ensure_path(path) model_path = ensure_path(path)
if not model_path.exists(): if not model_path.exists():
raise ValueError("Can't find model directory: %s" % path2str(model_path)) msg = "Can't find model directory: %s"
raise ValueError(msg % path2str(model_path))
meta_path = model_path / 'meta.json' meta_path = model_path / 'meta.json'
if not meta_path.is_file(): if not meta_path.is_file():
raise IOError("Could not read meta.json from %s" % meta_path) raise IOError("Could not read meta.json from %s" % meta_path)
meta = read_json(meta_path) meta = read_json(meta_path)
for setting in ['lang', 'name', 'version']: for setting in ['lang', 'name', 'version']:
if setting not in meta or not meta[setting]: if setting not in meta or not meta[setting]:
raise ValueError("No valid '%s' setting found in model meta.json" % setting) msg = "No valid '%s' setting found in model meta.json"
raise ValueError(msg % setting)
return meta return meta
@ -240,7 +246,7 @@ def get_async(stream, numpy_array):
return numpy_array return numpy_array
else: else:
array = cupy.ndarray(numpy_array.shape, order='C', array = cupy.ndarray(numpy_array.shape, order='C',
dtype=numpy_array.dtype) dtype=numpy_array.dtype)
array.set(numpy_array, stream=stream) array.set(numpy_array, stream=stream)
return array return array
@ -274,12 +280,6 @@ def itershuffle(iterable, bufsize=1000):
raise StopIteration raise StopIteration
_PRINT_ENV = False
def set_env_log(value):
global _PRINT_ENV
_PRINT_ENV = value
def env_opt(name, default=None): def env_opt(name, default=None):
if type(default) is float: if type(default) is float:
type_convert = float type_convert = float
@ -305,17 +305,20 @@ def read_regex(path):
path = ensure_path(path) path = ensure_path(path)
with path.open() as file_: with path.open() as file_:
entries = file_.read().split('\n') entries = file_.read().split('\n')
expression = '|'.join(['^' + re.escape(piece) for piece in entries if piece.strip()]) expression = '|'.join(['^' + re.escape(piece)
for piece in entries if piece.strip()])
return re.compile(expression) return re.compile(expression)
def compile_prefix_regex(entries): def compile_prefix_regex(entries):
if '(' in entries: if '(' in entries:
# Handle deprecated data # Handle deprecated data
expression = '|'.join(['^' + re.escape(piece) for piece in entries if piece.strip()]) expression = '|'.join(['^' + re.escape(piece)
for piece in entries if piece.strip()])
return re.compile(expression) return re.compile(expression)
else: else:
expression = '|'.join(['^' + piece for piece in entries if piece.strip()]) expression = '|'.join(['^' + piece
for piece in entries if piece.strip()])
return re.compile(expression) return re.compile(expression)
@ -359,16 +362,15 @@ def update_exc(base_exceptions, *addition_dicts):
exc = dict(base_exceptions) exc = dict(base_exceptions)
for additions in addition_dicts: for additions in addition_dicts:
for orth, token_attrs in additions.items(): for orth, token_attrs in additions.items():
if not all(isinstance(attr[ORTH], unicode_) for attr in token_attrs): if not all(isinstance(attr[ORTH], unicode_)
msg = "Invalid value for ORTH in exception: key='%s', orths='%s'" for attr in token_attrs):
msg = "Invalid ORTH value in exception: key='%s', orths='%s'"
raise ValueError(msg % (orth, token_attrs)) raise ValueError(msg % (orth, token_attrs))
described_orth = ''.join(attr[ORTH] for attr in token_attrs) described_orth = ''.join(attr[ORTH] for attr in token_attrs)
if orth != described_orth: if orth != described_orth:
raise ValueError("Invalid tokenizer exception: ORTH values " msg = ("Invalid tokenizer exception: ORTH values combined "
"combined don't match original string. " "don't match original string. key='%s', orths='%s'")
"key='%s', orths='%s'" % (orth, described_orth)) raise ValueError(msg % (orth, described_orth))
# overlap = set(exc.keys()).intersection(set(additions))
# assert not overlap, overlap
exc.update(additions) exc.update(additions)
exc = expand_exc(exc, "'", "") exc = expand_exc(exc, "'", "")
return exc return exc
@ -401,17 +403,15 @@ def normalize_slice(length, start, stop, step=None):
raise ValueError("Stepped slices not supported in Span objects." raise ValueError("Stepped slices not supported in Span objects."
"Try: list(tokens)[start:stop:step] instead.") "Try: list(tokens)[start:stop:step] instead.")
if start is None: if start is None:
start = 0 start = 0
elif start < 0: elif start < 0:
start += length start += length
start = min(length, max(0, start)) start = min(length, max(0, start))
if stop is None: if stop is None:
stop = length stop = length
elif stop < 0: elif stop < 0:
stop += length stop += length
stop = min(length, max(start, stop)) stop = min(length, max(start, stop))
assert 0 <= start <= stop <= length assert 0 <= start <= stop <= length
return start, stop return start, stop
@ -428,7 +428,7 @@ def compounding(start, stop, compound):
>>> assert next(sizes) == 1.5 * 1.5 >>> assert next(sizes) == 1.5 * 1.5
""" """
def clip(value): def clip(value):
return max(value, stop) if (start>stop) else min(value, stop) return max(value, stop) if (start > stop) else min(value, stop)
curr = float(start) curr = float(start)
while True: while True:
yield clip(curr) yield clip(curr)
@ -438,7 +438,7 @@ def compounding(start, stop, compound):
def decaying(start, stop, decay): def decaying(start, stop, decay):
"""Yield an infinite series of linearly decaying values.""" """Yield an infinite series of linearly decaying values."""
def clip(value): def clip(value):
return max(value, stop) if (start>stop) else min(value, stop) return max(value, stop) if (start > stop) else min(value, stop)
nr_upd = 1. nr_upd = 1.
while True: while True:
yield clip(start * 1./(1. + decay * nr_upd)) yield clip(start * 1./(1. + decay * nr_upd))
@ -530,17 +530,19 @@ def print_markdown(data, title=None):
if isinstance(data, dict): if isinstance(data, dict):
data = list(data.items()) data = list(data.items())
markdown = ["* **{}:** {}".format(l, unicode_(v)) for l, v in data if not excl_value(v)] markdown = ["* **{}:** {}".format(l, unicode_(v))
for l, v in data if not excl_value(v)]
if title: if title:
print("\n## {}".format(title)) print("\n## {}".format(title))
print('\n{}\n'.format('\n'.join(markdown))) print('\n{}\n'.format('\n'.join(markdown)))
def prints(*texts, **kwargs): def prints(*texts, **kwargs):
"""Print formatted message (manual ANSI escape sequences to avoid dependency) """Print formatted message (manual ANSI escape sequences to avoid
dependency)
*texts (unicode): Texts to print. Each argument is rendered as paragraph. *texts (unicode): Texts to print. Each argument is rendered as paragraph.
**kwargs: 'title' becomes coloured headline. 'exits'=True performs sys exit. **kwargs: 'title' becomes coloured headline. exits=True performs sys exit.
""" """
exits = kwargs.get('exits', None) exits = kwargs.get('exits', None)
title = kwargs.get('title', None) title = kwargs.get('title', None)
@ -570,7 +572,8 @@ def _wrap(text, wrap_max=80, indent=4):
def minify_html(html): def minify_html(html):
"""Perform a template-specific, rudimentary HTML minification for displaCy. """Perform a template-specific, rudimentary HTML minification for displaCy.
Disclaimer: NOT a general-purpose solution, only removes indentation/newlines. Disclaimer: NOT a general-purpose solution, only removes indentation and
newlines.
html (unicode): Markup to minify. html (unicode): Markup to minify.
RETURNS (unicode): "Minified" HTML. RETURNS (unicode): "Minified" HTML.