1
1
mirror of https://github.com/explosion/spaCy.git synced 2025-01-18 05:24:12 +03:00
spaCy/spacy/util.py

626 lines
20 KiB
Python
Raw Normal View History

2017-03-12 15:07:28 +03:00
# coding: utf8
from __future__ import unicode_literals, print_function
import os
2017-04-15 13:13:34 +03:00
import ujson
import pkg_resources
import importlib
2017-04-20 02:22:52 +03:00
import regex as re
from pathlib import Path
2017-03-16 19:08:58 +03:00
import sys
import textwrap
import random
2017-05-31 14:42:39 +03:00
from collections import OrderedDict
import inspect
import warnings
2017-09-21 03:16:35 +03:00
from thinc.neural._classes.model import Model
import functools
2017-11-07 15:20:12 +03:00
import cytoolz
2017-11-10 21:05:18 +03:00
import itertools
2017-10-27 15:39:09 +03:00
from .symbols import ORTH
from .compat import cupy, CudaStream, path2str, basestring_, input_, unicode_
from .compat import import_file
2017-05-29 11:13:42 +03:00
import msgpack
import msgpack_numpy
msgpack_numpy.patch()
2016-03-25 20:54:45 +03:00
LANGUAGES = {}
_data_path = Path(__file__).parent / 'data'
2017-10-27 15:39:09 +03:00
_PRINT_ENV = False
def set_env_log(value):
global _PRINT_ENV
_PRINT_ENV = value
2016-03-25 20:54:45 +03:00
def get_lang_class(lang):
"""Import and load a Language class.
2016-03-25 20:54:45 +03:00
lang (unicode): Two-letter language code, e.g. 'en'.
RETURNS (Language): Language class.
"""
global LANGUAGES
2017-10-27 15:39:09 +03:00
if lang not in LANGUAGES:
try:
module = importlib.import_module('.lang.%s' % lang, 'spacy')
except ImportError:
2017-10-27 15:39:09 +03:00
msg = "Can't import language %s from spacy.lang."
raise ImportError(msg % lang)
LANGUAGES[lang] = getattr(module, module.__all__[0])
2016-03-25 20:54:45 +03:00
return LANGUAGES[lang]
def set_lang_class(name, cls):
"""Set a custom Language class name that can be loaded via get_lang_class.
name (unicode): Name of Language class.
cls (Language): Language class.
"""
global LANGUAGES
LANGUAGES[name] = cls
2017-05-09 00:50:45 +03:00
2017-01-10 01:40:26 +03:00
def get_data_path(require_exists=True):
"""Get path to spaCy data directory.
2017-05-14 02:30:29 +03:00
require_exists (bool): Only return path if it exists, otherwise None.
RETURNS (Path or None): Data path or None.
"""
2017-01-10 01:40:26 +03:00
if not require_exists:
return _data_path
else:
return _data_path if _data_path.exists() else None
2016-09-24 21:26:17 +03:00
def set_data_path(path):
"""Set path to spaCy data directory.
2017-05-14 02:30:29 +03:00
path (unicode or Path): Path to new data directory.
"""
2016-09-24 21:26:17 +03:00
global _data_path
_data_path = ensure_path(path)
def ensure_path(path):
2017-05-14 02:30:29 +03:00
"""Ensure string is converted to a Path.
path: Anything. If string, it's converted to Path.
RETURNS: Path or original argument.
"""
if isinstance(path, basestring_):
return Path(path)
else:
return path
2016-09-24 21:26:17 +03:00
def load_model(name, **overrides):
"""Load a model from a shortcut link, package or data path.
2017-05-14 02:30:29 +03:00
name (unicode): Package name, shortcut link or model path.
**overrides: Specific overrides, like pipeline components to disable.
RETURNS (Language): `Language` class with the loaded model.
"""
2017-05-09 00:51:15 +03:00
data_path = get_data_path()
if not data_path or not data_path.exists():
raise IOError("Can't find spaCy data path: %s" % path2str(data_path))
2017-10-27 15:39:09 +03:00
if isinstance(name, basestring_): # in data dir / shortcut
if name in set([d.name for d in data_path.iterdir()]):
2017-06-05 14:02:31 +03:00
return load_model_from_link(name, **overrides)
2017-10-27 15:39:09 +03:00
if is_package(name): # installed as package
2017-06-05 14:02:31 +03:00
return load_model_from_package(name, **overrides)
2017-10-27 15:39:09 +03:00
if Path(name).exists(): # path to model data directory
2017-06-05 14:02:31 +03:00
return load_model_from_path(Path(name), **overrides)
2017-10-27 15:39:09 +03:00
elif hasattr(name, 'exists'): # Path or Path-like to model data
2017-06-05 14:02:31 +03:00
return load_model_from_path(name, **overrides)
2017-05-09 00:51:15 +03:00
raise IOError("Can't find model '%s'" % name)
2017-06-05 14:02:31 +03:00
def load_model_from_link(name, **overrides):
"""Load a model from a shortcut link, or directory in spaCy data path."""
2017-08-18 22:57:06 +03:00
path = get_data_path() / name / '__init__.py'
2017-06-05 14:02:31 +03:00
try:
2017-08-18 22:57:06 +03:00
cls = import_file(name, path)
2017-06-05 14:02:31 +03:00
except AttributeError:
raise IOError(
"Cant' load '%s'. If you're using a shortcut link, make sure it "
2017-10-27 15:39:09 +03:00
"points to a valid package (not just a data directory)." % name)
2017-06-05 14:02:31 +03:00
return cls.load(**overrides)
def load_model_from_package(name, **overrides):
"""Load a model from an installed package."""
cls = importlib.import_module(name)
return cls.load(**overrides)
def load_model_from_path(model_path, meta=False, **overrides):
"""Load a model from a data directory path. Creates Language class with
pipeline from meta.json and then calls from_disk() with path."""
if not meta:
meta = get_model_meta(model_path)
cls = get_lang_class(meta['lang'])
nlp = cls(meta=meta, **overrides)
pipeline = meta.get('pipeline', [])
disable = overrides.get('disable', [])
if pipeline is True:
pipeline = nlp.Defaults.pipe_names
elif pipeline in (False, None):
pipeline = []
for name in pipeline:
if name not in disable:
config = meta.get('pipeline_args', {}).get(name, {})
component = nlp.create_pipe(name, config=config)
nlp.add_pipe(component, name=name)
2017-06-05 14:02:31 +03:00
return nlp.from_disk(model_path)
def load_model_from_init_py(init_file, **overrides):
"""Helper function to use in the `load()` method of a model package's
__init__.py.
init_file (unicode): Path to model's __init__.py, i.e. `__file__`.
**overrides: Specific overrides, like pipeline components to disable.
RETURNS (Language): `Language` class with loaded model.
"""
model_path = Path(init_file).parent
meta = get_model_meta(model_path)
data_dir = '%s_%s-%s' % (meta['lang'], meta['name'], meta['version'])
data_path = model_path / data_dir
if not model_path.exists():
2017-10-27 15:39:09 +03:00
msg = "Can't find model directory: %s"
raise ValueError(msg % path2str(data_path))
2017-06-05 14:02:31 +03:00
return load_model_from_path(data_path, meta, **overrides)
def get_model_meta(path):
"""Get model meta.json from a directory path and validate its contents.
path (unicode or Path): Path to model directory.
RETURNS (dict): The model's meta data.
"""
model_path = ensure_path(path)
if not model_path.exists():
2017-10-27 15:39:09 +03:00
msg = "Can't find model directory: %s"
raise ValueError(msg % path2str(model_path))
meta_path = model_path / 'meta.json'
if not meta_path.is_file():
raise IOError("Could not read meta.json from %s" % meta_path)
meta = read_json(meta_path)
for setting in ['lang', 'name', 'version']:
2017-08-29 12:21:44 +03:00
if setting not in meta or not meta[setting]:
2017-10-27 15:39:09 +03:00
msg = "No valid '%s' setting found in model meta.json"
raise ValueError(msg % setting)
return meta
def is_package(name):
"""Check if string maps to a package installed via pip.
2017-05-14 02:30:29 +03:00
name (unicode): Name of package.
RETURNS (bool): True if installed package, False if not.
2017-05-09 00:51:15 +03:00
"""
name = name.lower() # compare package name against lowercase name
packages = pkg_resources.working_set.by_key.keys()
2017-05-09 00:51:15 +03:00
for package in packages:
if package.lower().replace('-', '_') == name:
2017-05-09 00:51:15 +03:00
return True
return False
def get_package_path(name):
"""Get the path to an installed package.
name (unicode): Package name.
RETURNS (Path): Path to installed package.
"""
name = name.lower() # use lowercase version to be safe
2017-05-09 00:51:15 +03:00
# Here we're importing the module just to find it. This is worryingly
# indirect, but it's otherwise very difficult to find the package.
pkg = importlib.import_module(name)
return Path(pkg.__file__).parent
2017-05-09 00:51:15 +03:00
def is_in_jupyter():
2017-05-21 02:12:09 +03:00
"""Check if user is running spaCy from a Jupyter notebook by detecting the
IPython kernel. Mainly used for the displaCy visualizer.
RETURNS (bool): True if in Jupyter, False if not.
"""
try:
cfg = get_ipython().config
if cfg['IPKernelApp']['parent_appname'] == 'ipython-notebook':
return True
except NameError:
return False
return False
def get_cuda_stream(require=False):
return CudaStream() if CudaStream is not None else None
def get_async(stream, numpy_array):
if cupy is None:
return numpy_array
else:
array = cupy.ndarray(numpy_array.shape, order='C',
2017-10-27 15:39:09 +03:00
dtype=numpy_array.dtype)
array.set(numpy_array, stream=stream)
return array
2017-05-26 13:37:45 +03:00
def env_opt(name, default=None):
2017-05-18 16:32:03 +03:00
if type(default) is float:
type_convert = float
else:
2017-05-18 16:32:03 +03:00
type_convert = int
if 'SPACY_' + name.upper() in os.environ:
value = type_convert(os.environ['SPACY_' + name.upper()])
if _PRINT_ENV:
print(name, "=", repr(value), "via", "$SPACY_" + name.upper())
2017-05-18 16:32:03 +03:00
return value
elif name in os.environ:
value = type_convert(os.environ[name])
if _PRINT_ENV:
print(name, "=", repr(value), "via", '$' + name)
2017-05-18 16:32:03 +03:00
return value
else:
if _PRINT_ENV:
print(name, '=', repr(default), "by default")
return default
2016-09-24 21:26:17 +03:00
def read_regex(path):
path = ensure_path(path)
2016-09-24 21:26:17 +03:00
with path.open() as file_:
entries = file_.read().split('\n')
2017-10-27 15:39:09 +03:00
expression = '|'.join(['^' + re.escape(piece)
for piece in entries if piece.strip()])
2016-09-24 21:26:17 +03:00
return re.compile(expression)
def compile_prefix_regex(entries):
if '(' in entries:
# Handle deprecated data
2017-10-27 15:39:09 +03:00
expression = '|'.join(['^' + re.escape(piece)
for piece in entries if piece.strip()])
return re.compile(expression)
else:
2017-10-27 15:39:09 +03:00
expression = '|'.join(['^' + piece
for piece in entries if piece.strip()])
return re.compile(expression)
2016-09-24 21:26:17 +03:00
def compile_suffix_regex(entries):
2016-09-24 21:26:17 +03:00
expression = '|'.join([piece + '$' for piece in entries if piece.strip()])
return re.compile(expression)
def compile_infix_regex(entries):
2016-09-24 21:26:17 +03:00
expression = '|'.join([piece for piece in entries if piece.strip()])
return re.compile(expression)
2017-06-03 20:44:47 +03:00
def add_lookups(default_func, *lookups):
"""Extend an attribute function with special cases. If a word is in the
lookups, the value is returned. Otherwise the previous function is used.
default_func (callable): The default function to execute.
*lookups (dict): Lookup dictionary mapping string to attribute value.
RETURNS (callable): Lexical attribute getter.
"""
# This is implemented as functools.partial instead of a closure, to allow
# pickle to work.
return functools.partial(_get_attr_unless_lookup, default_func, lookups)
def _get_attr_unless_lookup(default_func, lookups, string):
for lookup in lookups:
if string in lookup:
return lookup[string]
return default_func(string)
2017-06-03 20:44:47 +03:00
def update_exc(base_exceptions, *addition_dicts):
"""Update and validate tokenizer exceptions. Will overwrite exceptions.
2017-05-14 02:30:29 +03:00
base_exceptions (dict): Base exceptions.
*addition_dicts (dict): Exceptions to add to the base dict, in order.
RETURNS (dict): Combined tokenizer exceptions.
"""
exc = dict(base_exceptions)
for additions in addition_dicts:
for orth, token_attrs in additions.items():
2017-10-27 15:39:09 +03:00
if not all(isinstance(attr[ORTH], unicode_)
for attr in token_attrs):
msg = "Invalid ORTH value in exception: key='%s', orths='%s'"
raise ValueError(msg % (orth, token_attrs))
described_orth = ''.join(attr[ORTH] for attr in token_attrs)
if orth != described_orth:
2017-10-27 15:39:09 +03:00
msg = ("Invalid tokenizer exception: ORTH values combined "
"don't match original string. key='%s', orths='%s'")
raise ValueError(msg % (orth, described_orth))
exc.update(additions)
exc = expand_exc(exc, "'", "")
return exc
def expand_exc(excs, search, replace):
"""Find string in tokenizer exceptions, duplicate entry and replace string.
For example, to add additional versions with typographic apostrophes.
2017-05-14 02:30:29 +03:00
excs (dict): Tokenizer exceptions.
search (unicode): String to find and replace.
replace (unicode): Replacement.
RETURNS (dict): Combined tokenizer exceptions.
"""
def _fix_token(token, search, replace):
fixed = dict(token)
fixed[ORTH] = fixed[ORTH].replace(search, replace)
return fixed
new_excs = dict(excs)
for token_string, tokens in excs.items():
if search in token_string:
new_key = token_string.replace(search, replace)
new_value = [_fix_token(t, search, replace) for t in tokens]
new_excs[new_key] = new_value
return new_excs
def normalize_slice(length, start, stop, step=None):
if not (step is None or step == 1):
raise ValueError("Stepped slices not supported in Span objects."
"Try: list(tokens)[start:stop:step] instead.")
if start is None:
2017-10-27 15:39:09 +03:00
start = 0
elif start < 0:
2017-10-27 15:39:09 +03:00
start += length
start = min(length, max(0, start))
if stop is None:
2017-10-27 15:39:09 +03:00
stop = length
elif stop < 0:
2017-10-27 15:39:09 +03:00
stop += length
stop = min(length, max(start, stop))
assert 0 <= start <= stop <= length
return start, stop
2017-11-07 01:45:36 +03:00
def minibatch(items, size=8):
"""Iterate over batches of items. `size` may be an iterator,
so that batch-size can vary on each step.
"""
if isinstance(size, int):
2017-11-07 02:22:43 +03:00
size_ = itertools.repeat(size)
2017-11-07 01:45:36 +03:00
else:
size_ = size
items = iter(items)
while True:
batch_size = next(size_)
batch = list(cytoolz.take(int(batch_size), items))
if len(batch) == 0:
break
yield list(batch)
2017-05-26 00:16:10 +03:00
def compounding(start, stop, compound):
"""Yield an infinite series of compounding values. Each time the
2017-05-26 00:16:10 +03:00
generator is called, a value is produced by multiplying the previous
value by the compound rate.
EXAMPLE:
2017-05-26 00:16:10 +03:00
>>> sizes = compounding(1., 10., 1.5)
>>> assert next(sizes) == 1.
>>> assert next(sizes) == 1 * 1.5
>>> assert next(sizes) == 1.5 * 1.5
"""
2017-05-26 00:16:10 +03:00
def clip(value):
2017-10-27 15:39:09 +03:00
return max(value, stop) if (start > stop) else min(value, stop)
2017-05-26 00:16:10 +03:00
curr = float(start)
while True:
yield clip(curr)
curr *= compound
def decaying(start, stop, decay):
"""Yield an infinite series of linearly decaying values."""
2017-05-26 00:16:10 +03:00
def clip(value):
2017-10-27 15:39:09 +03:00
return max(value, stop) if (start > stop) else min(value, stop)
2017-05-26 00:16:10 +03:00
nr_upd = 1.
while True:
yield clip(start * 1./(1. + decay * nr_upd))
nr_upd += 1
2017-11-07 01:45:36 +03:00
def itershuffle(iterable, bufsize=1000):
"""Shuffle an iterator. This works by holding `bufsize` items back
and yielding them sometime later. Obviously, this is not unbiased
but should be good enough for batching. Larger bufsize means less bias.
From https://gist.github.com/andres-erbsen/1307752
iterable (iterable): Iterator to shuffle.
bufsize (int): Items to hold back.
YIELDS (iterable): The shuffled iterator.
"""
iterable = iter(iterable)
buf = []
try:
while True:
for i in range(random.randint(1, bufsize-len(buf))):
buf.append(iterable.next())
random.shuffle(buf)
for i in range(random.randint(1, bufsize)):
if buf:
yield buf.pop()
else:
break
except StopIteration:
random.shuffle(buf)
while buf:
yield buf.pop()
raise StopIteration
def read_json(location):
"""Open and load JSON from file.
2017-05-14 02:30:29 +03:00
location (Path): Path to JSON file.
RETURNS (dict): Loaded JSON content.
"""
2017-06-04 21:44:37 +03:00
location = ensure_path(location)
with location.open('r', encoding='utf8') as f:
return ujson.load(f)
def get_raw_input(description, default=False):
"""Get user input from the command line via raw_input / input.
2017-05-14 02:30:29 +03:00
description (unicode): Text to display before prompt.
default (unicode or False/None): Default value to display with prompt.
RETURNS (unicode): User input.
2017-04-16 14:42:34 +03:00
"""
2017-05-08 00:25:29 +03:00
additional = ' (default: %s)' % default if default else ''
prompt = ' %s%s: ' % (description, additional)
user_input = input_(prompt)
return user_input
2017-05-29 11:13:42 +03:00
def to_bytes(getters, exclude):
2017-05-31 14:42:39 +03:00
serialized = OrderedDict()
2017-05-29 11:13:42 +03:00
for key, getter in getters.items():
if key not in exclude:
serialized[key] = getter()
2017-06-01 18:48:43 +03:00
return msgpack.dumps(serialized, use_bin_type=True, encoding='utf8')
2017-05-29 11:13:42 +03:00
def from_bytes(bytes_data, setters, exclude):
2017-06-01 18:48:43 +03:00
msg = msgpack.loads(bytes_data, encoding='utf8')
2017-05-29 11:13:42 +03:00
for key, setter in setters.items():
2017-06-02 19:18:17 +03:00
if key not in exclude and key in msg:
2017-05-29 11:13:42 +03:00
setter(msg[key])
return msg
2017-05-31 14:42:39 +03:00
def to_disk(path, writers, exclude):
path = ensure_path(path)
if not path.exists():
path.mkdir()
for key, writer in writers.items():
if key not in exclude:
writer(path / key)
return path
def from_disk(path, readers, exclude):
path = ensure_path(path)
for key, reader in readers.items():
if key not in exclude:
2017-10-16 21:55:00 +03:00
reader(path / key)
2017-05-31 14:42:39 +03:00
return path
def deprecated(message, filter='always'):
"""Show a deprecation warning.
message (unicode): The message to display.
filter (unicode): Filter value.
"""
stack = inspect.stack()[-1]
with warnings.catch_warnings():
warnings.simplefilter(filter, DeprecationWarning)
warnings.warn_explicit(message, DeprecationWarning, stack[1], stack[2])
2017-05-08 00:25:29 +03:00
def print_table(data, title=None):
"""Print data in table format.
2017-05-14 02:30:29 +03:00
data (dict or list of tuples): Label/value pairs.
title (unicode or None): Title, will be printed above.
2017-04-16 14:42:34 +03:00
"""
if isinstance(data, dict):
data = list(data.items())
2017-05-08 00:25:29 +03:00
tpl_row = ' {:<15}' * len(data[0])
table = '\n'.join([tpl_row.format(l, unicode_(v)) for l, v in data])
2017-05-08 00:25:29 +03:00
if title:
print('\n \033[93m{}\033[0m'.format(title))
print('\n{}\n'.format(table))
2017-05-08 00:25:29 +03:00
def print_markdown(data, title=None):
"""Print data in GitHub-flavoured Markdown format for issues etc.
2017-05-14 02:30:29 +03:00
data (dict or list of tuples): Label/value pairs.
title (unicode or None): Title, will be rendered as headline 2.
2017-04-16 14:42:34 +03:00
"""
def excl_value(value):
# contains path, i.e. personal info
return isinstance(value, basestring_) and Path(value).exists()
if isinstance(data, dict):
data = list(data.items())
2017-10-27 15:39:09 +03:00
markdown = ["* **{}:** {}".format(l, unicode_(v))
for l, v in data if not excl_value(v)]
2017-05-08 00:25:29 +03:00
if title:
print("\n## {}".format(title))
print('\n{}\n'.format('\n'.join(markdown)))
2017-05-08 03:00:37 +03:00
def prints(*texts, **kwargs):
2017-10-27 15:39:09 +03:00
"""Print formatted message (manual ANSI escape sequences to avoid
dependency)
2017-05-14 02:30:29 +03:00
*texts (unicode): Texts to print. Each argument is rendered as paragraph.
2017-10-27 15:39:09 +03:00
**kwargs: 'title' becomes coloured headline. exits=True performs sys exit.
2017-04-16 14:42:34 +03:00
"""
exits = kwargs.get('exits', None)
2017-05-08 02:05:24 +03:00
title = kwargs.get('title', None)
2017-05-08 00:25:29 +03:00
title = '\033[93m{}\033[0m\n'.format(_wrap(title)) if title else ''
message = '\n\n'.join([_wrap(text) for text in texts])
print('\n{}{}\n'.format(title, message))
if exits is not None:
sys.exit(exits)
2017-05-08 00:25:29 +03:00
def _wrap(text, wrap_max=80, indent=4):
"""Wrap text at given width using textwrap module.
2017-05-14 02:30:29 +03:00
text (unicode): Text to wrap. If it's a Path, it's converted to string.
wrap_max (int): Maximum line length (indent is deducted).
indent (int): Number of spaces for indentation.
RETURNS (unicode): Wrapped text.
2017-04-16 14:42:34 +03:00
"""
2017-05-08 00:25:29 +03:00
indent = indent * ' '
wrap_width = wrap_max - len(indent)
2017-05-08 00:25:29 +03:00
if isinstance(text, Path):
text = path2str(text)
return textwrap.fill(text, width=wrap_width, initial_indent=indent,
2017-05-08 00:25:29 +03:00
subsequent_indent=indent, break_long_words=False,
break_on_hyphens=False)
2017-05-14 18:50:23 +03:00
def minify_html(html):
"""Perform a template-specific, rudimentary HTML minification for displaCy.
2017-10-27 15:39:09 +03:00
Disclaimer: NOT a general-purpose solution, only removes indentation and
newlines.
2017-05-14 18:50:23 +03:00
html (unicode): Markup to minify.
RETURNS (unicode): "Minified" HTML.
"""
return html.strip().replace(' ', '').replace('\n', '')
2017-09-21 03:16:35 +03:00
def use_gpu(gpu_id):
2017-10-03 23:47:31 +03:00
try:
import cupy.cuda.device
except ImportError:
return None
2017-09-21 03:16:35 +03:00
from thinc.neural.ops import CupyOps
device = cupy.cuda.device.Device(gpu_id)
device.use()
Model.ops = CupyOps()
Model.Ops = CupyOps
return device