mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Tidy up util and helpers
This commit is contained in:
parent
d941fc3667
commit
ea4a41c8fb
|
@ -87,15 +87,15 @@ def symlink_to(orig, dest):
|
||||||
|
|
||||||
|
|
||||||
def is_config(python2=None, python3=None, windows=None, linux=None, osx=None):
|
def is_config(python2=None, python3=None, windows=None, linux=None, osx=None):
|
||||||
return ((python2 == None or python2 == is_python2) and
|
return ((python2 is None or python2 == is_python2) and
|
||||||
(python3 == None or python3 == is_python3) and
|
(python3 is None or python3 == is_python3) and
|
||||||
(windows == None or windows == is_windows) and
|
(windows is None or windows == is_windows) and
|
||||||
(linux == None or linux == is_linux) and
|
(linux is None or linux == is_linux) and
|
||||||
(osx == None or osx == is_osx))
|
(osx is None or osx == is_osx))
|
||||||
|
|
||||||
|
|
||||||
def normalize_string_keys(old):
|
def normalize_string_keys(old):
|
||||||
'''Given a dictionary, make sure keys are unicode strings, not bytes.'''
|
"""Given a dictionary, make sure keys are unicode strings, not bytes."""
|
||||||
new = {}
|
new = {}
|
||||||
for key, value in old.items():
|
for key, value in old.items():
|
||||||
if isinstance(key, bytes_):
|
if isinstance(key, bytes_):
|
||||||
|
|
|
@ -24,7 +24,7 @@ def depr_model_download(lang):
|
||||||
|
|
||||||
|
|
||||||
def resolve_load_name(name, **overrides):
|
def resolve_load_name(name, **overrides):
|
||||||
"""Resolve model loading if deprecated path kwarg is specified in overrides.
|
"""Resolve model loading if deprecated path kwarg in overrides.
|
||||||
|
|
||||||
name (unicode): Name of model to load.
|
name (unicode): Name of model to load.
|
||||||
**overrides: Overrides specified in spacy.load().
|
**overrides: Overrides specified in spacy.load().
|
||||||
|
@ -32,8 +32,9 @@ def resolve_load_name(name, **overrides):
|
||||||
"""
|
"""
|
||||||
if overrides.get('path') not in (None, False, True):
|
if overrides.get('path') not in (None, False, True):
|
||||||
name = overrides.get('path')
|
name = overrides.get('path')
|
||||||
prints("To load a model from a path, you can now use the first argument. "
|
prints("To load a model from a path, you can now use the first "
|
||||||
"The model meta is used to load the required Language class.",
|
"argument. The model meta is used to load the Language class.",
|
||||||
"OLD: spacy.load('en', path='/some/path')", "NEW: spacy.load('/some/path')",
|
"OLD: spacy.load('en', path='/some/path')",
|
||||||
|
"NEW: spacy.load('/some/path')",
|
||||||
title="Warning: deprecated argument 'path'")
|
title="Warning: deprecated argument 'path'")
|
||||||
return name
|
return name
|
||||||
|
|
|
@ -264,7 +264,6 @@ GLOSSARY = {
|
||||||
'nk': 'noun kernel element',
|
'nk': 'noun kernel element',
|
||||||
'nmc': 'numerical component',
|
'nmc': 'numerical component',
|
||||||
'oa': 'accusative object',
|
'oa': 'accusative object',
|
||||||
'oa': 'second accusative object',
|
|
||||||
'oc': 'clausal object',
|
'oc': 'clausal object',
|
||||||
'og': 'genitive object',
|
'og': 'genitive object',
|
||||||
'op': 'prepositional object',
|
'op': 'prepositional object',
|
||||||
|
|
|
@ -43,8 +43,8 @@ def POS_tree(root, light=False, flat=False):
|
||||||
|
|
||||||
|
|
||||||
def parse_tree(doc, light=False, flat=False):
|
def parse_tree(doc, light=False, flat=False):
|
||||||
"""Makes a copy of the doc, then construct a syntactic parse tree, similar to
|
"""Make a copy of the doc and construct a syntactic parse tree similar to
|
||||||
the one used in displaCy. Generates the POS tree for all sentences in a doc.
|
displaCy. Generates the POS tree for all sentences in a doc.
|
||||||
|
|
||||||
doc (Doc): The doc for parsing.
|
doc (Doc): The doc for parsing.
|
||||||
RETURNS (dict): The parse tree.
|
RETURNS (dict): The parse tree.
|
||||||
|
@ -66,8 +66,9 @@ def parse_tree(doc, light=False, flat=False):
|
||||||
'NE': '', 'word': 'ate', 'arc': 'ROOT', 'POS_coarse': 'VERB',
|
'NE': '', 'word': 'ate', 'arc': 'ROOT', 'POS_coarse': 'VERB',
|
||||||
'POS_fine': 'VBD', 'lemma': 'eat'}
|
'POS_fine': 'VBD', 'lemma': 'eat'}
|
||||||
"""
|
"""
|
||||||
doc_clone = Doc(doc.vocab, words=[w.text for w in doc])
|
doc_clone = Doc(doc.vocab, words=[w.text for w in doc])
|
||||||
doc_clone.from_array([HEAD, TAG, DEP, ENT_IOB, ENT_TYPE],
|
doc_clone.from_array([HEAD, TAG, DEP, ENT_IOB, ENT_TYPE],
|
||||||
doc.to_array([HEAD, TAG, DEP, ENT_IOB, ENT_TYPE]))
|
doc.to_array([HEAD, TAG, DEP, ENT_IOB, ENT_TYPE]))
|
||||||
merge_ents(doc_clone) # merge the entities into single tokens first
|
merge_ents(doc_clone) # merge the entities into single tokens first
|
||||||
return [POS_tree(sent.root, light=light, flat=flat) for sent in doc_clone.sents]
|
return [POS_tree(sent.root, light=light, flat=flat)
|
||||||
|
for sent in doc_clone.sents]
|
||||||
|
|
|
@ -1,5 +1,9 @@
|
||||||
|
# coding: utf8
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
|
|
||||||
class Underscore(object):
|
class Underscore(object):
|
||||||
doc_extensions = {}
|
doc_extensions = {}
|
||||||
span_extensions = {}
|
span_extensions = {}
|
||||||
|
|
|
@ -10,25 +10,27 @@ from pathlib import Path
|
||||||
import sys
|
import sys
|
||||||
import textwrap
|
import textwrap
|
||||||
import random
|
import random
|
||||||
import numpy
|
|
||||||
import io
|
|
||||||
import dill
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from thinc.neural._classes.model import Model
|
from thinc.neural._classes.model import Model
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
|
from .symbols import ORTH
|
||||||
|
from .compat import cupy, CudaStream, path2str, basestring_, input_, unicode_
|
||||||
|
from .compat import import_file
|
||||||
|
|
||||||
import msgpack
|
import msgpack
|
||||||
import msgpack_numpy
|
import msgpack_numpy
|
||||||
msgpack_numpy.patch()
|
msgpack_numpy.patch()
|
||||||
import ujson
|
|
||||||
|
|
||||||
from .symbols import ORTH
|
|
||||||
from .compat import cupy, CudaStream, path2str, basestring_, input_, unicode_
|
|
||||||
from .compat import copy_array, normalize_string_keys, getattr_, import_file
|
|
||||||
|
|
||||||
|
|
||||||
LANGUAGES = {}
|
LANGUAGES = {}
|
||||||
_data_path = Path(__file__).parent / 'data'
|
_data_path = Path(__file__).parent / 'data'
|
||||||
|
_PRINT_ENV = False
|
||||||
|
|
||||||
|
|
||||||
|
def set_env_log(value):
|
||||||
|
global _PRINT_ENV
|
||||||
|
_PRINT_ENV = value
|
||||||
|
|
||||||
|
|
||||||
def get_lang_class(lang):
|
def get_lang_class(lang):
|
||||||
|
@ -38,11 +40,12 @@ def get_lang_class(lang):
|
||||||
RETURNS (Language): Language class.
|
RETURNS (Language): Language class.
|
||||||
"""
|
"""
|
||||||
global LANGUAGES
|
global LANGUAGES
|
||||||
if not lang in LANGUAGES:
|
if lang not in LANGUAGES:
|
||||||
try:
|
try:
|
||||||
module = importlib.import_module('.lang.%s' % lang, 'spacy')
|
module = importlib.import_module('.lang.%s' % lang, 'spacy')
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("Can't import language %s from spacy.lang." %lang)
|
msg = "Can't import language %s from spacy.lang."
|
||||||
|
raise ImportError(msg % lang)
|
||||||
LANGUAGES[lang] = getattr(module, module.__all__[0])
|
LANGUAGES[lang] = getattr(module, module.__all__[0])
|
||||||
return LANGUAGES[lang]
|
return LANGUAGES[lang]
|
||||||
|
|
||||||
|
@ -100,14 +103,14 @@ def load_model(name, **overrides):
|
||||||
data_path = get_data_path()
|
data_path = get_data_path()
|
||||||
if not data_path or not data_path.exists():
|
if not data_path or not data_path.exists():
|
||||||
raise IOError("Can't find spaCy data path: %s" % path2str(data_path))
|
raise IOError("Can't find spaCy data path: %s" % path2str(data_path))
|
||||||
if isinstance(name, basestring_):
|
if isinstance(name, basestring_): # in data dir / shortcut
|
||||||
if name in set([d.name for d in data_path.iterdir()]): # in data dir / shortcut
|
if name in set([d.name for d in data_path.iterdir()]):
|
||||||
return load_model_from_link(name, **overrides)
|
return load_model_from_link(name, **overrides)
|
||||||
if is_package(name): # installed as package
|
if is_package(name): # installed as package
|
||||||
return load_model_from_package(name, **overrides)
|
return load_model_from_package(name, **overrides)
|
||||||
if Path(name).exists(): # path to model data directory
|
if Path(name).exists(): # path to model data directory
|
||||||
return load_model_from_path(Path(name), **overrides)
|
return load_model_from_path(Path(name), **overrides)
|
||||||
elif hasattr(name, 'exists'): # Path or Path-like to model data
|
elif hasattr(name, 'exists'): # Path or Path-like to model data
|
||||||
return load_model_from_path(name, **overrides)
|
return load_model_from_path(name, **overrides)
|
||||||
raise IOError("Can't find model '%s'" % name)
|
raise IOError("Can't find model '%s'" % name)
|
||||||
|
|
||||||
|
@ -120,7 +123,7 @@ def load_model_from_link(name, **overrides):
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
raise IOError(
|
raise IOError(
|
||||||
"Cant' load '%s'. If you're using a shortcut link, make sure it "
|
"Cant' load '%s'. If you're using a shortcut link, make sure it "
|
||||||
"points to a valid model package (not just a data directory)." % name)
|
"points to a valid package (not just a data directory)." % name)
|
||||||
return cls.load(**overrides)
|
return cls.load(**overrides)
|
||||||
|
|
||||||
|
|
||||||
|
@ -164,7 +167,8 @@ def load_model_from_init_py(init_file, **overrides):
|
||||||
data_dir = '%s_%s-%s' % (meta['lang'], meta['name'], meta['version'])
|
data_dir = '%s_%s-%s' % (meta['lang'], meta['name'], meta['version'])
|
||||||
data_path = model_path / data_dir
|
data_path = model_path / data_dir
|
||||||
if not model_path.exists():
|
if not model_path.exists():
|
||||||
raise ValueError("Can't find model directory: %s" % path2str(data_path))
|
msg = "Can't find model directory: %s"
|
||||||
|
raise ValueError(msg % path2str(data_path))
|
||||||
return load_model_from_path(data_path, meta, **overrides)
|
return load_model_from_path(data_path, meta, **overrides)
|
||||||
|
|
||||||
|
|
||||||
|
@ -176,14 +180,16 @@ def get_model_meta(path):
|
||||||
"""
|
"""
|
||||||
model_path = ensure_path(path)
|
model_path = ensure_path(path)
|
||||||
if not model_path.exists():
|
if not model_path.exists():
|
||||||
raise ValueError("Can't find model directory: %s" % path2str(model_path))
|
msg = "Can't find model directory: %s"
|
||||||
|
raise ValueError(msg % path2str(model_path))
|
||||||
meta_path = model_path / 'meta.json'
|
meta_path = model_path / 'meta.json'
|
||||||
if not meta_path.is_file():
|
if not meta_path.is_file():
|
||||||
raise IOError("Could not read meta.json from %s" % meta_path)
|
raise IOError("Could not read meta.json from %s" % meta_path)
|
||||||
meta = read_json(meta_path)
|
meta = read_json(meta_path)
|
||||||
for setting in ['lang', 'name', 'version']:
|
for setting in ['lang', 'name', 'version']:
|
||||||
if setting not in meta or not meta[setting]:
|
if setting not in meta or not meta[setting]:
|
||||||
raise ValueError("No valid '%s' setting found in model meta.json" % setting)
|
msg = "No valid '%s' setting found in model meta.json"
|
||||||
|
raise ValueError(msg % setting)
|
||||||
return meta
|
return meta
|
||||||
|
|
||||||
|
|
||||||
|
@ -240,7 +246,7 @@ def get_async(stream, numpy_array):
|
||||||
return numpy_array
|
return numpy_array
|
||||||
else:
|
else:
|
||||||
array = cupy.ndarray(numpy_array.shape, order='C',
|
array = cupy.ndarray(numpy_array.shape, order='C',
|
||||||
dtype=numpy_array.dtype)
|
dtype=numpy_array.dtype)
|
||||||
array.set(numpy_array, stream=stream)
|
array.set(numpy_array, stream=stream)
|
||||||
return array
|
return array
|
||||||
|
|
||||||
|
@ -274,12 +280,6 @@ def itershuffle(iterable, bufsize=1000):
|
||||||
raise StopIteration
|
raise StopIteration
|
||||||
|
|
||||||
|
|
||||||
_PRINT_ENV = False
|
|
||||||
def set_env_log(value):
|
|
||||||
global _PRINT_ENV
|
|
||||||
_PRINT_ENV = value
|
|
||||||
|
|
||||||
|
|
||||||
def env_opt(name, default=None):
|
def env_opt(name, default=None):
|
||||||
if type(default) is float:
|
if type(default) is float:
|
||||||
type_convert = float
|
type_convert = float
|
||||||
|
@ -305,17 +305,20 @@ def read_regex(path):
|
||||||
path = ensure_path(path)
|
path = ensure_path(path)
|
||||||
with path.open() as file_:
|
with path.open() as file_:
|
||||||
entries = file_.read().split('\n')
|
entries = file_.read().split('\n')
|
||||||
expression = '|'.join(['^' + re.escape(piece) for piece in entries if piece.strip()])
|
expression = '|'.join(['^' + re.escape(piece)
|
||||||
|
for piece in entries if piece.strip()])
|
||||||
return re.compile(expression)
|
return re.compile(expression)
|
||||||
|
|
||||||
|
|
||||||
def compile_prefix_regex(entries):
|
def compile_prefix_regex(entries):
|
||||||
if '(' in entries:
|
if '(' in entries:
|
||||||
# Handle deprecated data
|
# Handle deprecated data
|
||||||
expression = '|'.join(['^' + re.escape(piece) for piece in entries if piece.strip()])
|
expression = '|'.join(['^' + re.escape(piece)
|
||||||
|
for piece in entries if piece.strip()])
|
||||||
return re.compile(expression)
|
return re.compile(expression)
|
||||||
else:
|
else:
|
||||||
expression = '|'.join(['^' + piece for piece in entries if piece.strip()])
|
expression = '|'.join(['^' + piece
|
||||||
|
for piece in entries if piece.strip()])
|
||||||
return re.compile(expression)
|
return re.compile(expression)
|
||||||
|
|
||||||
|
|
||||||
|
@ -359,16 +362,15 @@ def update_exc(base_exceptions, *addition_dicts):
|
||||||
exc = dict(base_exceptions)
|
exc = dict(base_exceptions)
|
||||||
for additions in addition_dicts:
|
for additions in addition_dicts:
|
||||||
for orth, token_attrs in additions.items():
|
for orth, token_attrs in additions.items():
|
||||||
if not all(isinstance(attr[ORTH], unicode_) for attr in token_attrs):
|
if not all(isinstance(attr[ORTH], unicode_)
|
||||||
msg = "Invalid value for ORTH in exception: key='%s', orths='%s'"
|
for attr in token_attrs):
|
||||||
|
msg = "Invalid ORTH value in exception: key='%s', orths='%s'"
|
||||||
raise ValueError(msg % (orth, token_attrs))
|
raise ValueError(msg % (orth, token_attrs))
|
||||||
described_orth = ''.join(attr[ORTH] for attr in token_attrs)
|
described_orth = ''.join(attr[ORTH] for attr in token_attrs)
|
||||||
if orth != described_orth:
|
if orth != described_orth:
|
||||||
raise ValueError("Invalid tokenizer exception: ORTH values "
|
msg = ("Invalid tokenizer exception: ORTH values combined "
|
||||||
"combined don't match original string. "
|
"don't match original string. key='%s', orths='%s'")
|
||||||
"key='%s', orths='%s'" % (orth, described_orth))
|
raise ValueError(msg % (orth, described_orth))
|
||||||
# overlap = set(exc.keys()).intersection(set(additions))
|
|
||||||
# assert not overlap, overlap
|
|
||||||
exc.update(additions)
|
exc.update(additions)
|
||||||
exc = expand_exc(exc, "'", "’")
|
exc = expand_exc(exc, "'", "’")
|
||||||
return exc
|
return exc
|
||||||
|
@ -401,17 +403,15 @@ def normalize_slice(length, start, stop, step=None):
|
||||||
raise ValueError("Stepped slices not supported in Span objects."
|
raise ValueError("Stepped slices not supported in Span objects."
|
||||||
"Try: list(tokens)[start:stop:step] instead.")
|
"Try: list(tokens)[start:stop:step] instead.")
|
||||||
if start is None:
|
if start is None:
|
||||||
start = 0
|
start = 0
|
||||||
elif start < 0:
|
elif start < 0:
|
||||||
start += length
|
start += length
|
||||||
start = min(length, max(0, start))
|
start = min(length, max(0, start))
|
||||||
|
|
||||||
if stop is None:
|
if stop is None:
|
||||||
stop = length
|
stop = length
|
||||||
elif stop < 0:
|
elif stop < 0:
|
||||||
stop += length
|
stop += length
|
||||||
stop = min(length, max(start, stop))
|
stop = min(length, max(start, stop))
|
||||||
|
|
||||||
assert 0 <= start <= stop <= length
|
assert 0 <= start <= stop <= length
|
||||||
return start, stop
|
return start, stop
|
||||||
|
|
||||||
|
@ -428,7 +428,7 @@ def compounding(start, stop, compound):
|
||||||
>>> assert next(sizes) == 1.5 * 1.5
|
>>> assert next(sizes) == 1.5 * 1.5
|
||||||
"""
|
"""
|
||||||
def clip(value):
|
def clip(value):
|
||||||
return max(value, stop) if (start>stop) else min(value, stop)
|
return max(value, stop) if (start > stop) else min(value, stop)
|
||||||
curr = float(start)
|
curr = float(start)
|
||||||
while True:
|
while True:
|
||||||
yield clip(curr)
|
yield clip(curr)
|
||||||
|
@ -438,7 +438,7 @@ def compounding(start, stop, compound):
|
||||||
def decaying(start, stop, decay):
|
def decaying(start, stop, decay):
|
||||||
"""Yield an infinite series of linearly decaying values."""
|
"""Yield an infinite series of linearly decaying values."""
|
||||||
def clip(value):
|
def clip(value):
|
||||||
return max(value, stop) if (start>stop) else min(value, stop)
|
return max(value, stop) if (start > stop) else min(value, stop)
|
||||||
nr_upd = 1.
|
nr_upd = 1.
|
||||||
while True:
|
while True:
|
||||||
yield clip(start * 1./(1. + decay * nr_upd))
|
yield clip(start * 1./(1. + decay * nr_upd))
|
||||||
|
@ -530,17 +530,19 @@ def print_markdown(data, title=None):
|
||||||
|
|
||||||
if isinstance(data, dict):
|
if isinstance(data, dict):
|
||||||
data = list(data.items())
|
data = list(data.items())
|
||||||
markdown = ["* **{}:** {}".format(l, unicode_(v)) for l, v in data if not excl_value(v)]
|
markdown = ["* **{}:** {}".format(l, unicode_(v))
|
||||||
|
for l, v in data if not excl_value(v)]
|
||||||
if title:
|
if title:
|
||||||
print("\n## {}".format(title))
|
print("\n## {}".format(title))
|
||||||
print('\n{}\n'.format('\n'.join(markdown)))
|
print('\n{}\n'.format('\n'.join(markdown)))
|
||||||
|
|
||||||
|
|
||||||
def prints(*texts, **kwargs):
|
def prints(*texts, **kwargs):
|
||||||
"""Print formatted message (manual ANSI escape sequences to avoid dependency)
|
"""Print formatted message (manual ANSI escape sequences to avoid
|
||||||
|
dependency)
|
||||||
|
|
||||||
*texts (unicode): Texts to print. Each argument is rendered as paragraph.
|
*texts (unicode): Texts to print. Each argument is rendered as paragraph.
|
||||||
**kwargs: 'title' becomes coloured headline. 'exits'=True performs sys exit.
|
**kwargs: 'title' becomes coloured headline. exits=True performs sys exit.
|
||||||
"""
|
"""
|
||||||
exits = kwargs.get('exits', None)
|
exits = kwargs.get('exits', None)
|
||||||
title = kwargs.get('title', None)
|
title = kwargs.get('title', None)
|
||||||
|
@ -570,7 +572,8 @@ def _wrap(text, wrap_max=80, indent=4):
|
||||||
|
|
||||||
def minify_html(html):
|
def minify_html(html):
|
||||||
"""Perform a template-specific, rudimentary HTML minification for displaCy.
|
"""Perform a template-specific, rudimentary HTML minification for displaCy.
|
||||||
Disclaimer: NOT a general-purpose solution, only removes indentation/newlines.
|
Disclaimer: NOT a general-purpose solution, only removes indentation and
|
||||||
|
newlines.
|
||||||
|
|
||||||
html (unicode): Markup to minify.
|
html (unicode): Markup to minify.
|
||||||
RETURNS (unicode): "Minified" HTML.
|
RETURNS (unicode): "Minified" HTML.
|
||||||
|
|
Loading…
Reference in New Issue
Block a user