mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
bede11b67c
This patch does a few smallish things that tighten up the training workflow a little, and allow memory use during training to be reduced by letting the GoldCorpus stream data properly. Previously, the parser and entity recognizer read and saved labels as lists, with extra labels noted separately. Lists were used becaue ordering is very important, to ensure that the label-to-class mapping is stable. We now manage labels as nested dictionaries, first keyed by the action, and then keyed by the label. Values are frequencies. The trick is, how do we save new labels? We need to make sure we iterate over these in the same order they're added. Otherwise, we'll get different class IDs, and the model's predictions won't make sense. To allow stable sorting, we map the new labels to negative values. If we have two new labels, they'll be noted as having "frequency" -1 and -2. The next new label will then have "frequency" -3. When we sort by (frequency, label), we then get a stable sort. Storing frequencies then allows us to make the next nice improvement. Previously we had to iterate over the whole training set, to pre-process it for the deprojectivisation. This led to storing the whole training set in memory. This was most of the required memory during training. To prevent this, we now store the frequencies as we stream in the data, and deprojectivize as we go. Once we've built the frequencies, we can then apply a frequency cut-off when we decide how many classes to make. Finally, to allow proper data streaming, we also have to have some way of shuffling the iterator. This is awkward if the training files have multiple documents in them. To solve this, the GoldCorpus class now writes the training data to disk in msgpack files, one per document. We can then shuffle the data by shuffling the paths. This is a squash merge, as I made a lot of very small commits. Individual commit messages below. * Simplify label management for TransitionSystem and its subclasses * Fix serialization for new label handling format in parser * Simplify and improve GoldCorpus class. Reduce memory use, write to temp dir * Set actions in transition system * Require thinc 6.11.1.dev4 * Fix error in parser init * Add unicode declaration * Fix unicode declaration * Update textcat test * Try to get model training on less memory * Print json loc for now * Try rapidjson to reduce memory use * Remove rapidjson requirement * Try rapidjson for reduced mem usage * Handle None heads when projectivising * Stream json docs * Fix train script * Handle projectivity in GoldParse * Fix projectivity handling * Add minibatch_by_words util from ud_train * Minibatch by number of words in spacy.cli.train * Move minibatch_by_words util to spacy.util * Fix label handling * More hacking at label management in parser * Fix encoding in msgpack serialization in GoldParse * Adjust batch sizes in parser training * Fix minibatch_by_words * Add merge_subtokens function to pipeline.pyx * Register merge_subtokens factory * Restore use of msgpack tmp directory * Use minibatch-by-words in train * Handle retokenization in scorer * Change back-off approach for missing labels. Use 'dep' label * Update NER for new label management * Set NER tags for over-segmented words * Fix label alignment in gold * Fix label back-off for infrequent labels * Fix int type in labels dict key * Fix int type in labels dict key * Update feature definition for 8 feature set * Update ud-train script for new label stuff * Fix json streamer * Print the line number if conll eval fails * Update children and sentence boundaries after deprojectivisation * Export set_children_from_heads from doc.pxd * Render parses during UD training * Remove print statement * Require thinc 6.11.1.dev6. Try adding wheel as install_requires * Set different dev version, to flush pip cache * Update thinc version * Update GoldCorpus docs * Remove print statements * Fix formatting and links [ci skip]
655 lines
20 KiB
Python
655 lines
20 KiB
Python
# coding: utf8
|
||
from __future__ import unicode_literals, print_function
|
||
|
||
import os
|
||
import ujson
|
||
import pkg_resources
|
||
import importlib
|
||
import regex as re
|
||
from pathlib import Path
|
||
import sys
|
||
import textwrap
|
||
import random
|
||
from collections import OrderedDict
|
||
import inspect
|
||
import warnings
|
||
from thinc.neural._classes.model import Model
|
||
import functools
|
||
import cytoolz
|
||
import itertools
|
||
import numpy.random
|
||
|
||
from .symbols import ORTH
|
||
from .compat import cupy, CudaStream, path2str, basestring_, input_, unicode_
|
||
from .compat import import_file
|
||
|
||
import msgpack
|
||
import msgpack_numpy
|
||
msgpack_numpy.patch()
|
||
|
||
|
||
LANGUAGES = {}
|
||
_data_path = Path(__file__).parent / 'data'
|
||
_PRINT_ENV = False
|
||
|
||
|
||
def set_env_log(value):
|
||
global _PRINT_ENV
|
||
_PRINT_ENV = value
|
||
|
||
|
||
def get_lang_class(lang):
|
||
"""Import and load a Language class.
|
||
|
||
lang (unicode): Two-letter language code, e.g. 'en'.
|
||
RETURNS (Language): Language class.
|
||
"""
|
||
global LANGUAGES
|
||
if lang not in LANGUAGES:
|
||
try:
|
||
module = importlib.import_module('.lang.%s' % lang, 'spacy')
|
||
except ImportError:
|
||
msg = "Can't import language %s from spacy.lang."
|
||
raise ImportError(msg % lang)
|
||
LANGUAGES[lang] = getattr(module, module.__all__[0])
|
||
return LANGUAGES[lang]
|
||
|
||
|
||
def set_lang_class(name, cls):
|
||
"""Set a custom Language class name that can be loaded via get_lang_class.
|
||
|
||
name (unicode): Name of Language class.
|
||
cls (Language): Language class.
|
||
"""
|
||
global LANGUAGES
|
||
LANGUAGES[name] = cls
|
||
|
||
|
||
def get_data_path(require_exists=True):
|
||
"""Get path to spaCy data directory.
|
||
|
||
require_exists (bool): Only return path if it exists, otherwise None.
|
||
RETURNS (Path or None): Data path or None.
|
||
"""
|
||
if not require_exists:
|
||
return _data_path
|
||
else:
|
||
return _data_path if _data_path.exists() else None
|
||
|
||
|
||
def set_data_path(path):
|
||
"""Set path to spaCy data directory.
|
||
|
||
path (unicode or Path): Path to new data directory.
|
||
"""
|
||
global _data_path
|
||
_data_path = ensure_path(path)
|
||
|
||
|
||
def ensure_path(path):
|
||
"""Ensure string is converted to a Path.
|
||
|
||
path: Anything. If string, it's converted to Path.
|
||
RETURNS: Path or original argument.
|
||
"""
|
||
if isinstance(path, basestring_):
|
||
return Path(path)
|
||
else:
|
||
return path
|
||
|
||
|
||
def load_model(name, **overrides):
|
||
"""Load a model from a shortcut link, package or data path.
|
||
|
||
name (unicode): Package name, shortcut link or model path.
|
||
**overrides: Specific overrides, like pipeline components to disable.
|
||
RETURNS (Language): `Language` class with the loaded model.
|
||
"""
|
||
data_path = get_data_path()
|
||
if not data_path or not data_path.exists():
|
||
raise IOError("Can't find spaCy data path: %s" % path2str(data_path))
|
||
if isinstance(name, basestring_): # in data dir / shortcut
|
||
if name in set([d.name for d in data_path.iterdir()]):
|
||
return load_model_from_link(name, **overrides)
|
||
if is_package(name): # installed as package
|
||
return load_model_from_package(name, **overrides)
|
||
if Path(name).exists(): # path to model data directory
|
||
return load_model_from_path(Path(name), **overrides)
|
||
elif hasattr(name, 'exists'): # Path or Path-like to model data
|
||
return load_model_from_path(name, **overrides)
|
||
raise IOError("Can't find model '%s'" % name)
|
||
|
||
|
||
def load_model_from_link(name, **overrides):
|
||
"""Load a model from a shortcut link, or directory in spaCy data path."""
|
||
path = get_data_path() / name / '__init__.py'
|
||
try:
|
||
cls = import_file(name, path)
|
||
except AttributeError:
|
||
raise IOError(
|
||
"Cant' load '%s'. If you're using a shortcut link, make sure it "
|
||
"points to a valid package (not just a data directory)." % name)
|
||
return cls.load(**overrides)
|
||
|
||
|
||
def load_model_from_package(name, **overrides):
|
||
"""Load a model from an installed package."""
|
||
cls = importlib.import_module(name)
|
||
return cls.load(**overrides)
|
||
|
||
|
||
def load_model_from_path(model_path, meta=False, **overrides):
|
||
"""Load a model from a data directory path. Creates Language class with
|
||
pipeline from meta.json and then calls from_disk() with path."""
|
||
if not meta:
|
||
meta = get_model_meta(model_path)
|
||
cls = get_lang_class(meta['lang'])
|
||
nlp = cls(meta=meta, **overrides)
|
||
pipeline = meta.get('pipeline', [])
|
||
disable = overrides.get('disable', [])
|
||
if pipeline is True:
|
||
pipeline = nlp.Defaults.pipe_names
|
||
elif pipeline in (False, None):
|
||
pipeline = []
|
||
for name in pipeline:
|
||
if name not in disable:
|
||
config = meta.get('pipeline_args', {}).get(name, {})
|
||
component = nlp.create_pipe(name, config=config)
|
||
nlp.add_pipe(component, name=name)
|
||
return nlp.from_disk(model_path)
|
||
|
||
|
||
def load_model_from_init_py(init_file, **overrides):
|
||
"""Helper function to use in the `load()` method of a model package's
|
||
__init__.py.
|
||
|
||
init_file (unicode): Path to model's __init__.py, i.e. `__file__`.
|
||
**overrides: Specific overrides, like pipeline components to disable.
|
||
RETURNS (Language): `Language` class with loaded model.
|
||
"""
|
||
model_path = Path(init_file).parent
|
||
meta = get_model_meta(model_path)
|
||
data_dir = '%s_%s-%s' % (meta['lang'], meta['name'], meta['version'])
|
||
data_path = model_path / data_dir
|
||
if not model_path.exists():
|
||
msg = "Can't find model directory: %s"
|
||
raise ValueError(msg % path2str(data_path))
|
||
return load_model_from_path(data_path, meta, **overrides)
|
||
|
||
|
||
def get_model_meta(path):
|
||
"""Get model meta.json from a directory path and validate its contents.
|
||
|
||
path (unicode or Path): Path to model directory.
|
||
RETURNS (dict): The model's meta data.
|
||
"""
|
||
model_path = ensure_path(path)
|
||
if not model_path.exists():
|
||
msg = "Can't find model directory: %s"
|
||
raise ValueError(msg % path2str(model_path))
|
||
meta_path = model_path / 'meta.json'
|
||
if not meta_path.is_file():
|
||
raise IOError("Could not read meta.json from %s" % meta_path)
|
||
meta = read_json(meta_path)
|
||
for setting in ['lang', 'name', 'version']:
|
||
if setting not in meta or not meta[setting]:
|
||
msg = "No valid '%s' setting found in model meta.json"
|
||
raise ValueError(msg % setting)
|
||
return meta
|
||
|
||
|
||
def is_package(name):
|
||
"""Check if string maps to a package installed via pip.
|
||
|
||
name (unicode): Name of package.
|
||
RETURNS (bool): True if installed package, False if not.
|
||
"""
|
||
name = name.lower() # compare package name against lowercase name
|
||
packages = pkg_resources.working_set.by_key.keys()
|
||
for package in packages:
|
||
if package.lower().replace('-', '_') == name:
|
||
return True
|
||
return False
|
||
|
||
|
||
def get_package_path(name):
|
||
"""Get the path to an installed package.
|
||
|
||
name (unicode): Package name.
|
||
RETURNS (Path): Path to installed package.
|
||
"""
|
||
name = name.lower() # use lowercase version to be safe
|
||
# Here we're importing the module just to find it. This is worryingly
|
||
# indirect, but it's otherwise very difficult to find the package.
|
||
pkg = importlib.import_module(name)
|
||
return Path(pkg.__file__).parent
|
||
|
||
|
||
def is_in_jupyter():
|
||
"""Check if user is running spaCy from a Jupyter notebook by detecting the
|
||
IPython kernel. Mainly used for the displaCy visualizer.
|
||
|
||
RETURNS (bool): True if in Jupyter, False if not.
|
||
"""
|
||
try:
|
||
cfg = get_ipython().config
|
||
if cfg['IPKernelApp']['parent_appname'] == 'ipython-notebook':
|
||
return True
|
||
except NameError:
|
||
return False
|
||
return False
|
||
|
||
|
||
def get_cuda_stream(require=False):
|
||
return CudaStream() if CudaStream is not None else None
|
||
|
||
|
||
def get_async(stream, numpy_array):
|
||
if cupy is None:
|
||
return numpy_array
|
||
else:
|
||
array = cupy.ndarray(numpy_array.shape, order='C',
|
||
dtype=numpy_array.dtype)
|
||
array.set(numpy_array, stream=stream)
|
||
return array
|
||
|
||
|
||
def env_opt(name, default=None):
|
||
if type(default) is float:
|
||
type_convert = float
|
||
else:
|
||
type_convert = int
|
||
if 'SPACY_' + name.upper() in os.environ:
|
||
value = type_convert(os.environ['SPACY_' + name.upper()])
|
||
if _PRINT_ENV:
|
||
print(name, "=", repr(value), "via", "$SPACY_" + name.upper())
|
||
return value
|
||
elif name in os.environ:
|
||
value = type_convert(os.environ[name])
|
||
if _PRINT_ENV:
|
||
print(name, "=", repr(value), "via", '$' + name)
|
||
return value
|
||
else:
|
||
if _PRINT_ENV:
|
||
print(name, '=', repr(default), "by default")
|
||
return default
|
||
|
||
|
||
def read_regex(path):
|
||
path = ensure_path(path)
|
||
with path.open() as file_:
|
||
entries = file_.read().split('\n')
|
||
expression = '|'.join(['^' + re.escape(piece)
|
||
for piece in entries if piece.strip()])
|
||
return re.compile(expression)
|
||
|
||
|
||
def compile_prefix_regex(entries):
|
||
if '(' in entries:
|
||
# Handle deprecated data
|
||
expression = '|'.join(['^' + re.escape(piece)
|
||
for piece in entries if piece.strip()])
|
||
return re.compile(expression)
|
||
else:
|
||
expression = '|'.join(['^' + piece
|
||
for piece in entries if piece.strip()])
|
||
return re.compile(expression)
|
||
|
||
|
||
def compile_suffix_regex(entries):
|
||
expression = '|'.join([piece + '$' for piece in entries if piece.strip()])
|
||
return re.compile(expression)
|
||
|
||
|
||
def compile_infix_regex(entries):
|
||
expression = '|'.join([piece for piece in entries if piece.strip()])
|
||
return re.compile(expression)
|
||
|
||
|
||
def add_lookups(default_func, *lookups):
|
||
"""Extend an attribute function with special cases. If a word is in the
|
||
lookups, the value is returned. Otherwise the previous function is used.
|
||
|
||
default_func (callable): The default function to execute.
|
||
*lookups (dict): Lookup dictionary mapping string to attribute value.
|
||
RETURNS (callable): Lexical attribute getter.
|
||
"""
|
||
# This is implemented as functools.partial instead of a closure, to allow
|
||
# pickle to work.
|
||
return functools.partial(_get_attr_unless_lookup, default_func, lookups)
|
||
|
||
|
||
def _get_attr_unless_lookup(default_func, lookups, string):
|
||
for lookup in lookups:
|
||
if string in lookup:
|
||
return lookup[string]
|
||
return default_func(string)
|
||
|
||
|
||
def update_exc(base_exceptions, *addition_dicts):
|
||
"""Update and validate tokenizer exceptions. Will overwrite exceptions.
|
||
|
||
base_exceptions (dict): Base exceptions.
|
||
*addition_dicts (dict): Exceptions to add to the base dict, in order.
|
||
RETURNS (dict): Combined tokenizer exceptions.
|
||
"""
|
||
exc = dict(base_exceptions)
|
||
for additions in addition_dicts:
|
||
for orth, token_attrs in additions.items():
|
||
if not all(isinstance(attr[ORTH], unicode_)
|
||
for attr in token_attrs):
|
||
msg = "Invalid ORTH value in exception: key='%s', orths='%s'"
|
||
raise ValueError(msg % (orth, token_attrs))
|
||
described_orth = ''.join(attr[ORTH] for attr in token_attrs)
|
||
if orth != described_orth:
|
||
msg = ("Invalid tokenizer exception: ORTH values combined "
|
||
"don't match original string. key='%s', orths='%s'")
|
||
raise ValueError(msg % (orth, described_orth))
|
||
exc.update(additions)
|
||
exc = expand_exc(exc, "'", "’")
|
||
return exc
|
||
|
||
|
||
def expand_exc(excs, search, replace):
|
||
"""Find string in tokenizer exceptions, duplicate entry and replace string.
|
||
For example, to add additional versions with typographic apostrophes.
|
||
|
||
excs (dict): Tokenizer exceptions.
|
||
search (unicode): String to find and replace.
|
||
replace (unicode): Replacement.
|
||
RETURNS (dict): Combined tokenizer exceptions.
|
||
"""
|
||
def _fix_token(token, search, replace):
|
||
fixed = dict(token)
|
||
fixed[ORTH] = fixed[ORTH].replace(search, replace)
|
||
return fixed
|
||
new_excs = dict(excs)
|
||
for token_string, tokens in excs.items():
|
||
if search in token_string:
|
||
new_key = token_string.replace(search, replace)
|
||
new_value = [_fix_token(t, search, replace) for t in tokens]
|
||
new_excs[new_key] = new_value
|
||
return new_excs
|
||
|
||
|
||
def normalize_slice(length, start, stop, step=None):
|
||
if not (step is None or step == 1):
|
||
raise ValueError("Stepped slices not supported in Span objects."
|
||
"Try: list(tokens)[start:stop:step] instead.")
|
||
if start is None:
|
||
start = 0
|
||
elif start < 0:
|
||
start += length
|
||
start = min(length, max(0, start))
|
||
if stop is None:
|
||
stop = length
|
||
elif stop < 0:
|
||
stop += length
|
||
stop = min(length, max(start, stop))
|
||
assert 0 <= start <= stop <= length
|
||
return start, stop
|
||
|
||
|
||
def minibatch(items, size=8):
|
||
"""Iterate over batches of items. `size` may be an iterator,
|
||
so that batch-size can vary on each step.
|
||
"""
|
||
if isinstance(size, int):
|
||
size_ = itertools.repeat(size)
|
||
else:
|
||
size_ = size
|
||
items = iter(items)
|
||
while True:
|
||
batch_size = next(size_)
|
||
batch = list(cytoolz.take(int(batch_size), items))
|
||
if len(batch) == 0:
|
||
break
|
||
yield list(batch)
|
||
|
||
|
||
def compounding(start, stop, compound):
|
||
"""Yield an infinite series of compounding values. Each time the
|
||
generator is called, a value is produced by multiplying the previous
|
||
value by the compound rate.
|
||
|
||
EXAMPLE:
|
||
>>> sizes = compounding(1., 10., 1.5)
|
||
>>> assert next(sizes) == 1.
|
||
>>> assert next(sizes) == 1 * 1.5
|
||
>>> assert next(sizes) == 1.5 * 1.5
|
||
"""
|
||
def clip(value):
|
||
return max(value, stop) if (start > stop) else min(value, stop)
|
||
curr = float(start)
|
||
while True:
|
||
yield clip(curr)
|
||
curr *= compound
|
||
|
||
|
||
def decaying(start, stop, decay):
|
||
"""Yield an infinite series of linearly decaying values."""
|
||
def clip(value):
|
||
return max(value, stop) if (start > stop) else min(value, stop)
|
||
nr_upd = 1.
|
||
while True:
|
||
yield clip(start * 1./(1. + decay * nr_upd))
|
||
nr_upd += 1
|
||
|
||
|
||
def minibatch_by_words(items, size, count_words=len):
|
||
'''Create minibatches of a given number of words.'''
|
||
if isinstance(size, int):
|
||
size_ = itertools.repeat(size)
|
||
else:
|
||
size_ = size
|
||
items = iter(items)
|
||
while True:
|
||
batch_size = next(size_)
|
||
batch = []
|
||
while batch_size >= 0:
|
||
try:
|
||
doc, gold = next(items)
|
||
except StopIteration:
|
||
if batch:
|
||
yield batch
|
||
return
|
||
batch_size -= count_words(doc)
|
||
batch.append((doc, gold))
|
||
if batch:
|
||
yield batch
|
||
|
||
|
||
def itershuffle(iterable, bufsize=1000):
|
||
"""Shuffle an iterator. This works by holding `bufsize` items back
|
||
and yielding them sometime later. Obviously, this is not unbiased –
|
||
but should be good enough for batching. Larger bufsize means less bias.
|
||
From https://gist.github.com/andres-erbsen/1307752
|
||
|
||
iterable (iterable): Iterator to shuffle.
|
||
bufsize (int): Items to hold back.
|
||
YIELDS (iterable): The shuffled iterator.
|
||
"""
|
||
iterable = iter(iterable)
|
||
buf = []
|
||
try:
|
||
while True:
|
||
for i in range(random.randint(1, bufsize-len(buf))):
|
||
buf.append(next(iterable))
|
||
random.shuffle(buf)
|
||
for i in range(random.randint(1, bufsize)):
|
||
if buf:
|
||
yield buf.pop()
|
||
else:
|
||
break
|
||
except StopIteration:
|
||
random.shuffle(buf)
|
||
while buf:
|
||
yield buf.pop()
|
||
raise StopIteration
|
||
|
||
|
||
def read_json(location):
|
||
"""Open and load JSON from file.
|
||
|
||
location (Path): Path to JSON file.
|
||
RETURNS (dict): Loaded JSON content.
|
||
"""
|
||
location = ensure_path(location)
|
||
with location.open('r', encoding='utf8') as f:
|
||
return ujson.load(f)
|
||
|
||
|
||
def get_raw_input(description, default=False):
|
||
"""Get user input from the command line via raw_input / input.
|
||
|
||
description (unicode): Text to display before prompt.
|
||
default (unicode or False/None): Default value to display with prompt.
|
||
RETURNS (unicode): User input.
|
||
"""
|
||
additional = ' (default: %s)' % default if default else ''
|
||
prompt = ' %s%s: ' % (description, additional)
|
||
user_input = input_(prompt)
|
||
return user_input
|
||
|
||
|
||
def to_bytes(getters, exclude):
|
||
serialized = OrderedDict()
|
||
for key, getter in getters.items():
|
||
if key not in exclude:
|
||
serialized[key] = getter()
|
||
return msgpack.dumps(serialized, use_bin_type=True, encoding='utf8')
|
||
|
||
|
||
def from_bytes(bytes_data, setters, exclude):
|
||
msg = msgpack.loads(bytes_data, encoding='utf8')
|
||
for key, setter in setters.items():
|
||
if key not in exclude and key in msg:
|
||
setter(msg[key])
|
||
return msg
|
||
|
||
|
||
def to_disk(path, writers, exclude):
|
||
path = ensure_path(path)
|
||
if not path.exists():
|
||
path.mkdir()
|
||
for key, writer in writers.items():
|
||
if key not in exclude:
|
||
writer(path / key)
|
||
return path
|
||
|
||
|
||
def from_disk(path, readers, exclude):
|
||
path = ensure_path(path)
|
||
for key, reader in readers.items():
|
||
if key not in exclude:
|
||
reader(path / key)
|
||
return path
|
||
|
||
|
||
def deprecated(message, filter='always'):
|
||
"""Show a deprecation warning.
|
||
|
||
message (unicode): The message to display.
|
||
filter (unicode): Filter value.
|
||
"""
|
||
stack = inspect.stack()[-1]
|
||
with warnings.catch_warnings():
|
||
warnings.simplefilter(filter, DeprecationWarning)
|
||
warnings.warn_explicit(message, DeprecationWarning, stack[1], stack[2])
|
||
|
||
|
||
def print_table(data, title=None):
|
||
"""Print data in table format.
|
||
|
||
data (dict or list of tuples): Label/value pairs.
|
||
title (unicode or None): Title, will be printed above.
|
||
"""
|
||
if isinstance(data, dict):
|
||
data = list(data.items())
|
||
tpl_row = ' {:<15}' * len(data[0])
|
||
table = '\n'.join([tpl_row.format(l, unicode_(v)) for l, v in data])
|
||
if title:
|
||
print('\n \033[93m{}\033[0m'.format(title))
|
||
print('\n{}\n'.format(table))
|
||
|
||
|
||
def print_markdown(data, title=None):
|
||
"""Print data in GitHub-flavoured Markdown format for issues etc.
|
||
|
||
data (dict or list of tuples): Label/value pairs.
|
||
title (unicode or None): Title, will be rendered as headline 2.
|
||
"""
|
||
def excl_value(value):
|
||
# contains path, i.e. personal info
|
||
return isinstance(value, basestring_) and Path(value).exists()
|
||
|
||
if isinstance(data, dict):
|
||
data = list(data.items())
|
||
markdown = ["* **{}:** {}".format(l, unicode_(v))
|
||
for l, v in data if not excl_value(v)]
|
||
if title:
|
||
print("\n## {}".format(title))
|
||
print('\n{}\n'.format('\n'.join(markdown)))
|
||
|
||
|
||
def prints(*texts, **kwargs):
|
||
"""Print formatted message (manual ANSI escape sequences to avoid
|
||
dependency)
|
||
|
||
*texts (unicode): Texts to print. Each argument is rendered as paragraph.
|
||
**kwargs: 'title' becomes coloured headline. exits=True performs sys exit.
|
||
"""
|
||
exits = kwargs.get('exits', None)
|
||
title = kwargs.get('title', None)
|
||
title = '\033[93m{}\033[0m\n'.format(_wrap(title)) if title else ''
|
||
message = '\n\n'.join([_wrap(text) for text in texts])
|
||
print('\n{}{}\n'.format(title, message))
|
||
if exits is not None:
|
||
sys.exit(exits)
|
||
|
||
|
||
def _wrap(text, wrap_max=80, indent=4):
|
||
"""Wrap text at given width using textwrap module.
|
||
|
||
text (unicode): Text to wrap. If it's a Path, it's converted to string.
|
||
wrap_max (int): Maximum line length (indent is deducted).
|
||
indent (int): Number of spaces for indentation.
|
||
RETURNS (unicode): Wrapped text.
|
||
"""
|
||
indent = indent * ' '
|
||
wrap_width = wrap_max - len(indent)
|
||
if isinstance(text, Path):
|
||
text = path2str(text)
|
||
return textwrap.fill(text, width=wrap_width, initial_indent=indent,
|
||
subsequent_indent=indent, break_long_words=False,
|
||
break_on_hyphens=False)
|
||
|
||
|
||
def minify_html(html):
|
||
"""Perform a template-specific, rudimentary HTML minification for displaCy.
|
||
Disclaimer: NOT a general-purpose solution, only removes indentation and
|
||
newlines.
|
||
|
||
html (unicode): Markup to minify.
|
||
RETURNS (unicode): "Minified" HTML.
|
||
"""
|
||
return html.strip().replace(' ', '').replace('\n', '')
|
||
|
||
|
||
def use_gpu(gpu_id):
|
||
try:
|
||
import cupy.cuda.device
|
||
except ImportError:
|
||
return None
|
||
from thinc.neural.ops import CupyOps
|
||
device = cupy.cuda.device.Device(gpu_id)
|
||
device.use()
|
||
Model.ops = CupyOps()
|
||
Model.Ops = CupyOps
|
||
return device
|
||
|
||
|
||
def fix_random_seed(seed=0):
|
||
random.seed(seed)
|
||
numpy.random.seed(seed)
|