2017-03-12 15:07:28 +03:00
|
|
|
|
# coding: utf8
|
2017-03-15 19:35:57 +03:00
|
|
|
|
from __future__ import unicode_literals, print_function
|
2017-04-15 13:05:47 +03:00
|
|
|
|
|
2017-05-18 12:36:53 +03:00
|
|
|
|
import os
|
2017-04-15 13:13:34 +03:00
|
|
|
|
import ujson
|
2017-05-08 00:24:51 +03:00
|
|
|
|
import pip
|
|
|
|
|
import importlib
|
2017-04-20 02:22:52 +03:00
|
|
|
|
import regex as re
|
2017-04-15 13:05:47 +03:00
|
|
|
|
from pathlib import Path
|
2017-03-16 19:08:58 +03:00
|
|
|
|
import sys
|
2017-03-15 19:35:57 +03:00
|
|
|
|
import textwrap
|
2017-05-21 17:05:05 +03:00
|
|
|
|
import random
|
2017-05-29 14:54:18 +03:00
|
|
|
|
import numpy
|
2017-05-29 16:40:45 +03:00
|
|
|
|
import io
|
2017-05-30 01:52:08 +03:00
|
|
|
|
import dill
|
2017-05-31 14:42:39 +03:00
|
|
|
|
from collections import OrderedDict
|
2017-03-15 19:35:57 +03:00
|
|
|
|
|
2017-05-29 11:13:42 +03:00
|
|
|
|
import msgpack
|
|
|
|
|
import msgpack_numpy
|
|
|
|
|
msgpack_numpy.patch()
|
2017-05-29 14:42:55 +03:00
|
|
|
|
import ujson
|
2017-05-29 11:13:42 +03:00
|
|
|
|
|
2017-05-08 16:42:12 +03:00
|
|
|
|
from .symbols import ORTH
|
2017-05-18 15:12:45 +03:00
|
|
|
|
from .compat import cupy, CudaStream, path2str, basestring_, input_, unicode_
|
2017-05-31 23:21:44 +03:00
|
|
|
|
from .compat import copy_array, normalize_string_keys, getattr_
|
2017-03-21 00:48:32 +03:00
|
|
|
|
|
|
|
|
|
|
2016-03-25 20:54:45 +03:00
|
|
|
|
LANGUAGES = {}
|
2017-04-15 13:05:47 +03:00
|
|
|
|
_data_path = Path(__file__).parent / 'data'
|
2016-03-25 20:54:45 +03:00
|
|
|
|
|
|
|
|
|
|
2017-05-14 02:31:10 +03:00
|
|
|
|
def get_lang_class(lang):
|
|
|
|
|
"""Import and load a Language class.
|
2016-03-25 20:54:45 +03:00
|
|
|
|
|
2017-05-14 02:31:10 +03:00
|
|
|
|
lang (unicode): Two-letter language code, e.g. 'en'.
|
|
|
|
|
RETURNS (Language): Language class.
|
|
|
|
|
"""
|
|
|
|
|
global LANGUAGES
|
|
|
|
|
if not lang in LANGUAGES:
|
|
|
|
|
try:
|
|
|
|
|
module = importlib.import_module('.lang.%s' % lang, 'spacy')
|
|
|
|
|
except ImportError:
|
|
|
|
|
raise ImportError("Can't import language %s from spacy.lang." %lang)
|
|
|
|
|
LANGUAGES[lang] = getattr(module, module.__all__[0])
|
2016-03-25 20:54:45 +03:00
|
|
|
|
return LANGUAGES[lang]
|
|
|
|
|
|
|
|
|
|
|
2017-05-14 02:31:10 +03:00
|
|
|
|
def set_lang_class(name, cls):
|
|
|
|
|
"""Set a custom Language class name that can be loaded via get_lang_class.
|
2017-05-13 22:22:49 +03:00
|
|
|
|
|
2017-05-14 02:31:10 +03:00
|
|
|
|
name (unicode): Name of Language class.
|
|
|
|
|
cls (Language): Language class.
|
2017-05-13 22:22:49 +03:00
|
|
|
|
"""
|
2017-05-14 02:31:10 +03:00
|
|
|
|
global LANGUAGES
|
|
|
|
|
LANGUAGES[name] = cls
|
2017-05-09 00:50:45 +03:00
|
|
|
|
|
|
|
|
|
|
2017-01-10 01:40:26 +03:00
|
|
|
|
def get_data_path(require_exists=True):
|
2017-05-13 22:22:49 +03:00
|
|
|
|
"""Get path to spaCy data directory.
|
|
|
|
|
|
2017-05-14 02:30:29 +03:00
|
|
|
|
require_exists (bool): Only return path if it exists, otherwise None.
|
|
|
|
|
RETURNS (Path or None): Data path or None.
|
2017-05-13 22:22:49 +03:00
|
|
|
|
"""
|
2017-01-10 01:40:26 +03:00
|
|
|
|
if not require_exists:
|
|
|
|
|
return _data_path
|
|
|
|
|
else:
|
|
|
|
|
return _data_path if _data_path.exists() else None
|
2016-09-24 21:26:17 +03:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def set_data_path(path):
|
2017-05-13 22:22:49 +03:00
|
|
|
|
"""Set path to spaCy data directory.
|
|
|
|
|
|
2017-05-14 02:30:29 +03:00
|
|
|
|
path (unicode or Path): Path to new data directory.
|
2017-05-13 22:22:49 +03:00
|
|
|
|
"""
|
2016-09-24 21:26:17 +03:00
|
|
|
|
global _data_path
|
2017-04-15 13:11:16 +03:00
|
|
|
|
_data_path = ensure_path(path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ensure_path(path):
|
2017-05-14 02:30:29 +03:00
|
|
|
|
"""Ensure string is converted to a Path.
|
|
|
|
|
|
|
|
|
|
path: Anything. If string, it's converted to Path.
|
|
|
|
|
RETURNS: Path or original argument.
|
|
|
|
|
"""
|
2017-04-15 13:11:16 +03:00
|
|
|
|
if isinstance(path, basestring_):
|
|
|
|
|
return Path(path)
|
|
|
|
|
else:
|
|
|
|
|
return path
|
2016-09-24 21:26:17 +03:00
|
|
|
|
|
|
|
|
|
|
2017-05-29 15:10:10 +03:00
|
|
|
|
def load_model(name, **overrides):
|
2017-05-28 01:22:00 +03:00
|
|
|
|
"""Load a model from a shortcut link, package or data path.
|
2017-05-13 22:22:49 +03:00
|
|
|
|
|
2017-05-14 02:30:29 +03:00
|
|
|
|
name (unicode): Package name, shortcut link or model path.
|
2017-05-29 15:10:10 +03:00
|
|
|
|
**overrides: Specific overrides, like pipeline components to disable.
|
2017-05-28 01:22:00 +03:00
|
|
|
|
RETURNS (Language): `Language` class with the loaded model.
|
2017-05-13 22:22:49 +03:00
|
|
|
|
"""
|
2017-05-09 00:51:15 +03:00
|
|
|
|
data_path = get_data_path()
|
|
|
|
|
if not data_path or not data_path.exists():
|
|
|
|
|
raise IOError("Can't find spaCy data path: %s" % path2str(data_path))
|
|
|
|
|
if isinstance(name, basestring_):
|
2017-05-28 01:22:00 +03:00
|
|
|
|
if (data_path / name).exists(): # in data dir or shortcut
|
2017-05-29 15:10:10 +03:00
|
|
|
|
spec = importlib.util.spec_from_file_location('model', data_path / name)
|
|
|
|
|
cls = importlib.util.module_from_spec(spec)
|
|
|
|
|
spec.loader.exec_module(cls)
|
|
|
|
|
return cls.load(**overrides)
|
2017-05-28 01:22:00 +03:00
|
|
|
|
if is_package(name): # installed as package
|
2017-05-29 15:10:10 +03:00
|
|
|
|
cls = importlib.import_module(name)
|
|
|
|
|
return cls.load(**overrides)
|
2017-05-28 01:22:00 +03:00
|
|
|
|
if Path(name).exists(): # path to model data directory
|
2017-05-29 15:10:10 +03:00
|
|
|
|
model_path = Path(name)
|
|
|
|
|
meta = get_package_meta(model_path)
|
|
|
|
|
cls = get_lang_class(meta['lang'])
|
2017-05-29 21:44:11 +03:00
|
|
|
|
nlp = cls(pipeline=meta.get('pipeline', True), meta=meta)
|
2017-05-29 15:10:10 +03:00
|
|
|
|
return nlp.from_disk(model_path, **overrides)
|
2017-05-28 01:22:00 +03:00
|
|
|
|
elif hasattr(name, 'exists'): # Path or Path-like to model data
|
2017-05-29 15:10:10 +03:00
|
|
|
|
meta = get_package_meta(name)
|
|
|
|
|
cls = get_lang_class(meta['lang'])
|
2017-05-29 21:44:11 +03:00
|
|
|
|
nlp = cls(pipeline=meta.get('pipeline', True), meta=meta)
|
2017-05-29 15:10:10 +03:00
|
|
|
|
return nlp.from_disk(name, **overrides)
|
2017-05-09 00:51:15 +03:00
|
|
|
|
raise IOError("Can't find model '%s'" % name)
|
|
|
|
|
|
|
|
|
|
|
2017-05-29 15:10:10 +03:00
|
|
|
|
def load_model_from_init_py(init_file, **overrides):
|
2017-05-28 01:22:00 +03:00
|
|
|
|
"""Helper function to use in the `load()` method of a model package's
|
|
|
|
|
__init__.py.
|
|
|
|
|
|
|
|
|
|
init_file (unicode): Path to model's __init__.py, i.e. `__file__`.
|
2017-05-29 15:10:10 +03:00
|
|
|
|
**overrides: Specific overrides, like pipeline components to disable.
|
2017-05-28 01:22:00 +03:00
|
|
|
|
RETURNS (Language): `Language` class with loaded model.
|
|
|
|
|
"""
|
|
|
|
|
model_path = Path(init_file).parent
|
2017-05-29 15:10:10 +03:00
|
|
|
|
meta = get_model_meta(model_path)
|
|
|
|
|
data_dir = '%s_%s-%s' % (meta['lang'], meta['name'], meta['version'])
|
|
|
|
|
data_path = model_path / data_dir
|
|
|
|
|
if not model_path.exists():
|
|
|
|
|
raise ValueError("Can't find model directory: %s" % path2str(data_path))
|
|
|
|
|
cls = get_lang_class(meta['lang'])
|
2017-05-29 21:44:11 +03:00
|
|
|
|
nlp = cls(pipeline=meta.get('pipeline', True), meta=meta)
|
2017-05-29 15:10:10 +03:00
|
|
|
|
return nlp.from_disk(data_path, **overrides)
|
2017-05-28 01:22:00 +03:00
|
|
|
|
|
|
|
|
|
|
2017-05-29 15:10:10 +03:00
|
|
|
|
def get_model_meta(path):
|
|
|
|
|
"""Get model meta.json from a directory path and validate its contents.
|
2017-05-28 01:22:00 +03:00
|
|
|
|
|
2017-05-29 15:10:10 +03:00
|
|
|
|
path (unicode or Path): Path to model directory.
|
|
|
|
|
RETURNS (dict): The model's meta data.
|
2017-05-28 01:22:00 +03:00
|
|
|
|
"""
|
2017-05-29 15:10:10 +03:00
|
|
|
|
model_path = ensure_path(path)
|
|
|
|
|
if not model_path.exists():
|
|
|
|
|
raise ValueError("Can't find model directory: %s" % path2str(model_path))
|
2017-05-28 01:22:00 +03:00
|
|
|
|
meta_path = model_path / 'meta.json'
|
|
|
|
|
if not meta_path.is_file():
|
2017-05-29 15:10:10 +03:00
|
|
|
|
raise IOError("Could not read meta.json from %s" % meta_path)
|
|
|
|
|
meta = read_json(meta_path)
|
2017-05-28 01:22:00 +03:00
|
|
|
|
for setting in ['lang', 'name', 'version']:
|
|
|
|
|
if setting not in meta:
|
|
|
|
|
raise IOError('No %s setting found in model meta.json' % setting)
|
2017-05-29 15:10:10 +03:00
|
|
|
|
return meta
|
2017-05-28 01:22:00 +03:00
|
|
|
|
|
|
|
|
|
|
2017-05-13 22:22:49 +03:00
|
|
|
|
def is_package(name):
|
|
|
|
|
"""Check if string maps to a package installed via pip.
|
|
|
|
|
|
2017-05-14 02:30:29 +03:00
|
|
|
|
name (unicode): Name of package.
|
|
|
|
|
RETURNS (bool): True if installed package, False if not.
|
2017-05-09 00:51:15 +03:00
|
|
|
|
"""
|
|
|
|
|
packages = pip.get_installed_distributions()
|
|
|
|
|
for package in packages:
|
2017-05-13 22:22:49 +03:00
|
|
|
|
if package.project_name.replace('-', '_') == name:
|
2017-05-09 00:51:15 +03:00
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
2017-05-28 01:22:00 +03:00
|
|
|
|
def get_package_path(name):
|
|
|
|
|
"""Get the path to an installed package.
|
2017-05-13 22:22:49 +03:00
|
|
|
|
|
2017-05-28 01:22:00 +03:00
|
|
|
|
name (unicode): Package name.
|
|
|
|
|
RETURNS (Path): Path to installed package.
|
2017-05-13 22:22:49 +03:00
|
|
|
|
"""
|
2017-05-09 00:51:15 +03:00
|
|
|
|
# Here we're importing the module just to find it. This is worryingly
|
|
|
|
|
# indirect, but it's otherwise very difficult to find the package.
|
2017-05-29 11:51:19 +03:00
|
|
|
|
pkg = importlib.import_module(name)
|
2017-05-28 01:22:00 +03:00
|
|
|
|
return Path(pkg.__file__).parent
|
2017-05-09 00:51:15 +03:00
|
|
|
|
|
|
|
|
|
|
2017-05-18 15:13:14 +03:00
|
|
|
|
def is_in_jupyter():
|
2017-05-21 02:12:09 +03:00
|
|
|
|
"""Check if user is running spaCy from a Jupyter notebook by detecting the
|
|
|
|
|
IPython kernel. Mainly used for the displaCy visualizer.
|
2017-05-18 15:13:14 +03:00
|
|
|
|
|
|
|
|
|
RETURNS (bool): True if in Jupyter, False if not.
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
cfg = get_ipython().config
|
|
|
|
|
if cfg['IPKernelApp']['parent_appname'] == 'ipython-notebook':
|
|
|
|
|
return True
|
|
|
|
|
except NameError:
|
|
|
|
|
return False
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
2017-05-14 01:37:53 +03:00
|
|
|
|
def get_cuda_stream(require=False):
|
|
|
|
|
# TODO: Error and tell to install chainer if not found
|
|
|
|
|
# Requires GPU
|
2017-05-15 22:46:08 +03:00
|
|
|
|
return CudaStream() if CudaStream is not None else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_async(stream, numpy_array):
|
|
|
|
|
if cupy is None:
|
|
|
|
|
return numpy_array
|
|
|
|
|
else:
|
2017-05-18 12:36:53 +03:00
|
|
|
|
array = cupy.ndarray(numpy_array.shape, order='C',
|
|
|
|
|
dtype=numpy_array.dtype)
|
|
|
|
|
array.set(numpy_array, stream=stream)
|
|
|
|
|
return array
|
|
|
|
|
|
2017-05-26 13:37:45 +03:00
|
|
|
|
|
2017-05-21 17:05:05 +03:00
|
|
|
|
def itershuffle(iterable, bufsize=1000):
|
|
|
|
|
"""Shuffle an iterator. This works by holding `bufsize` items back
|
2017-05-28 01:04:04 +03:00
|
|
|
|
and yielding them sometime later. Obviously, this is not unbiased –
|
2017-05-21 17:05:05 +03:00
|
|
|
|
but should be good enough for batching. Larger bufsize means less bias.
|
|
|
|
|
From https://gist.github.com/andres-erbsen/1307752
|
2017-05-28 01:04:04 +03:00
|
|
|
|
|
|
|
|
|
iterable (iterable): Iterator to shuffle.
|
|
|
|
|
bufsize (int): Items to hold back.
|
|
|
|
|
YIELDS (iterable): The shuffled iterator.
|
2017-05-21 17:05:05 +03:00
|
|
|
|
"""
|
|
|
|
|
iterable = iter(iterable)
|
|
|
|
|
buf = []
|
|
|
|
|
try:
|
|
|
|
|
while True:
|
|
|
|
|
for i in range(random.randint(1, bufsize-len(buf))):
|
|
|
|
|
buf.append(iterable.next())
|
|
|
|
|
random.shuffle(buf)
|
|
|
|
|
for i in range(random.randint(1, bufsize)):
|
|
|
|
|
if buf:
|
|
|
|
|
yield buf.pop()
|
|
|
|
|
else:
|
|
|
|
|
break
|
|
|
|
|
except StopIteration:
|
|
|
|
|
random.shuffle(buf)
|
|
|
|
|
while buf:
|
|
|
|
|
yield buf.pop()
|
|
|
|
|
raise StopIteration
|
|
|
|
|
|
2017-05-18 12:36:53 +03:00
|
|
|
|
|
2017-05-31 15:14:11 +03:00
|
|
|
|
_PRINT_ENV = False
|
|
|
|
|
def set_env_log(value):
|
|
|
|
|
global _PRINT_ENV
|
|
|
|
|
_PRINT_ENV = value
|
|
|
|
|
|
|
|
|
|
|
2017-05-18 12:36:53 +03:00
|
|
|
|
def env_opt(name, default=None):
|
2017-05-18 16:32:03 +03:00
|
|
|
|
if type(default) is float:
|
|
|
|
|
type_convert = float
|
2017-05-18 12:36:53 +03:00
|
|
|
|
else:
|
2017-05-18 16:32:03 +03:00
|
|
|
|
type_convert = int
|
|
|
|
|
if 'SPACY_' + name.upper() in os.environ:
|
|
|
|
|
value = type_convert(os.environ['SPACY_' + name.upper()])
|
2017-05-31 15:14:11 +03:00
|
|
|
|
if _PRINT_ENV:
|
|
|
|
|
print(name, "=", repr(value), "via", "$SPACY_" + name.upper())
|
2017-05-18 16:32:03 +03:00
|
|
|
|
return value
|
|
|
|
|
elif name in os.environ:
|
|
|
|
|
value = type_convert(os.environ[name])
|
2017-05-31 15:14:11 +03:00
|
|
|
|
if _PRINT_ENV:
|
|
|
|
|
print(name, "=", repr(value), "via", '$' + name)
|
2017-05-18 16:32:03 +03:00
|
|
|
|
return value
|
|
|
|
|
else:
|
2017-05-31 15:14:11 +03:00
|
|
|
|
if _PRINT_ENV:
|
|
|
|
|
print(name, '=', repr(default), "by default")
|
2017-05-18 12:36:53 +03:00
|
|
|
|
return default
|
2017-05-14 01:37:53 +03:00
|
|
|
|
|
|
|
|
|
|
2016-09-24 21:26:17 +03:00
|
|
|
|
def read_regex(path):
|
2017-04-15 13:11:16 +03:00
|
|
|
|
path = ensure_path(path)
|
2016-09-24 21:26:17 +03:00
|
|
|
|
with path.open() as file_:
|
|
|
|
|
entries = file_.read().split('\n')
|
|
|
|
|
expression = '|'.join(['^' + re.escape(piece) for piece in entries if piece.strip()])
|
|
|
|
|
return re.compile(expression)
|
|
|
|
|
|
|
|
|
|
|
2016-09-25 15:49:53 +03:00
|
|
|
|
def compile_prefix_regex(entries):
|
2017-01-08 22:33:28 +03:00
|
|
|
|
if '(' in entries:
|
|
|
|
|
# Handle deprecated data
|
|
|
|
|
expression = '|'.join(['^' + re.escape(piece) for piece in entries if piece.strip()])
|
|
|
|
|
return re.compile(expression)
|
|
|
|
|
else:
|
|
|
|
|
expression = '|'.join(['^' + piece for piece in entries if piece.strip()])
|
|
|
|
|
return re.compile(expression)
|
2016-09-24 21:26:17 +03:00
|
|
|
|
|
|
|
|
|
|
2016-09-25 15:49:53 +03:00
|
|
|
|
def compile_suffix_regex(entries):
|
2016-09-24 21:26:17 +03:00
|
|
|
|
expression = '|'.join([piece + '$' for piece in entries if piece.strip()])
|
|
|
|
|
return re.compile(expression)
|
|
|
|
|
|
|
|
|
|
|
2016-09-25 15:49:53 +03:00
|
|
|
|
def compile_infix_regex(entries):
|
2016-09-24 21:26:17 +03:00
|
|
|
|
expression = '|'.join([piece for piece in entries if piece.strip()])
|
|
|
|
|
return re.compile(expression)
|
|
|
|
|
|
|
|
|
|
|
2017-05-08 16:42:12 +03:00
|
|
|
|
def update_exc(base_exceptions, *addition_dicts):
|
2017-05-13 22:22:49 +03:00
|
|
|
|
"""Update and validate tokenizer exceptions. Will overwrite exceptions.
|
|
|
|
|
|
2017-05-14 02:30:29 +03:00
|
|
|
|
base_exceptions (dict): Base exceptions.
|
|
|
|
|
*addition_dicts (dict): Exceptions to add to the base dict, in order.
|
|
|
|
|
RETURNS (dict): Combined tokenizer exceptions.
|
2017-05-13 22:22:49 +03:00
|
|
|
|
"""
|
2017-05-08 16:42:12 +03:00
|
|
|
|
exc = dict(base_exceptions)
|
|
|
|
|
for additions in addition_dicts:
|
|
|
|
|
for orth, token_attrs in additions.items():
|
|
|
|
|
if not all(isinstance(attr[ORTH], unicode_) for attr in token_attrs):
|
|
|
|
|
msg = "Invalid value for ORTH in exception: key='%s', orths='%s'"
|
|
|
|
|
raise ValueError(msg % (orth, token_attrs))
|
|
|
|
|
described_orth = ''.join(attr[ORTH] for attr in token_attrs)
|
|
|
|
|
if orth != described_orth:
|
2017-05-13 22:22:49 +03:00
|
|
|
|
raise ValueError("Invalid tokenizer exception: ORTH values "
|
|
|
|
|
"combined don't match original string. "
|
|
|
|
|
"key='%s', orths='%s'" % (orth, described_orth))
|
2017-05-08 16:42:12 +03:00
|
|
|
|
# overlap = set(exc.keys()).intersection(set(additions))
|
|
|
|
|
# assert not overlap, overlap
|
|
|
|
|
exc.update(additions)
|
2017-05-13 22:22:25 +03:00
|
|
|
|
exc = expand_exc(exc, "'", "’")
|
2017-05-08 16:42:12 +03:00
|
|
|
|
return exc
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def expand_exc(excs, search, replace):
|
2017-05-13 22:22:49 +03:00
|
|
|
|
"""Find string in tokenizer exceptions, duplicate entry and replace string.
|
|
|
|
|
For example, to add additional versions with typographic apostrophes.
|
|
|
|
|
|
2017-05-14 02:30:29 +03:00
|
|
|
|
excs (dict): Tokenizer exceptions.
|
|
|
|
|
search (unicode): String to find and replace.
|
|
|
|
|
replace (unicode): Replacement.
|
|
|
|
|
RETURNS (dict): Combined tokenizer exceptions.
|
2017-05-13 22:22:49 +03:00
|
|
|
|
"""
|
2017-05-08 16:42:12 +03:00
|
|
|
|
def _fix_token(token, search, replace):
|
|
|
|
|
fixed = dict(token)
|
|
|
|
|
fixed[ORTH] = fixed[ORTH].replace(search, replace)
|
|
|
|
|
return fixed
|
2017-05-13 22:22:25 +03:00
|
|
|
|
new_excs = dict(excs)
|
2017-05-08 16:42:12 +03:00
|
|
|
|
for token_string, tokens in excs.items():
|
|
|
|
|
if search in token_string:
|
|
|
|
|
new_key = token_string.replace(search, replace)
|
|
|
|
|
new_value = [_fix_token(t, search, replace) for t in tokens]
|
2017-05-13 22:22:25 +03:00
|
|
|
|
new_excs[new_key] = new_value
|
|
|
|
|
return new_excs
|
2017-05-08 16:42:12 +03:00
|
|
|
|
|
|
|
|
|
|
2015-10-07 11:25:35 +03:00
|
|
|
|
def normalize_slice(length, start, stop, step=None):
|
|
|
|
|
if not (step is None or step == 1):
|
|
|
|
|
raise ValueError("Stepped slices not supported in Span objects."
|
|
|
|
|
"Try: list(tokens)[start:stop:step] instead.")
|
|
|
|
|
if start is None:
|
|
|
|
|
start = 0
|
|
|
|
|
elif start < 0:
|
|
|
|
|
start += length
|
|
|
|
|
start = min(length, max(0, start))
|
|
|
|
|
|
|
|
|
|
if stop is None:
|
|
|
|
|
stop = length
|
|
|
|
|
elif stop < 0:
|
|
|
|
|
stop += length
|
|
|
|
|
stop = min(length, max(start, stop))
|
|
|
|
|
|
|
|
|
|
assert 0 <= start <= stop <= length
|
|
|
|
|
return start, stop
|
|
|
|
|
|
|
|
|
|
|
2017-05-26 00:16:10 +03:00
|
|
|
|
def compounding(start, stop, compound):
|
2017-05-28 01:04:04 +03:00
|
|
|
|
"""Yield an infinite series of compounding values. Each time the
|
2017-05-26 00:16:10 +03:00
|
|
|
|
generator is called, a value is produced by multiplying the previous
|
|
|
|
|
value by the compound rate.
|
|
|
|
|
|
2017-05-28 01:04:04 +03:00
|
|
|
|
EXAMPLE:
|
2017-05-26 00:16:10 +03:00
|
|
|
|
>>> sizes = compounding(1., 10., 1.5)
|
|
|
|
|
>>> assert next(sizes) == 1.
|
|
|
|
|
>>> assert next(sizes) == 1 * 1.5
|
|
|
|
|
>>> assert next(sizes) == 1.5 * 1.5
|
2017-05-28 01:04:04 +03:00
|
|
|
|
"""
|
2017-05-26 00:16:10 +03:00
|
|
|
|
def clip(value):
|
2017-05-26 01:15:39 +03:00
|
|
|
|
return max(value, stop) if (start>stop) else min(value, stop)
|
2017-05-26 00:16:10 +03:00
|
|
|
|
curr = float(start)
|
|
|
|
|
while True:
|
|
|
|
|
yield clip(curr)
|
|
|
|
|
curr *= compound
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def decaying(start, stop, decay):
|
2017-05-28 01:04:04 +03:00
|
|
|
|
"""Yield an infinite series of linearly decaying values."""
|
2017-05-26 00:16:10 +03:00
|
|
|
|
def clip(value):
|
2017-05-26 01:15:39 +03:00
|
|
|
|
return max(value, stop) if (start>stop) else min(value, stop)
|
2017-05-26 00:16:10 +03:00
|
|
|
|
nr_upd = 1.
|
|
|
|
|
while True:
|
|
|
|
|
yield clip(start * 1./(1. + decay * nr_upd))
|
|
|
|
|
nr_upd += 1
|
|
|
|
|
|
|
|
|
|
|
2017-04-16 14:03:28 +03:00
|
|
|
|
def read_json(location):
|
2017-05-13 22:22:49 +03:00
|
|
|
|
"""Open and load JSON from file.
|
|
|
|
|
|
2017-05-14 02:30:29 +03:00
|
|
|
|
location (Path): Path to JSON file.
|
|
|
|
|
RETURNS (dict): Loaded JSON content.
|
2017-05-13 22:22:49 +03:00
|
|
|
|
"""
|
2017-04-16 14:03:28 +03:00
|
|
|
|
with location.open('r', encoding='utf8') as f:
|
|
|
|
|
return ujson.load(f)
|
|
|
|
|
|
|
|
|
|
|
2017-03-21 00:48:56 +03:00
|
|
|
|
def get_raw_input(description, default=False):
|
2017-05-13 22:22:49 +03:00
|
|
|
|
"""Get user input from the command line via raw_input / input.
|
|
|
|
|
|
2017-05-14 02:30:29 +03:00
|
|
|
|
description (unicode): Text to display before prompt.
|
|
|
|
|
default (unicode or False/None): Default value to display with prompt.
|
|
|
|
|
RETURNS (unicode): User input.
|
2017-04-16 14:42:34 +03:00
|
|
|
|
"""
|
2017-05-08 00:25:29 +03:00
|
|
|
|
additional = ' (default: %s)' % default if default else ''
|
|
|
|
|
prompt = ' %s%s: ' % (description, additional)
|
2017-04-15 13:11:16 +03:00
|
|
|
|
user_input = input_(prompt)
|
2017-03-21 00:48:56 +03:00
|
|
|
|
return user_input
|
|
|
|
|
|
|
|
|
|
|
2017-05-29 11:13:42 +03:00
|
|
|
|
def to_bytes(getters, exclude):
|
2017-05-31 14:42:39 +03:00
|
|
|
|
serialized = OrderedDict()
|
2017-05-29 11:13:42 +03:00
|
|
|
|
for key, getter in getters.items():
|
|
|
|
|
if key not in exclude:
|
|
|
|
|
serialized[key] = getter()
|
2017-05-29 12:45:45 +03:00
|
|
|
|
return msgpack.dumps(serialized)
|
2017-05-29 11:13:42 +03:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def from_bytes(bytes_data, setters, exclude):
|
2017-05-29 12:45:45 +03:00
|
|
|
|
msg = msgpack.loads(bytes_data)
|
2017-05-29 11:13:42 +03:00
|
|
|
|
for key, setter in setters.items():
|
|
|
|
|
if key not in exclude:
|
|
|
|
|
setter(msg[key])
|
|
|
|
|
return msg
|
|
|
|
|
|
|
|
|
|
|
2017-05-31 14:42:39 +03:00
|
|
|
|
def to_disk(path, writers, exclude):
|
|
|
|
|
path = ensure_path(path)
|
|
|
|
|
if not path.exists():
|
|
|
|
|
path.mkdir()
|
|
|
|
|
for key, writer in writers.items():
|
|
|
|
|
if key not in exclude:
|
|
|
|
|
writer(path / key)
|
|
|
|
|
return path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def from_disk(path, readers, exclude):
|
|
|
|
|
path = ensure_path(path)
|
|
|
|
|
for key, reader in readers.items():
|
|
|
|
|
if key not in exclude:
|
|
|
|
|
reader(path / key)
|
|
|
|
|
return path
|
|
|
|
|
|
|
|
|
|
|
2017-05-30 01:52:08 +03:00
|
|
|
|
# This stuff really belongs in thinc -- but I expect
|
|
|
|
|
# to refactor how all this works in thinc anyway.
|
|
|
|
|
# What a mess!
|
2017-05-29 11:13:42 +03:00
|
|
|
|
def model_to_bytes(model):
|
|
|
|
|
weights = []
|
|
|
|
|
queue = [model]
|
|
|
|
|
i = 0
|
|
|
|
|
for layer in queue:
|
|
|
|
|
if hasattr(layer, '_mem'):
|
2017-05-31 15:14:11 +03:00
|
|
|
|
weights.append({
|
|
|
|
|
'dims': normalize_string_keys(getattr(layer, '_dims', {})),
|
|
|
|
|
'params': []})
|
2017-05-30 01:52:08 +03:00
|
|
|
|
if hasattr(layer, 'seed'):
|
|
|
|
|
weights[-1]['seed'] = layer.seed
|
|
|
|
|
|
|
|
|
|
for (id_, name), (start, row, shape) in layer._mem._offsets.items():
|
|
|
|
|
if row == 1:
|
|
|
|
|
continue
|
|
|
|
|
param = layer._mem.get((id_, name))
|
|
|
|
|
if not isinstance(layer._mem.weights, numpy.ndarray):
|
|
|
|
|
param = param.get()
|
|
|
|
|
weights[-1]['params'].append(
|
|
|
|
|
{
|
|
|
|
|
'name': name,
|
|
|
|
|
'offset': start,
|
|
|
|
|
'shape': shape,
|
|
|
|
|
'value': param,
|
|
|
|
|
}
|
|
|
|
|
)
|
2017-05-29 11:13:42 +03:00
|
|
|
|
i += 1
|
|
|
|
|
if hasattr(layer, '_layers'):
|
|
|
|
|
queue.extend(layer._layers)
|
2017-05-30 01:52:08 +03:00
|
|
|
|
return msgpack.dumps({'weights': weights})
|
2017-05-29 11:13:42 +03:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def model_from_bytes(model, bytes_data):
|
|
|
|
|
data = msgpack.loads(bytes_data)
|
|
|
|
|
weights = data['weights']
|
|
|
|
|
queue = [model]
|
|
|
|
|
i = 0
|
|
|
|
|
for layer in queue:
|
|
|
|
|
if hasattr(layer, '_mem'):
|
2017-05-30 01:52:08 +03:00
|
|
|
|
if 'seed' in weights[i]:
|
|
|
|
|
layer.seed = weights[i]['seed']
|
|
|
|
|
for dim, value in weights[i]['dims'].items():
|
|
|
|
|
setattr(layer, dim, value)
|
|
|
|
|
for param in weights[i]['params']:
|
2017-05-31 23:21:44 +03:00
|
|
|
|
dest = getattr_(layer, param['name'])
|
2017-05-31 15:14:11 +03:00
|
|
|
|
copy_array(dest, param['value'])
|
2017-05-29 11:13:42 +03:00
|
|
|
|
i += 1
|
|
|
|
|
if hasattr(layer, '_layers'):
|
|
|
|
|
queue.extend(layer._layers)
|
2017-05-29 21:23:11 +03:00
|
|
|
|
|
2017-05-29 02:37:57 +03:00
|
|
|
|
|
2017-05-08 00:25:29 +03:00
|
|
|
|
def print_table(data, title=None):
|
2017-05-13 22:22:49 +03:00
|
|
|
|
"""Print data in table format.
|
|
|
|
|
|
2017-05-14 02:30:29 +03:00
|
|
|
|
data (dict or list of tuples): Label/value pairs.
|
|
|
|
|
title (unicode or None): Title, will be printed above.
|
2017-04-16 14:42:34 +03:00
|
|
|
|
"""
|
2017-05-13 22:22:49 +03:00
|
|
|
|
if isinstance(data, dict):
|
2017-03-18 15:00:14 +03:00
|
|
|
|
data = list(data.items())
|
2017-05-08 00:25:29 +03:00
|
|
|
|
tpl_row = ' {:<15}' * len(data[0])
|
2017-03-18 15:00:14 +03:00
|
|
|
|
table = '\n'.join([tpl_row.format(l, v) for l, v in data])
|
2017-05-08 00:25:29 +03:00
|
|
|
|
if title:
|
|
|
|
|
print('\n \033[93m{}\033[0m'.format(title))
|
|
|
|
|
print('\n{}\n'.format(table))
|
2017-03-18 15:00:14 +03:00
|
|
|
|
|
|
|
|
|
|
2017-05-08 00:25:29 +03:00
|
|
|
|
def print_markdown(data, title=None):
|
2017-05-13 22:22:49 +03:00
|
|
|
|
"""Print data in GitHub-flavoured Markdown format for issues etc.
|
|
|
|
|
|
2017-05-14 02:30:29 +03:00
|
|
|
|
data (dict or list of tuples): Label/value pairs.
|
|
|
|
|
title (unicode or None): Title, will be rendered as headline 2.
|
2017-04-16 14:42:34 +03:00
|
|
|
|
"""
|
2017-03-18 15:00:14 +03:00
|
|
|
|
def excl_value(value):
|
2017-05-08 00:25:29 +03:00
|
|
|
|
return Path(value).exists() # contains path (personal info)
|
2017-03-18 15:00:14 +03:00
|
|
|
|
|
2017-05-13 22:22:49 +03:00
|
|
|
|
if isinstance(data, dict):
|
2017-03-18 15:00:14 +03:00
|
|
|
|
data = list(data.items())
|
2017-05-08 00:25:29 +03:00
|
|
|
|
markdown = ["* **{}:** {}".format(l, v) for l, v in data if not excl_value(v)]
|
|
|
|
|
if title:
|
|
|
|
|
print("\n## {}".format(title))
|
|
|
|
|
print('\n{}\n'.format('\n'.join(markdown)))
|
2017-03-18 15:00:14 +03:00
|
|
|
|
|
|
|
|
|
|
2017-05-08 03:00:37 +03:00
|
|
|
|
def prints(*texts, **kwargs):
|
2017-05-13 22:22:49 +03:00
|
|
|
|
"""Print formatted message (manual ANSI escape sequences to avoid dependency)
|
|
|
|
|
|
2017-05-14 02:30:29 +03:00
|
|
|
|
*texts (unicode): Texts to print. Each argument is rendered as paragraph.
|
|
|
|
|
**kwargs: 'title' becomes coloured headline. 'exits'=True performs sys exit.
|
2017-04-16 14:42:34 +03:00
|
|
|
|
"""
|
2017-05-22 13:29:15 +03:00
|
|
|
|
exits = kwargs.get('exits', None)
|
2017-05-08 02:05:24 +03:00
|
|
|
|
title = kwargs.get('title', None)
|
2017-05-08 00:25:29 +03:00
|
|
|
|
title = '\033[93m{}\033[0m\n'.format(_wrap(title)) if title else ''
|
|
|
|
|
message = '\n\n'.join([_wrap(text) for text in texts])
|
|
|
|
|
print('\n{}{}\n'.format(title, message))
|
2017-05-22 13:29:15 +03:00
|
|
|
|
if exits is not None:
|
|
|
|
|
sys.exit(exits)
|
2017-03-15 19:35:57 +03:00
|
|
|
|
|
|
|
|
|
|
2017-05-08 00:25:29 +03:00
|
|
|
|
def _wrap(text, wrap_max=80, indent=4):
|
2017-05-13 22:22:49 +03:00
|
|
|
|
"""Wrap text at given width using textwrap module.
|
|
|
|
|
|
2017-05-14 02:30:29 +03:00
|
|
|
|
text (unicode): Text to wrap. If it's a Path, it's converted to string.
|
|
|
|
|
wrap_max (int): Maximum line length (indent is deducted).
|
|
|
|
|
indent (int): Number of spaces for indentation.
|
|
|
|
|
RETURNS (unicode): Wrapped text.
|
2017-04-16 14:42:34 +03:00
|
|
|
|
"""
|
2017-05-08 00:25:29 +03:00
|
|
|
|
indent = indent * ' '
|
2017-03-15 19:35:57 +03:00
|
|
|
|
wrap_width = wrap_max - len(indent)
|
2017-05-08 00:25:29 +03:00
|
|
|
|
if isinstance(text, Path):
|
|
|
|
|
text = path2str(text)
|
2017-03-15 19:35:57 +03:00
|
|
|
|
return textwrap.fill(text, width=wrap_width, initial_indent=indent,
|
2017-05-08 00:25:29 +03:00
|
|
|
|
subsequent_indent=indent, break_long_words=False,
|
|
|
|
|
break_on_hyphens=False)
|
2017-05-14 18:50:23 +03:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def minify_html(html):
|
|
|
|
|
"""Perform a template-specific, rudimentary HTML minification for displaCy.
|
|
|
|
|
Disclaimer: NOT a general-purpose solution, only removes indentation/newlines.
|
|
|
|
|
|
|
|
|
|
html (unicode): Markup to minify.
|
|
|
|
|
RETURNS (unicode): "Minified" HTML.
|
|
|
|
|
"""
|
|
|
|
|
return html.strip().replace(' ', '').replace('\n', '')
|