spaCy/spacy/compat.py

146 lines
3.5 KiB
Python
Raw Normal View History

# coding: utf8
from __future__ import unicode_literals
import os
import sys
2017-04-15 14:05:15 +03:00
import ujson
2017-07-25 19:57:59 +03:00
import itertools
2017-10-12 23:22:04 +03:00
import locale
2017-05-31 16:25:21 +03:00
from thinc.neural.util import copy_array
try:
import cPickle as pickle
except ImportError:
import pickle
try:
import copy_reg
except ImportError:
import copyreg as copy_reg
2017-05-18 15:12:45 +03:00
try:
from cupy.cuda.stream import Stream as CudaStream
except ImportError:
CudaStream = None
try:
import cupy
except ImportError:
cupy = None
try:
2017-10-27 15:33:42 +03:00
from thinc.neural.optimizers import Optimizer
except ImportError:
2017-10-27 15:33:42 +03:00
from thinc.neural.optimizers import Adam as Optimizer
2017-05-18 15:12:45 +03:00
pickle = pickle
copy_reg = copy_reg
CudaStream = CudaStream
cupy = cupy
2017-05-31 16:25:21 +03:00
copy_array = copy_array
2018-11-26 15:27:41 +03:00
izip = getattr(itertools, "izip", zip)
2018-11-26 15:27:41 +03:00
is_windows = sys.platform.startswith("win")
is_linux = sys.platform.startswith("linux")
is_osx = sys.platform == "darwin"
# See: https://github.com/benjaminp/six/blob/master/six.py
is_python2 = sys.version_info[0] == 2
is_python3 = sys.version_info[0] == 3
is_python_pre_3_5 = is_python2 or (is_python3 and sys.version_info[1] < 5)
if is_python2:
bytes_ = str
unicode_ = unicode # noqa: F821
basestring_ = basestring # noqa: F821
input_ = raw_input # noqa: F821
2018-11-26 15:27:41 +03:00
json_dumps = lambda data: ujson.dumps(
data, indent=2, escape_forward_slashes=False
).decode("utf8")
path2str = lambda path: str(path).decode("utf8")
elif is_python3:
bytes_ = bytes
unicode_ = str
basestring_ = str
input_ = input
json_dumps = lambda data: ujson.dumps(data, indent=2, escape_forward_slashes=False)
2017-05-08 00:24:56 +03:00
path2str = lambda path: str(path)
def b_to_str(b_str):
if is_python2:
return b_str
# important: if no encoding is set, string becomes "b'...'"
2018-11-26 15:27:41 +03:00
return str(b_str, encoding="utf8")
2017-05-31 23:21:44 +03:00
def getattr_(obj, name, *default):
if is_python3 and isinstance(name, bytes):
2018-11-26 15:27:41 +03:00
name = name.decode("utf8")
2017-05-31 23:21:44 +03:00
return getattr(obj, name, *default)
def symlink_to(orig, dest):
if is_windows:
import subprocess
2018-11-26 15:27:41 +03:00
subprocess.call(["mklink", "/d", path2str(orig), path2str(dest)], shell=True)
else:
orig.symlink_to(dest)
2018-11-26 15:27:41 +03:00
def symlink_remove(link):
2018-11-26 15:27:41 +03:00
# https://stackoverflow.com/q/26554135/6400719
if os.path.isdir(path2str(link)) and is_windows:
# this should only be on Py2.7 and windows
os.rmdir(path2str(link))
else:
os.unlink(path2str(link))
2018-11-26 15:27:41 +03:00
def is_config(python2=None, python3=None, windows=None, linux=None, osx=None):
2018-11-26 15:27:41 +03:00
return (
python2 in (None, is_python2)
and python3 in (None, is_python3)
and windows in (None, is_windows)
and linux in (None, is_linux)
and osx in (None, is_osx)
)
def normalize_string_keys(old):
2017-10-27 15:39:09 +03:00
"""Given a dictionary, make sure keys are unicode strings, not bytes."""
new = {}
2017-05-31 22:08:16 +03:00
for key, value in old.items():
if isinstance(key, bytes_):
2018-11-26 15:27:41 +03:00
new[key.decode("utf8")] = value
else:
new[key] = value
return new
2017-08-18 22:56:47 +03:00
def import_file(name, loc):
loc = str(loc)
if is_python_pre_3_5:
import imp
2018-11-26 15:27:41 +03:00
2017-08-18 22:56:47 +03:00
return imp.load_source(name, loc)
else:
import importlib.util
2018-11-26 15:27:41 +03:00
2017-08-19 23:32:07 +03:00
spec = importlib.util.spec_from_file_location(name, str(loc))
2017-08-18 22:56:47 +03:00
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
2017-10-12 23:22:04 +03:00
2018-11-26 15:27:41 +03:00
def locale_escape(string, errors="replace"):
"""
2017-10-12 23:22:04 +03:00
Mangle non-supported characters, for savages with ascii terminals.
2018-11-26 15:27:41 +03:00
"""
2017-10-12 23:22:04 +03:00
encoding = locale.getpreferredencoding()
2018-11-26 15:27:41 +03:00
string = string.encode(encoding, errors).decode("utf8")
2017-10-12 23:22:04 +03:00
return string