mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
Tidy up util and helpers
This commit is contained in:
parent
d941fc3667
commit
ea4a41c8fb
|
@ -87,15 +87,15 @@ def symlink_to(orig, dest):
|
|||
|
||||
|
||||
def is_config(python2=None, python3=None, windows=None, linux=None, osx=None):
|
||||
return ((python2 == None or python2 == is_python2) and
|
||||
(python3 == None or python3 == is_python3) and
|
||||
(windows == None or windows == is_windows) and
|
||||
(linux == None or linux == is_linux) and
|
||||
(osx == None or osx == is_osx))
|
||||
return ((python2 is None or python2 == is_python2) and
|
||||
(python3 is None or python3 == is_python3) and
|
||||
(windows is None or windows == is_windows) and
|
||||
(linux is None or linux == is_linux) and
|
||||
(osx is None or osx == is_osx))
|
||||
|
||||
|
||||
def normalize_string_keys(old):
|
||||
'''Given a dictionary, make sure keys are unicode strings, not bytes.'''
|
||||
"""Given a dictionary, make sure keys are unicode strings, not bytes."""
|
||||
new = {}
|
||||
for key, value in old.items():
|
||||
if isinstance(key, bytes_):
|
||||
|
|
|
@ -24,7 +24,7 @@ def depr_model_download(lang):
|
|||
|
||||
|
||||
def resolve_load_name(name, **overrides):
|
||||
"""Resolve model loading if deprecated path kwarg is specified in overrides.
|
||||
"""Resolve model loading if deprecated path kwarg in overrides.
|
||||
|
||||
name (unicode): Name of model to load.
|
||||
**overrides: Overrides specified in spacy.load().
|
||||
|
@ -32,8 +32,9 @@ def resolve_load_name(name, **overrides):
|
|||
"""
|
||||
if overrides.get('path') not in (None, False, True):
|
||||
name = overrides.get('path')
|
||||
prints("To load a model from a path, you can now use the first argument. "
|
||||
"The model meta is used to load the required Language class.",
|
||||
"OLD: spacy.load('en', path='/some/path')", "NEW: spacy.load('/some/path')",
|
||||
prints("To load a model from a path, you can now use the first "
|
||||
"argument. The model meta is used to load the Language class.",
|
||||
"OLD: spacy.load('en', path='/some/path')",
|
||||
"NEW: spacy.load('/some/path')",
|
||||
title="Warning: deprecated argument 'path'")
|
||||
return name
|
||||
|
|
|
@ -264,7 +264,6 @@ GLOSSARY = {
|
|||
'nk': 'noun kernel element',
|
||||
'nmc': 'numerical component',
|
||||
'oa': 'accusative object',
|
||||
'oa': 'second accusative object',
|
||||
'oc': 'clausal object',
|
||||
'og': 'genitive object',
|
||||
'op': 'prepositional object',
|
||||
|
|
|
@ -43,8 +43,8 @@ def POS_tree(root, light=False, flat=False):
|
|||
|
||||
|
||||
def parse_tree(doc, light=False, flat=False):
|
||||
"""Makes a copy of the doc, then construct a syntactic parse tree, similar to
|
||||
the one used in displaCy. Generates the POS tree for all sentences in a doc.
|
||||
"""Make a copy of the doc and construct a syntactic parse tree similar to
|
||||
displaCy. Generates the POS tree for all sentences in a doc.
|
||||
|
||||
doc (Doc): The doc for parsing.
|
||||
RETURNS (dict): The parse tree.
|
||||
|
@ -70,4 +70,5 @@ def parse_tree(doc, light=False, flat=False):
|
|||
doc_clone.from_array([HEAD, TAG, DEP, ENT_IOB, ENT_TYPE],
|
||||
doc.to_array([HEAD, TAG, DEP, ENT_IOB, ENT_TYPE]))
|
||||
merge_ents(doc_clone) # merge the entities into single tokens first
|
||||
return [POS_tree(sent.root, light=light, flat=flat) for sent in doc_clone.sents]
|
||||
return [POS_tree(sent.root, light=light, flat=flat)
|
||||
for sent in doc_clone.sents]
|
||||
|
|
|
@ -1,5 +1,9 @@
|
|||
# coding: utf8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import functools
|
||||
|
||||
|
||||
class Underscore(object):
|
||||
doc_extensions = {}
|
||||
span_extensions = {}
|
||||
|
|
|
@ -10,25 +10,27 @@ from pathlib import Path
|
|||
import sys
|
||||
import textwrap
|
||||
import random
|
||||
import numpy
|
||||
import io
|
||||
import dill
|
||||
from collections import OrderedDict
|
||||
from thinc.neural._classes.model import Model
|
||||
import functools
|
||||
|
||||
from .symbols import ORTH
|
||||
from .compat import cupy, CudaStream, path2str, basestring_, input_, unicode_
|
||||
from .compat import import_file
|
||||
|
||||
import msgpack
|
||||
import msgpack_numpy
|
||||
msgpack_numpy.patch()
|
||||
import ujson
|
||||
|
||||
from .symbols import ORTH
|
||||
from .compat import cupy, CudaStream, path2str, basestring_, input_, unicode_
|
||||
from .compat import copy_array, normalize_string_keys, getattr_, import_file
|
||||
|
||||
|
||||
LANGUAGES = {}
|
||||
_data_path = Path(__file__).parent / 'data'
|
||||
_PRINT_ENV = False
|
||||
|
||||
|
||||
def set_env_log(value):
|
||||
global _PRINT_ENV
|
||||
_PRINT_ENV = value
|
||||
|
||||
|
||||
def get_lang_class(lang):
|
||||
|
@ -38,11 +40,12 @@ def get_lang_class(lang):
|
|||
RETURNS (Language): Language class.
|
||||
"""
|
||||
global LANGUAGES
|
||||
if not lang in LANGUAGES:
|
||||
if lang not in LANGUAGES:
|
||||
try:
|
||||
module = importlib.import_module('.lang.%s' % lang, 'spacy')
|
||||
except ImportError:
|
||||
raise ImportError("Can't import language %s from spacy.lang." %lang)
|
||||
msg = "Can't import language %s from spacy.lang."
|
||||
raise ImportError(msg % lang)
|
||||
LANGUAGES[lang] = getattr(module, module.__all__[0])
|
||||
return LANGUAGES[lang]
|
||||
|
||||
|
@ -100,8 +103,8 @@ def load_model(name, **overrides):
|
|||
data_path = get_data_path()
|
||||
if not data_path or not data_path.exists():
|
||||
raise IOError("Can't find spaCy data path: %s" % path2str(data_path))
|
||||
if isinstance(name, basestring_):
|
||||
if name in set([d.name for d in data_path.iterdir()]): # in data dir / shortcut
|
||||
if isinstance(name, basestring_): # in data dir / shortcut
|
||||
if name in set([d.name for d in data_path.iterdir()]):
|
||||
return load_model_from_link(name, **overrides)
|
||||
if is_package(name): # installed as package
|
||||
return load_model_from_package(name, **overrides)
|
||||
|
@ -120,7 +123,7 @@ def load_model_from_link(name, **overrides):
|
|||
except AttributeError:
|
||||
raise IOError(
|
||||
"Cant' load '%s'. If you're using a shortcut link, make sure it "
|
||||
"points to a valid model package (not just a data directory)." % name)
|
||||
"points to a valid package (not just a data directory)." % name)
|
||||
return cls.load(**overrides)
|
||||
|
||||
|
||||
|
@ -164,7 +167,8 @@ def load_model_from_init_py(init_file, **overrides):
|
|||
data_dir = '%s_%s-%s' % (meta['lang'], meta['name'], meta['version'])
|
||||
data_path = model_path / data_dir
|
||||
if not model_path.exists():
|
||||
raise ValueError("Can't find model directory: %s" % path2str(data_path))
|
||||
msg = "Can't find model directory: %s"
|
||||
raise ValueError(msg % path2str(data_path))
|
||||
return load_model_from_path(data_path, meta, **overrides)
|
||||
|
||||
|
||||
|
@ -176,14 +180,16 @@ def get_model_meta(path):
|
|||
"""
|
||||
model_path = ensure_path(path)
|
||||
if not model_path.exists():
|
||||
raise ValueError("Can't find model directory: %s" % path2str(model_path))
|
||||
msg = "Can't find model directory: %s"
|
||||
raise ValueError(msg % path2str(model_path))
|
||||
meta_path = model_path / 'meta.json'
|
||||
if not meta_path.is_file():
|
||||
raise IOError("Could not read meta.json from %s" % meta_path)
|
||||
meta = read_json(meta_path)
|
||||
for setting in ['lang', 'name', 'version']:
|
||||
if setting not in meta or not meta[setting]:
|
||||
raise ValueError("No valid '%s' setting found in model meta.json" % setting)
|
||||
msg = "No valid '%s' setting found in model meta.json"
|
||||
raise ValueError(msg % setting)
|
||||
return meta
|
||||
|
||||
|
||||
|
@ -274,12 +280,6 @@ def itershuffle(iterable, bufsize=1000):
|
|||
raise StopIteration
|
||||
|
||||
|
||||
_PRINT_ENV = False
|
||||
def set_env_log(value):
|
||||
global _PRINT_ENV
|
||||
_PRINT_ENV = value
|
||||
|
||||
|
||||
def env_opt(name, default=None):
|
||||
if type(default) is float:
|
||||
type_convert = float
|
||||
|
@ -305,17 +305,20 @@ def read_regex(path):
|
|||
path = ensure_path(path)
|
||||
with path.open() as file_:
|
||||
entries = file_.read().split('\n')
|
||||
expression = '|'.join(['^' + re.escape(piece) for piece in entries if piece.strip()])
|
||||
expression = '|'.join(['^' + re.escape(piece)
|
||||
for piece in entries if piece.strip()])
|
||||
return re.compile(expression)
|
||||
|
||||
|
||||
def compile_prefix_regex(entries):
|
||||
if '(' in entries:
|
||||
# Handle deprecated data
|
||||
expression = '|'.join(['^' + re.escape(piece) for piece in entries if piece.strip()])
|
||||
expression = '|'.join(['^' + re.escape(piece)
|
||||
for piece in entries if piece.strip()])
|
||||
return re.compile(expression)
|
||||
else:
|
||||
expression = '|'.join(['^' + piece for piece in entries if piece.strip()])
|
||||
expression = '|'.join(['^' + piece
|
||||
for piece in entries if piece.strip()])
|
||||
return re.compile(expression)
|
||||
|
||||
|
||||
|
@ -359,16 +362,15 @@ def update_exc(base_exceptions, *addition_dicts):
|
|||
exc = dict(base_exceptions)
|
||||
for additions in addition_dicts:
|
||||
for orth, token_attrs in additions.items():
|
||||
if not all(isinstance(attr[ORTH], unicode_) for attr in token_attrs):
|
||||
msg = "Invalid value for ORTH in exception: key='%s', orths='%s'"
|
||||
if not all(isinstance(attr[ORTH], unicode_)
|
||||
for attr in token_attrs):
|
||||
msg = "Invalid ORTH value in exception: key='%s', orths='%s'"
|
||||
raise ValueError(msg % (orth, token_attrs))
|
||||
described_orth = ''.join(attr[ORTH] for attr in token_attrs)
|
||||
if orth != described_orth:
|
||||
raise ValueError("Invalid tokenizer exception: ORTH values "
|
||||
"combined don't match original string. "
|
||||
"key='%s', orths='%s'" % (orth, described_orth))
|
||||
# overlap = set(exc.keys()).intersection(set(additions))
|
||||
# assert not overlap, overlap
|
||||
msg = ("Invalid tokenizer exception: ORTH values combined "
|
||||
"don't match original string. key='%s', orths='%s'")
|
||||
raise ValueError(msg % (orth, described_orth))
|
||||
exc.update(additions)
|
||||
exc = expand_exc(exc, "'", "’")
|
||||
return exc
|
||||
|
@ -405,13 +407,11 @@ def normalize_slice(length, start, stop, step=None):
|
|||
elif start < 0:
|
||||
start += length
|
||||
start = min(length, max(0, start))
|
||||
|
||||
if stop is None:
|
||||
stop = length
|
||||
elif stop < 0:
|
||||
stop += length
|
||||
stop = min(length, max(start, stop))
|
||||
|
||||
assert 0 <= start <= stop <= length
|
||||
return start, stop
|
||||
|
||||
|
@ -428,7 +428,7 @@ def compounding(start, stop, compound):
|
|||
>>> assert next(sizes) == 1.5 * 1.5
|
||||
"""
|
||||
def clip(value):
|
||||
return max(value, stop) if (start>stop) else min(value, stop)
|
||||
return max(value, stop) if (start > stop) else min(value, stop)
|
||||
curr = float(start)
|
||||
while True:
|
||||
yield clip(curr)
|
||||
|
@ -438,7 +438,7 @@ def compounding(start, stop, compound):
|
|||
def decaying(start, stop, decay):
|
||||
"""Yield an infinite series of linearly decaying values."""
|
||||
def clip(value):
|
||||
return max(value, stop) if (start>stop) else min(value, stop)
|
||||
return max(value, stop) if (start > stop) else min(value, stop)
|
||||
nr_upd = 1.
|
||||
while True:
|
||||
yield clip(start * 1./(1. + decay * nr_upd))
|
||||
|
@ -530,17 +530,19 @@ def print_markdown(data, title=None):
|
|||
|
||||
if isinstance(data, dict):
|
||||
data = list(data.items())
|
||||
markdown = ["* **{}:** {}".format(l, unicode_(v)) for l, v in data if not excl_value(v)]
|
||||
markdown = ["* **{}:** {}".format(l, unicode_(v))
|
||||
for l, v in data if not excl_value(v)]
|
||||
if title:
|
||||
print("\n## {}".format(title))
|
||||
print('\n{}\n'.format('\n'.join(markdown)))
|
||||
|
||||
|
||||
def prints(*texts, **kwargs):
|
||||
"""Print formatted message (manual ANSI escape sequences to avoid dependency)
|
||||
"""Print formatted message (manual ANSI escape sequences to avoid
|
||||
dependency)
|
||||
|
||||
*texts (unicode): Texts to print. Each argument is rendered as paragraph.
|
||||
**kwargs: 'title' becomes coloured headline. 'exits'=True performs sys exit.
|
||||
**kwargs: 'title' becomes coloured headline. exits=True performs sys exit.
|
||||
"""
|
||||
exits = kwargs.get('exits', None)
|
||||
title = kwargs.get('title', None)
|
||||
|
@ -570,7 +572,8 @@ def _wrap(text, wrap_max=80, indent=4):
|
|||
|
||||
def minify_html(html):
|
||||
"""Perform a template-specific, rudimentary HTML minification for displaCy.
|
||||
Disclaimer: NOT a general-purpose solution, only removes indentation/newlines.
|
||||
Disclaimer: NOT a general-purpose solution, only removes indentation and
|
||||
newlines.
|
||||
|
||||
html (unicode): Markup to minify.
|
||||
RETURNS (unicode): "Minified" HTML.
|
||||
|
|
Loading…
Reference in New Issue
Block a user