Add docstrings, error messages and fix consistency

This commit is contained in:
ines 2017-05-13 21:22:49 +02:00
parent ee7dcf65c9
commit 1694c24e52

View File

@ -32,11 +32,25 @@ def get_lang_class(name):
def load_lang_class(lang):
"""Import and load a Language class.
Args:
lang (unicode): Two-letter language code, e.g. 'en'.
Returns:
Language: Language class.
"""
module = importlib.import_module('.lang.%s' % lang, 'spacy')
return getattr(module, module.__all__[0])
def get_data_path(require_exists=True):
"""Get path to spaCy data directory.
Args:
require_exists (bool): Only return path if it exists, otherwise None.
Returns:
Path or None: Data path or None.
"""
if not require_exists:
return _data_path
else:
@ -44,6 +58,11 @@ def get_data_path(require_exists=True):
def set_data_path(path):
"""Set path to spaCy data directory.
Args:
path (unicode or Path): Path to new data directory.
"""
global _data_path
_data_path = ensure_path(path)
@ -56,6 +75,13 @@ def ensure_path(path):
def resolve_model_path(name):
"""Resolve a model name or string to a model path.
Args:
name (unicode): Package name, shortcut link or model path.
Returns:
Path: Path to model data directory.
"""
data_path = get_data_path()
if not data_path or not data_path.exists():
raise IOError("Can't find spaCy data path: %s" % path2str(data_path))
@ -71,18 +97,30 @@ def resolve_model_path(name):
raise IOError("Can't find model '%s'" % name)
def is_package(origin):
"""
Check if string maps to a package installed via pip.
def is_package(name):
"""Check if string maps to a package installed via pip.
Args:
name (unicode): Name of package.
Returns:
bool: True if installed package, False if not.
"""
packages = pip.get_installed_distributions()
for package in packages:
if package.project_name.replace('-', '_') == origin:
if package.project_name.replace('-', '_') == name:
return True
return False
def get_model_package_path(package_name):
"""Get path to a model package installed via pip.
Args:
package_name (unicode): Name of installed package.
Returns:
Path: Path to model data directory.
"""
# Here we're importing the module just to find it. This is worryingly
# indirect, but it's otherwise very difficult to find the package.
# Python's installation and import rules are very complicated.
@ -94,9 +132,13 @@ def get_model_package_path(package_name):
def parse_package_meta(package_path, require=True):
"""
Check if a meta.json exists in a package and return its contents as a
dictionary. If require is set to True, raise an error if no meta.json found.
"""Check if a meta.json exists in a package and return its contents.
Args:
package_path (Path): Path to model package directory.
require (bool): If True, raise error if no meta.json is found.
Returns:
dict or None: Model meta.json data or None.
"""
location = package_path / 'meta.json'
if location.is_file():
@ -136,6 +178,14 @@ def compile_infix_regex(entries):
def update_exc(base_exceptions, *addition_dicts):
"""Update and validate tokenizer exceptions. Will overwrite exceptions.
Args:
base_exceptions (dict): Base exceptions.
*addition_dicts (dict): Exceptions to add to the base dict, in order.
Returns:
dict: Combined tokenizer exceptions.
"""
exc = dict(base_exceptions)
for additions in addition_dicts:
for orth, token_attrs in additions.items():
@ -144,9 +194,9 @@ def update_exc(base_exceptions, *addition_dicts):
raise ValueError(msg % (orth, token_attrs))
described_orth = ''.join(attr[ORTH] for attr in token_attrs)
if orth != described_orth:
# TODO: Better error
msg = "Invalid tokenizer exception: key='%s', orths='%s'"
raise ValueError(msg % (orth, described_orth))
raise ValueError("Invalid tokenizer exception: ORTH values "
"combined don't match original string. "
"key='%s', orths='%s'" % (orth, described_orth))
# overlap = set(exc.keys()).intersection(set(additions))
# assert not overlap, overlap
exc.update(additions)
@ -155,6 +205,16 @@ def update_exc(base_exceptions, *addition_dicts):
def expand_exc(excs, search, replace):
"""Find string in tokenizer exceptions, duplicate entry and replace string.
For example, to add additional versions with typographic apostrophes.
Args:
excs (dict): Tokenizer exceptions.
search (unicode): String to find and replace.
replace (unicode): Replacement.
Returns:
dict:
"""
def _fix_token(token, search, replace):
fixed = dict(token)
fixed[ORTH] = fixed[ORTH].replace(search, replace)
@ -195,14 +255,25 @@ def check_renamed_kwargs(renamed, kwargs):
def read_json(location):
"""Open and load JSON from file.
Args:
location (Path): Path to JSON file.
Returns:
dict: Loaded JSON content.
"""
with location.open('r', encoding='utf8') as f:
return ujson.load(f)
def get_raw_input(description, default=False):
"""
Get user input via raw_input / input and return input value. Takes a
description, and an optional default value to display with the prompt.
"""Get user input from the command line via raw_input / input.
Args:
description (unicode): Text to display before prompt.
default (unicode or False/None): Default value to display with prompt.
Returns:
unicode: User input.
"""
additional = ' (default: %s)' % default if default else ''
prompt = ' %s%s: ' % (description, additional)
@ -211,11 +282,13 @@ def get_raw_input(description, default=False):
def print_table(data, title=None):
"""Print data in table format.
Args:
data (dict or list of tuples): Label/value pairs.
title (unicode or None): Title, will be printed above.
"""
Print data in table format. Can either take a list of tuples or a
dictionary, which will be converted to a list of tuples.
"""
if type(data) == dict:
if isinstance(data, dict):
data = list(data.items())
tpl_row = ' {:<15}' * len(data[0])
table = '\n'.join([tpl_row.format(l, v) for l, v in data])
@ -225,14 +298,16 @@ def print_table(data, title=None):
def print_markdown(data, title=None):
"""
Print listed data in GitHub-flavoured Markdown format so it can be
copy-pasted into issues. Can either take a list of tuples or a dictionary.
"""Print data in GitHub-flavoured Markdown format for issues etc.
Args:
data (dict or list of tuples): Label/value pairs.
title (unicode or None): Title, will be rendered as headline 2.
"""
def excl_value(value):
return Path(value).exists() # contains path (personal info)
if type(data) == dict:
if isinstance(data, dict):
data = list(data.items())
markdown = ["* **{}:** {}".format(l, v) for l, v in data if not excl_value(v)]
if title:
@ -241,10 +316,12 @@ def print_markdown(data, title=None):
def prints(*texts, **kwargs):
"""
Print formatted message. Each positional argument is rendered as newline-
separated paragraph. An optional highlighted title is printed above the text
(using ANSI escape sequences manually to avoid unnecessary dependency).
"""Print formatted message (manual ANSI escape sequences to avoid dependency)
Args:
*texts (unicode): Texts to print. Each argument is rendered as paragraph.
**kwargs: 'title' is rendered as coloured headline. 'exits'=True performs
system exit after printing.
"""
exits = kwargs.get('exits', False)
title = kwargs.get('title', None)
@ -256,9 +333,14 @@ def prints(*texts, **kwargs):
def _wrap(text, wrap_max=80, indent=4):
"""
Wrap text at given width using textwrap module. Indent should consist of
spaces. Its length is deducted from wrap width to ensure exact wrapping.
"""Wrap text at given width using textwrap module.
Args:
text (unicode): Text to wrap. If it's a Path, it's converted to string.
wrap_max (int): Maximum line length (indent is deducted).
indent (int): Number of spaces for indentation.
Returns:
unicode: Wrapped text.
"""
indent = indent * ' '
wrap_width = wrap_max - len(indent)