mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Merge branch 'develop' of https://github.com/explosion/spaCy into develop
This commit is contained in:
commit
6c51cd77b4
|
@ -6,6 +6,8 @@ import ftfy
|
|||
import sys
|
||||
import ujson
|
||||
|
||||
import thinc.neural.util
|
||||
|
||||
try:
|
||||
import cPickle as pickle
|
||||
except ImportError:
|
||||
|
@ -32,6 +34,7 @@ copy_reg = copy_reg
|
|||
CudaStream = CudaStream
|
||||
cupy = cupy
|
||||
fix_text = ftfy.fix_text
|
||||
copy_array = thinc.neural.util.copy_array
|
||||
|
||||
is_python2 = six.PY2
|
||||
is_python3 = six.PY3
|
||||
|
@ -71,3 +74,16 @@ def is_config(python2=None, python3=None, windows=None, linux=None, osx=None):
|
|||
(windows == None or windows == is_windows) and
|
||||
(linux == None or linux == is_linux) and
|
||||
(osx == None or osx == is_osx))
|
||||
|
||||
|
||||
def normalize_string_keys(old):
|
||||
'''Given a dictionary, make sure keys are unicode strings, not bytes.'''
|
||||
new = {}
|
||||
for key, value in old:
|
||||
if isinstance(key, bytes_):
|
||||
new[key.decode('utf8')] = value
|
||||
else:
|
||||
new[key] = value
|
||||
return new
|
||||
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ import ujson
|
|||
|
||||
from .symbols import ORTH
|
||||
from .compat import cupy, CudaStream, path2str, basestring_, input_, unicode_
|
||||
from .compat import copy_array, normalize_string_keys
|
||||
|
||||
|
||||
LANGUAGES = {}
|
||||
|
@ -243,6 +244,12 @@ def itershuffle(iterable, bufsize=1000):
|
|||
raise StopIteration
|
||||
|
||||
|
||||
_PRINT_ENV = False
|
||||
def set_env_log(value):
|
||||
global _PRINT_ENV
|
||||
_PRINT_ENV = value
|
||||
|
||||
|
||||
def env_opt(name, default=None):
|
||||
if type(default) is float:
|
||||
type_convert = float
|
||||
|
@ -250,14 +257,17 @@ def env_opt(name, default=None):
|
|||
type_convert = int
|
||||
if 'SPACY_' + name.upper() in os.environ:
|
||||
value = type_convert(os.environ['SPACY_' + name.upper()])
|
||||
print(name, "=", repr(value), "via", "$SPACY_" + name.upper())
|
||||
if _PRINT_ENV:
|
||||
print(name, "=", repr(value), "via", "$SPACY_" + name.upper())
|
||||
return value
|
||||
elif name in os.environ:
|
||||
value = type_convert(os.environ[name])
|
||||
print(name, "=", repr(value), "via", '$' + name)
|
||||
if _PRINT_ENV:
|
||||
print(name, "=", repr(value), "via", '$' + name)
|
||||
return value
|
||||
else:
|
||||
print(name, '=', repr(default), "by default")
|
||||
if _PRINT_ENV:
|
||||
print(name, '=', repr(default), "by default")
|
||||
return default
|
||||
|
||||
|
||||
|
@ -451,7 +461,9 @@ def model_to_bytes(model):
|
|||
i = 0
|
||||
for layer in queue:
|
||||
if hasattr(layer, '_mem'):
|
||||
weights.append({'dims': dict(getattr(layer, '_dims', {})), 'params': []})
|
||||
weights.append({
|
||||
'dims': normalize_string_keys(getattr(layer, '_dims', {})),
|
||||
'params': []})
|
||||
if hasattr(layer, 'seed'):
|
||||
weights[-1]['seed'] = layer.seed
|
||||
|
||||
|
@ -488,7 +500,7 @@ def model_from_bytes(model, bytes_data):
|
|||
setattr(layer, dim, value)
|
||||
for param in weights[i]['params']:
|
||||
dest = getattr(layer, param['name'])
|
||||
dest[:] = param['value']
|
||||
copy_array(dest, param['value'])
|
||||
i += 1
|
||||
if hasattr(layer, '_layers'):
|
||||
queue.extend(layer._layers)
|
||||
|
|
Loading…
Reference in New Issue
Block a user