mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	Tidy up util and helpers
This commit is contained in:
		
							parent
							
								
									d941fc3667
								
							
						
					
					
						commit
						ea4a41c8fb
					
				| 
						 | 
				
			
			@ -87,15 +87,15 @@ def symlink_to(orig, dest):
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def is_config(python2=None, python3=None, windows=None, linux=None, osx=None):
 | 
			
		||||
    return ((python2 == None or python2 == is_python2) and
 | 
			
		||||
            (python3 == None or python3 == is_python3) and
 | 
			
		||||
            (windows == None or windows == is_windows) and
 | 
			
		||||
            (linux == None or linux == is_linux) and
 | 
			
		||||
            (osx == None or osx == is_osx))
 | 
			
		||||
    return ((python2 is None or python2 == is_python2) and
 | 
			
		||||
            (python3 is None or python3 == is_python3) and
 | 
			
		||||
            (windows is None or windows == is_windows) and
 | 
			
		||||
            (linux is None or linux == is_linux) and
 | 
			
		||||
            (osx is None or osx == is_osx))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def normalize_string_keys(old):
 | 
			
		||||
    '''Given a dictionary, make sure keys are unicode strings, not bytes.'''
 | 
			
		||||
    """Given a dictionary, make sure keys are unicode strings, not bytes."""
 | 
			
		||||
    new = {}
 | 
			
		||||
    for key, value in old.items():
 | 
			
		||||
        if isinstance(key, bytes_):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -24,7 +24,7 @@ def depr_model_download(lang):
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def resolve_load_name(name, **overrides):
 | 
			
		||||
    """Resolve model loading if deprecated path kwarg is specified in overrides.
 | 
			
		||||
    """Resolve model loading if deprecated path kwarg in overrides.
 | 
			
		||||
 | 
			
		||||
    name (unicode): Name of model to load.
 | 
			
		||||
    **overrides: Overrides specified in spacy.load().
 | 
			
		||||
| 
						 | 
				
			
			@ -32,8 +32,9 @@ def resolve_load_name(name, **overrides):
 | 
			
		|||
    """
 | 
			
		||||
    if overrides.get('path') not in (None, False, True):
 | 
			
		||||
        name = overrides.get('path')
 | 
			
		||||
        prints("To load a model from a path, you can now use the first argument. "
 | 
			
		||||
               "The model meta is used to load the required Language class.",
 | 
			
		||||
               "OLD: spacy.load('en', path='/some/path')", "NEW: spacy.load('/some/path')",
 | 
			
		||||
        prints("To load a model from a path, you can now use the first "
 | 
			
		||||
               "argument. The model meta is used to load the Language class.",
 | 
			
		||||
               "OLD: spacy.load('en', path='/some/path')",
 | 
			
		||||
               "NEW: spacy.load('/some/path')",
 | 
			
		||||
               title="Warning: deprecated argument 'path'")
 | 
			
		||||
    return name
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -264,7 +264,6 @@ GLOSSARY = {
 | 
			
		|||
    'nk':           'noun kernel element',
 | 
			
		||||
    'nmc':          'numerical component',
 | 
			
		||||
    'oa':           'accusative object',
 | 
			
		||||
    'oa':           'second accusative object',
 | 
			
		||||
    'oc':           'clausal object',
 | 
			
		||||
    'og':           'genitive object',
 | 
			
		||||
    'op':           'prepositional object',
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -43,8 +43,8 @@ def POS_tree(root, light=False, flat=False):
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def parse_tree(doc, light=False, flat=False):
 | 
			
		||||
    """Makes a copy of the doc, then construct a syntactic parse tree, similar to
 | 
			
		||||
    the one used in displaCy. Generates the POS tree for all sentences in a doc.
 | 
			
		||||
    """Make a copy of the doc and construct a syntactic parse tree similar to
 | 
			
		||||
    displaCy. Generates the POS tree for all sentences in a doc.
 | 
			
		||||
 | 
			
		||||
    doc (Doc): The doc for parsing.
 | 
			
		||||
    RETURNS (dict): The parse tree.
 | 
			
		||||
| 
						 | 
				
			
			@ -66,8 +66,9 @@ def parse_tree(doc, light=False, flat=False):
 | 
			
		|||
            'NE': '', 'word': 'ate', 'arc': 'ROOT', 'POS_coarse': 'VERB',
 | 
			
		||||
            'POS_fine': 'VBD', 'lemma': 'eat'}
 | 
			
		||||
    """
 | 
			
		||||
    doc_clone  = Doc(doc.vocab, words=[w.text for w in doc])
 | 
			
		||||
    doc_clone = Doc(doc.vocab, words=[w.text for w in doc])
 | 
			
		||||
    doc_clone.from_array([HEAD, TAG, DEP, ENT_IOB, ENT_TYPE],
 | 
			
		||||
                         doc.to_array([HEAD, TAG, DEP, ENT_IOB, ENT_TYPE]))
 | 
			
		||||
    merge_ents(doc_clone)  # merge the entities into single tokens first
 | 
			
		||||
    return [POS_tree(sent.root, light=light, flat=flat) for sent in doc_clone.sents]
 | 
			
		||||
    return [POS_tree(sent.root, light=light, flat=flat)
 | 
			
		||||
            for sent in doc_clone.sents]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,5 +1,9 @@
 | 
			
		|||
# coding: utf8
 | 
			
		||||
from __future__ import unicode_literals
 | 
			
		||||
 | 
			
		||||
import functools
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Underscore(object):
 | 
			
		||||
    doc_extensions = {}
 | 
			
		||||
    span_extensions = {}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -10,25 +10,27 @@ from pathlib import Path
 | 
			
		|||
import sys
 | 
			
		||||
import textwrap
 | 
			
		||||
import random
 | 
			
		||||
import numpy
 | 
			
		||||
import io
 | 
			
		||||
import dill
 | 
			
		||||
from collections import OrderedDict
 | 
			
		||||
from thinc.neural._classes.model import Model
 | 
			
		||||
import functools
 | 
			
		||||
 | 
			
		||||
from .symbols import ORTH
 | 
			
		||||
from .compat import cupy, CudaStream, path2str, basestring_, input_, unicode_
 | 
			
		||||
from .compat import import_file
 | 
			
		||||
 | 
			
		||||
import msgpack
 | 
			
		||||
import msgpack_numpy
 | 
			
		||||
msgpack_numpy.patch()
 | 
			
		||||
import ujson
 | 
			
		||||
 | 
			
		||||
from .symbols import ORTH
 | 
			
		||||
from .compat import cupy, CudaStream, path2str, basestring_, input_, unicode_
 | 
			
		||||
from .compat import copy_array, normalize_string_keys, getattr_, import_file
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
LANGUAGES = {}
 | 
			
		||||
_data_path = Path(__file__).parent / 'data'
 | 
			
		||||
_PRINT_ENV = False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def set_env_log(value):
 | 
			
		||||
    global _PRINT_ENV
 | 
			
		||||
    _PRINT_ENV = value
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_lang_class(lang):
 | 
			
		||||
| 
						 | 
				
			
			@ -38,11 +40,12 @@ def get_lang_class(lang):
 | 
			
		|||
    RETURNS (Language): Language class.
 | 
			
		||||
    """
 | 
			
		||||
    global LANGUAGES
 | 
			
		||||
    if not lang in LANGUAGES:
 | 
			
		||||
    if lang not in LANGUAGES:
 | 
			
		||||
        try:
 | 
			
		||||
            module = importlib.import_module('.lang.%s' % lang, 'spacy')
 | 
			
		||||
        except ImportError:
 | 
			
		||||
            raise ImportError("Can't import language %s from spacy.lang." %lang)
 | 
			
		||||
            msg = "Can't import language %s from spacy.lang."
 | 
			
		||||
            raise ImportError(msg % lang)
 | 
			
		||||
        LANGUAGES[lang] = getattr(module, module.__all__[0])
 | 
			
		||||
    return LANGUAGES[lang]
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -100,14 +103,14 @@ def load_model(name, **overrides):
 | 
			
		|||
    data_path = get_data_path()
 | 
			
		||||
    if not data_path or not data_path.exists():
 | 
			
		||||
        raise IOError("Can't find spaCy data path: %s" % path2str(data_path))
 | 
			
		||||
    if isinstance(name, basestring_):
 | 
			
		||||
        if name in set([d.name for d in data_path.iterdir()]): # in data dir / shortcut
 | 
			
		||||
    if isinstance(name, basestring_):  # in data dir / shortcut
 | 
			
		||||
        if name in set([d.name for d in data_path.iterdir()]):
 | 
			
		||||
            return load_model_from_link(name, **overrides)
 | 
			
		||||
        if is_package(name): # installed as package
 | 
			
		||||
        if is_package(name):  # installed as package
 | 
			
		||||
            return load_model_from_package(name, **overrides)
 | 
			
		||||
        if Path(name).exists(): # path to model data directory
 | 
			
		||||
        if Path(name).exists():  # path to model data directory
 | 
			
		||||
            return load_model_from_path(Path(name), **overrides)
 | 
			
		||||
    elif hasattr(name, 'exists'): # Path or Path-like to model data
 | 
			
		||||
    elif hasattr(name, 'exists'):  # Path or Path-like to model data
 | 
			
		||||
        return load_model_from_path(name, **overrides)
 | 
			
		||||
    raise IOError("Can't find model '%s'" % name)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -120,7 +123,7 @@ def load_model_from_link(name, **overrides):
 | 
			
		|||
    except AttributeError:
 | 
			
		||||
        raise IOError(
 | 
			
		||||
            "Cant' load '%s'. If you're using a shortcut link, make sure it "
 | 
			
		||||
            "points to a valid model package (not just a data directory)." % name)
 | 
			
		||||
            "points to a valid package (not just a data directory)." % name)
 | 
			
		||||
    return cls.load(**overrides)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -164,7 +167,8 @@ def load_model_from_init_py(init_file, **overrides):
 | 
			
		|||
    data_dir = '%s_%s-%s' % (meta['lang'], meta['name'], meta['version'])
 | 
			
		||||
    data_path = model_path / data_dir
 | 
			
		||||
    if not model_path.exists():
 | 
			
		||||
        raise ValueError("Can't find model directory: %s" % path2str(data_path))
 | 
			
		||||
        msg = "Can't find model directory: %s"
 | 
			
		||||
        raise ValueError(msg % path2str(data_path))
 | 
			
		||||
    return load_model_from_path(data_path, meta, **overrides)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -176,14 +180,16 @@ def get_model_meta(path):
 | 
			
		|||
    """
 | 
			
		||||
    model_path = ensure_path(path)
 | 
			
		||||
    if not model_path.exists():
 | 
			
		||||
        raise ValueError("Can't find model directory: %s" % path2str(model_path))
 | 
			
		||||
        msg = "Can't find model directory: %s"
 | 
			
		||||
        raise ValueError(msg % path2str(model_path))
 | 
			
		||||
    meta_path = model_path / 'meta.json'
 | 
			
		||||
    if not meta_path.is_file():
 | 
			
		||||
        raise IOError("Could not read meta.json from %s" % meta_path)
 | 
			
		||||
    meta = read_json(meta_path)
 | 
			
		||||
    for setting in ['lang', 'name', 'version']:
 | 
			
		||||
        if setting not in meta or not meta[setting]:
 | 
			
		||||
            raise ValueError("No valid '%s' setting found in model meta.json" % setting)
 | 
			
		||||
            msg = "No valid '%s' setting found in model meta.json"
 | 
			
		||||
            raise ValueError(msg % setting)
 | 
			
		||||
    return meta
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -240,7 +246,7 @@ def get_async(stream, numpy_array):
 | 
			
		|||
        return numpy_array
 | 
			
		||||
    else:
 | 
			
		||||
        array = cupy.ndarray(numpy_array.shape, order='C',
 | 
			
		||||
                           dtype=numpy_array.dtype)
 | 
			
		||||
                             dtype=numpy_array.dtype)
 | 
			
		||||
        array.set(numpy_array, stream=stream)
 | 
			
		||||
        return array
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -274,12 +280,6 @@ def itershuffle(iterable, bufsize=1000):
 | 
			
		|||
        raise StopIteration
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_PRINT_ENV = False
 | 
			
		||||
def set_env_log(value):
 | 
			
		||||
    global _PRINT_ENV
 | 
			
		||||
    _PRINT_ENV = value
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def env_opt(name, default=None):
 | 
			
		||||
    if type(default) is float:
 | 
			
		||||
        type_convert = float
 | 
			
		||||
| 
						 | 
				
			
			@ -305,17 +305,20 @@ def read_regex(path):
 | 
			
		|||
    path = ensure_path(path)
 | 
			
		||||
    with path.open() as file_:
 | 
			
		||||
        entries = file_.read().split('\n')
 | 
			
		||||
    expression = '|'.join(['^' + re.escape(piece) for piece in entries if piece.strip()])
 | 
			
		||||
    expression = '|'.join(['^' + re.escape(piece)
 | 
			
		||||
                           for piece in entries if piece.strip()])
 | 
			
		||||
    return re.compile(expression)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def compile_prefix_regex(entries):
 | 
			
		||||
    if '(' in entries:
 | 
			
		||||
        # Handle deprecated data
 | 
			
		||||
        expression = '|'.join(['^' + re.escape(piece) for piece in entries if piece.strip()])
 | 
			
		||||
        expression = '|'.join(['^' + re.escape(piece)
 | 
			
		||||
                               for piece in entries if piece.strip()])
 | 
			
		||||
        return re.compile(expression)
 | 
			
		||||
    else:
 | 
			
		||||
        expression = '|'.join(['^' + piece for piece in entries if piece.strip()])
 | 
			
		||||
        expression = '|'.join(['^' + piece
 | 
			
		||||
                               for piece in entries if piece.strip()])
 | 
			
		||||
        return re.compile(expression)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -359,16 +362,15 @@ def update_exc(base_exceptions, *addition_dicts):
 | 
			
		|||
    exc = dict(base_exceptions)
 | 
			
		||||
    for additions in addition_dicts:
 | 
			
		||||
        for orth, token_attrs in additions.items():
 | 
			
		||||
            if not all(isinstance(attr[ORTH], unicode_) for attr in token_attrs):
 | 
			
		||||
                msg = "Invalid value for ORTH in exception: key='%s', orths='%s'"
 | 
			
		||||
            if not all(isinstance(attr[ORTH], unicode_)
 | 
			
		||||
                       for attr in token_attrs):
 | 
			
		||||
                msg = "Invalid ORTH value in exception: key='%s', orths='%s'"
 | 
			
		||||
                raise ValueError(msg % (orth, token_attrs))
 | 
			
		||||
            described_orth = ''.join(attr[ORTH] for attr in token_attrs)
 | 
			
		||||
            if orth != described_orth:
 | 
			
		||||
                raise ValueError("Invalid tokenizer exception: ORTH values "
 | 
			
		||||
                                 "combined don't match original string. "
 | 
			
		||||
                                 "key='%s', orths='%s'" % (orth, described_orth))
 | 
			
		||||
        # overlap = set(exc.keys()).intersection(set(additions))
 | 
			
		||||
        # assert not overlap, overlap
 | 
			
		||||
                msg = ("Invalid tokenizer exception: ORTH values combined "
 | 
			
		||||
                       "don't match original string. key='%s', orths='%s'")
 | 
			
		||||
                raise ValueError(msg % (orth, described_orth))
 | 
			
		||||
        exc.update(additions)
 | 
			
		||||
    exc = expand_exc(exc, "'", "’")
 | 
			
		||||
    return exc
 | 
			
		||||
| 
						 | 
				
			
			@ -401,17 +403,15 @@ def normalize_slice(length, start, stop, step=None):
 | 
			
		|||
        raise ValueError("Stepped slices not supported in Span objects."
 | 
			
		||||
                         "Try: list(tokens)[start:stop:step] instead.")
 | 
			
		||||
    if start is None:
 | 
			
		||||
       start = 0
 | 
			
		||||
        start = 0
 | 
			
		||||
    elif start < 0:
 | 
			
		||||
       start += length
 | 
			
		||||
        start += length
 | 
			
		||||
    start = min(length, max(0, start))
 | 
			
		||||
 | 
			
		||||
    if stop is None:
 | 
			
		||||
       stop = length
 | 
			
		||||
        stop = length
 | 
			
		||||
    elif stop < 0:
 | 
			
		||||
       stop += length
 | 
			
		||||
        stop += length
 | 
			
		||||
    stop = min(length, max(start, stop))
 | 
			
		||||
 | 
			
		||||
    assert 0 <= start <= stop <= length
 | 
			
		||||
    return start, stop
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -428,7 +428,7 @@ def compounding(start, stop, compound):
 | 
			
		|||
      >>> assert next(sizes) == 1.5 * 1.5
 | 
			
		||||
    """
 | 
			
		||||
    def clip(value):
 | 
			
		||||
        return max(value, stop) if (start>stop) else min(value, stop)
 | 
			
		||||
        return max(value, stop) if (start > stop) else min(value, stop)
 | 
			
		||||
    curr = float(start)
 | 
			
		||||
    while True:
 | 
			
		||||
        yield clip(curr)
 | 
			
		||||
| 
						 | 
				
			
			@ -438,7 +438,7 @@ def compounding(start, stop, compound):
 | 
			
		|||
def decaying(start, stop, decay):
 | 
			
		||||
    """Yield an infinite series of linearly decaying values."""
 | 
			
		||||
    def clip(value):
 | 
			
		||||
        return max(value, stop) if (start>stop) else min(value, stop)
 | 
			
		||||
        return max(value, stop) if (start > stop) else min(value, stop)
 | 
			
		||||
    nr_upd = 1.
 | 
			
		||||
    while True:
 | 
			
		||||
        yield clip(start * 1./(1. + decay * nr_upd))
 | 
			
		||||
| 
						 | 
				
			
			@ -530,17 +530,19 @@ def print_markdown(data, title=None):
 | 
			
		|||
 | 
			
		||||
    if isinstance(data, dict):
 | 
			
		||||
        data = list(data.items())
 | 
			
		||||
    markdown = ["* **{}:** {}".format(l, unicode_(v)) for l, v in data if not excl_value(v)]
 | 
			
		||||
    markdown = ["* **{}:** {}".format(l, unicode_(v))
 | 
			
		||||
                for l, v in data if not excl_value(v)]
 | 
			
		||||
    if title:
 | 
			
		||||
        print("\n## {}".format(title))
 | 
			
		||||
    print('\n{}\n'.format('\n'.join(markdown)))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def prints(*texts, **kwargs):
 | 
			
		||||
    """Print formatted message (manual ANSI escape sequences to avoid dependency)
 | 
			
		||||
    """Print formatted message (manual ANSI escape sequences to avoid
 | 
			
		||||
    dependency)
 | 
			
		||||
 | 
			
		||||
    *texts (unicode): Texts to print. Each argument is rendered as paragraph.
 | 
			
		||||
    **kwargs: 'title' becomes coloured headline. 'exits'=True performs sys exit.
 | 
			
		||||
    **kwargs: 'title' becomes coloured headline. exits=True performs sys exit.
 | 
			
		||||
    """
 | 
			
		||||
    exits = kwargs.get('exits', None)
 | 
			
		||||
    title = kwargs.get('title', None)
 | 
			
		||||
| 
						 | 
				
			
			@ -570,7 +572,8 @@ def _wrap(text, wrap_max=80, indent=4):
 | 
			
		|||
 | 
			
		||||
def minify_html(html):
 | 
			
		||||
    """Perform a template-specific, rudimentary HTML minification for displaCy.
 | 
			
		||||
    Disclaimer: NOT a general-purpose solution, only removes indentation/newlines.
 | 
			
		||||
    Disclaimer: NOT a general-purpose solution, only removes indentation and
 | 
			
		||||
    newlines.
 | 
			
		||||
 | 
			
		||||
    html (unicode): Markup to minify.
 | 
			
		||||
    RETURNS (unicode): "Minified" HTML.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user