spaCy/spacy/util.py

362 lines
12 KiB
Python
Raw Normal View History

2017-03-12 15:07:28 +03:00
# coding: utf8
from __future__ import unicode_literals, print_function
2017-04-15 13:13:34 +03:00
import ujson
import pip
import importlib
2017-04-20 02:22:52 +03:00
import regex as re
from pathlib import Path
2017-03-16 19:08:58 +03:00
import sys
import textwrap
from .symbols import ORTH
from .compat import path2str, basestring_, input_, unicode_
2016-03-25 20:54:45 +03:00
LANGUAGES = {}
_data_path = Path(__file__).parent / 'data'
try:
from cupy.cuda.stream import Stream as CudaStream
except ImportError:
CudaStream = None
2016-03-25 20:54:45 +03:00
try:
import cupy
except ImportError:
cupy = None
2016-03-25 20:54:45 +03:00
def get_lang_class(lang):
"""Import and load a Language class.
2016-03-25 20:54:45 +03:00
lang (unicode): Two-letter language code, e.g. 'en'.
RETURNS (Language): Language class.
"""
global LANGUAGES
if not lang in LANGUAGES:
try:
module = importlib.import_module('.lang.%s' % lang, 'spacy')
except ImportError:
raise ImportError("Can't import language %s from spacy.lang." %lang)
LANGUAGES[lang] = getattr(module, module.__all__[0])
2016-03-25 20:54:45 +03:00
return LANGUAGES[lang]
def set_lang_class(name, cls):
"""Set a custom Language class name that can be loaded via get_lang_class.
name (unicode): Name of Language class.
cls (Language): Language class.
"""
global LANGUAGES
LANGUAGES[name] = cls
2017-05-09 00:50:45 +03:00
2017-01-10 01:40:26 +03:00
def get_data_path(require_exists=True):
"""Get path to spaCy data directory.
2017-05-14 02:30:29 +03:00
require_exists (bool): Only return path if it exists, otherwise None.
RETURNS (Path or None): Data path or None.
"""
2017-01-10 01:40:26 +03:00
if not require_exists:
return _data_path
else:
return _data_path if _data_path.exists() else None
2016-09-24 21:26:17 +03:00
def set_data_path(path):
"""Set path to spaCy data directory.
2017-05-14 02:30:29 +03:00
path (unicode or Path): Path to new data directory.
"""
2016-09-24 21:26:17 +03:00
global _data_path
_data_path = ensure_path(path)
def ensure_path(path):
2017-05-14 02:30:29 +03:00
"""Ensure string is converted to a Path.
path: Anything. If string, it's converted to Path.
RETURNS: Path or original argument.
"""
if isinstance(path, basestring_):
return Path(path)
else:
return path
2016-09-24 21:26:17 +03:00
2017-05-09 00:51:15 +03:00
def resolve_model_path(name):
"""Resolve a model name or string to a model path.
2017-05-14 02:30:29 +03:00
name (unicode): Package name, shortcut link or model path.
RETURNS (Path): Path to model data directory.
"""
2017-05-09 00:51:15 +03:00
data_path = get_data_path()
if not data_path or not data_path.exists():
raise IOError("Can't find spaCy data path: %s" % path2str(data_path))
if isinstance(name, basestring_):
if (data_path / name).exists(): # in data dir or shortcut link
return (data_path / name)
if is_package(name): # installed as a package
return get_model_package_path(name)
if Path(name).exists(): # path to model
return Path(name)
elif hasattr(name, 'exists'): # Path or Path-like object
return name
raise IOError("Can't find model '%s'" % name)
def is_package(name):
"""Check if string maps to a package installed via pip.
2017-05-14 02:30:29 +03:00
name (unicode): Name of package.
RETURNS (bool): True if installed package, False if not.
2017-05-09 00:51:15 +03:00
"""
packages = pip.get_installed_distributions()
for package in packages:
if package.project_name.replace('-', '_') == name:
2017-05-09 00:51:15 +03:00
return True
return False
def get_model_package_path(package_name):
"""Get path to a model package installed via pip.
2017-05-14 02:30:29 +03:00
package_name (unicode): Name of installed package.
RETURNS (Path): Path to model data directory.
"""
2017-05-09 00:51:15 +03:00
# Here we're importing the module just to find it. This is worryingly
# indirect, but it's otherwise very difficult to find the package.
# Python's installation and import rules are very complicated.
pkg = importlib.import_module(package_name)
package_path = Path(pkg.__file__).parent.parent
meta = parse_package_meta(package_path / package_name)
model_name = '%s-%s' % (package_name, meta['version'])
return package_path / package_name / model_name
def parse_package_meta(package_path, require=True):
"""Check if a meta.json exists in a package and return its contents.
2017-05-14 02:30:29 +03:00
package_path (Path): Path to model package directory.
require (bool): If True, raise error if no meta.json is found.
RETURNS (dict or None): Model meta.json data or None.
2017-05-09 00:51:15 +03:00
"""
location = package_path / 'meta.json'
if location.is_file():
return read_json(location)
elif require:
raise IOError("Could not read meta.json from %s" % location)
else:
return None
def get_cuda_stream(require=False):
# TODO: Error and tell to install chainer if not found
# Requires GPU
return CudaStream() if CudaStream is not None else None
def get_async(stream, numpy_array):
if cupy is None:
return numpy_array
else:
return cupy.array(numpy_array, stream=stream)
2016-09-24 21:26:17 +03:00
def read_regex(path):
path = ensure_path(path)
2016-09-24 21:26:17 +03:00
with path.open() as file_:
entries = file_.read().split('\n')
expression = '|'.join(['^' + re.escape(piece) for piece in entries if piece.strip()])
return re.compile(expression)
def compile_prefix_regex(entries):
if '(' in entries:
# Handle deprecated data
expression = '|'.join(['^' + re.escape(piece) for piece in entries if piece.strip()])
return re.compile(expression)
else:
expression = '|'.join(['^' + piece for piece in entries if piece.strip()])
return re.compile(expression)
2016-09-24 21:26:17 +03:00
def compile_suffix_regex(entries):
2016-09-24 21:26:17 +03:00
expression = '|'.join([piece + '$' for piece in entries if piece.strip()])
return re.compile(expression)
def compile_infix_regex(entries):
2016-09-24 21:26:17 +03:00
expression = '|'.join([piece for piece in entries if piece.strip()])
return re.compile(expression)
def update_exc(base_exceptions, *addition_dicts):
"""Update and validate tokenizer exceptions. Will overwrite exceptions.
2017-05-14 02:30:29 +03:00
base_exceptions (dict): Base exceptions.
*addition_dicts (dict): Exceptions to add to the base dict, in order.
RETURNS (dict): Combined tokenizer exceptions.
"""
exc = dict(base_exceptions)
for additions in addition_dicts:
for orth, token_attrs in additions.items():
if not all(isinstance(attr[ORTH], unicode_) for attr in token_attrs):
msg = "Invalid value for ORTH in exception: key='%s', orths='%s'"
raise ValueError(msg % (orth, token_attrs))
described_orth = ''.join(attr[ORTH] for attr in token_attrs)
if orth != described_orth:
raise ValueError("Invalid tokenizer exception: ORTH values "
"combined don't match original string. "
"key='%s', orths='%s'" % (orth, described_orth))
# overlap = set(exc.keys()).intersection(set(additions))
# assert not overlap, overlap
exc.update(additions)
exc = expand_exc(exc, "'", "")
return exc
def expand_exc(excs, search, replace):
"""Find string in tokenizer exceptions, duplicate entry and replace string.
For example, to add additional versions with typographic apostrophes.
2017-05-14 02:30:29 +03:00
excs (dict): Tokenizer exceptions.
search (unicode): String to find and replace.
replace (unicode): Replacement.
RETURNS (dict): Combined tokenizer exceptions.
"""
def _fix_token(token, search, replace):
fixed = dict(token)
fixed[ORTH] = fixed[ORTH].replace(search, replace)
return fixed
new_excs = dict(excs)
for token_string, tokens in excs.items():
if search in token_string:
new_key = token_string.replace(search, replace)
new_value = [_fix_token(t, search, replace) for t in tokens]
new_excs[new_key] = new_value
return new_excs
def normalize_slice(length, start, stop, step=None):
if not (step is None or step == 1):
raise ValueError("Stepped slices not supported in Span objects."
"Try: list(tokens)[start:stop:step] instead.")
if start is None:
start = 0
elif start < 0:
start += length
start = min(length, max(0, start))
if stop is None:
stop = length
elif stop < 0:
stop += length
stop = min(length, max(start, stop))
assert 0 <= start <= stop <= length
return start, stop
def check_renamed_kwargs(renamed, kwargs):
for old, new in renamed.items():
if old in kwargs:
raise TypeError("Keyword argument %s now renamed to %s" % (old, new))
def read_json(location):
"""Open and load JSON from file.
2017-05-14 02:30:29 +03:00
location (Path): Path to JSON file.
RETURNS (dict): Loaded JSON content.
"""
with location.open('r', encoding='utf8') as f:
return ujson.load(f)
def get_raw_input(description, default=False):
"""Get user input from the command line via raw_input / input.
2017-05-14 02:30:29 +03:00
description (unicode): Text to display before prompt.
default (unicode or False/None): Default value to display with prompt.
RETURNS (unicode): User input.
2017-04-16 14:42:34 +03:00
"""
2017-05-08 00:25:29 +03:00
additional = ' (default: %s)' % default if default else ''
prompt = ' %s%s: ' % (description, additional)
user_input = input_(prompt)
return user_input
2017-05-08 00:25:29 +03:00
def print_table(data, title=None):
"""Print data in table format.
2017-05-14 02:30:29 +03:00
data (dict or list of tuples): Label/value pairs.
title (unicode or None): Title, will be printed above.
2017-04-16 14:42:34 +03:00
"""
if isinstance(data, dict):
data = list(data.items())
2017-05-08 00:25:29 +03:00
tpl_row = ' {:<15}' * len(data[0])
table = '\n'.join([tpl_row.format(l, v) for l, v in data])
2017-05-08 00:25:29 +03:00
if title:
print('\n \033[93m{}\033[0m'.format(title))
print('\n{}\n'.format(table))
2017-05-08 00:25:29 +03:00
def print_markdown(data, title=None):
"""Print data in GitHub-flavoured Markdown format for issues etc.
2017-05-14 02:30:29 +03:00
data (dict or list of tuples): Label/value pairs.
title (unicode or None): Title, will be rendered as headline 2.
2017-04-16 14:42:34 +03:00
"""
def excl_value(value):
2017-05-08 00:25:29 +03:00
return Path(value).exists() # contains path (personal info)
if isinstance(data, dict):
data = list(data.items())
2017-05-08 00:25:29 +03:00
markdown = ["* **{}:** {}".format(l, v) for l, v in data if not excl_value(v)]
if title:
print("\n## {}".format(title))
print('\n{}\n'.format('\n'.join(markdown)))
2017-05-08 03:00:37 +03:00
def prints(*texts, **kwargs):
"""Print formatted message (manual ANSI escape sequences to avoid dependency)
2017-05-14 02:30:29 +03:00
*texts (unicode): Texts to print. Each argument is rendered as paragraph.
**kwargs: 'title' becomes coloured headline. 'exits'=True performs sys exit.
2017-04-16 14:42:34 +03:00
"""
2017-05-08 02:05:24 +03:00
exits = kwargs.get('exits', False)
title = kwargs.get('title', None)
2017-05-08 00:25:29 +03:00
title = '\033[93m{}\033[0m\n'.format(_wrap(title)) if title else ''
message = '\n\n'.join([_wrap(text) for text in texts])
print('\n{}{}\n'.format(title, message))
if exits:
sys.exit(0)
2017-05-08 00:25:29 +03:00
def _wrap(text, wrap_max=80, indent=4):
"""Wrap text at given width using textwrap module.
2017-05-14 02:30:29 +03:00
text (unicode): Text to wrap. If it's a Path, it's converted to string.
wrap_max (int): Maximum line length (indent is deducted).
indent (int): Number of spaces for indentation.
RETURNS (unicode): Wrapped text.
2017-04-16 14:42:34 +03:00
"""
2017-05-08 00:25:29 +03:00
indent = indent * ' '
wrap_width = wrap_max - len(indent)
2017-05-08 00:25:29 +03:00
if isinstance(text, Path):
text = path2str(text)
return textwrap.fill(text, width=wrap_width, initial_indent=indent,
2017-05-08 00:25:29 +03:00
subsequent_indent=indent, break_long_words=False,
break_on_hyphens=False)
2017-05-14 18:50:23 +03:00
def minify_html(html):
"""Perform a template-specific, rudimentary HTML minification for displaCy.
Disclaimer: NOT a general-purpose solution, only removes indentation/newlines.
html (unicode): Markup to minify.
RETURNS (unicode): "Minified" HTML.
"""
return html.strip().replace(' ', '').replace('\n', '')