mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 18:06:29 +03:00
Silence env_opt, and fix serialization for GPU
This commit is contained in:
parent
53a3824334
commit
92f9e5cc9a
|
@ -21,6 +21,7 @@ import ujson
|
||||||
|
|
||||||
from .symbols import ORTH
|
from .symbols import ORTH
|
||||||
from .compat import cupy, CudaStream, path2str, basestring_, input_, unicode_
|
from .compat import cupy, CudaStream, path2str, basestring_, input_, unicode_
|
||||||
|
from .compat import copy_array, normalize_string_keys
|
||||||
|
|
||||||
|
|
||||||
LANGUAGES = {}
|
LANGUAGES = {}
|
||||||
|
@ -242,6 +243,12 @@ def itershuffle(iterable, bufsize=1000):
|
||||||
raise StopIteration
|
raise StopIteration
|
||||||
|
|
||||||
|
|
||||||
|
_PRINT_ENV = False
|
||||||
|
def set_env_log(value):
|
||||||
|
global _PRINT_ENV
|
||||||
|
_PRINT_ENV = value
|
||||||
|
|
||||||
|
|
||||||
def env_opt(name, default=None):
|
def env_opt(name, default=None):
|
||||||
if type(default) is float:
|
if type(default) is float:
|
||||||
type_convert = float
|
type_convert = float
|
||||||
|
@ -249,13 +256,16 @@ def env_opt(name, default=None):
|
||||||
type_convert = int
|
type_convert = int
|
||||||
if 'SPACY_' + name.upper() in os.environ:
|
if 'SPACY_' + name.upper() in os.environ:
|
||||||
value = type_convert(os.environ['SPACY_' + name.upper()])
|
value = type_convert(os.environ['SPACY_' + name.upper()])
|
||||||
|
if _PRINT_ENV:
|
||||||
print(name, "=", repr(value), "via", "$SPACY_" + name.upper())
|
print(name, "=", repr(value), "via", "$SPACY_" + name.upper())
|
||||||
return value
|
return value
|
||||||
elif name in os.environ:
|
elif name in os.environ:
|
||||||
value = type_convert(os.environ[name])
|
value = type_convert(os.environ[name])
|
||||||
|
if _PRINT_ENV:
|
||||||
print(name, "=", repr(value), "via", '$' + name)
|
print(name, "=", repr(value), "via", '$' + name)
|
||||||
return value
|
return value
|
||||||
else:
|
else:
|
||||||
|
if _PRINT_ENV:
|
||||||
print(name, '=', repr(default), "by default")
|
print(name, '=', repr(default), "by default")
|
||||||
return default
|
return default
|
||||||
|
|
||||||
|
@ -432,7 +442,9 @@ def model_to_bytes(model):
|
||||||
i = 0
|
i = 0
|
||||||
for layer in queue:
|
for layer in queue:
|
||||||
if hasattr(layer, '_mem'):
|
if hasattr(layer, '_mem'):
|
||||||
weights.append({'dims': dict(getattr(layer, '_dims', {})), 'params': []})
|
weights.append({
|
||||||
|
'dims': normalize_string_keys(getattr(layer, '_dims', {})),
|
||||||
|
'params': []})
|
||||||
if hasattr(layer, 'seed'):
|
if hasattr(layer, 'seed'):
|
||||||
weights[-1]['seed'] = layer.seed
|
weights[-1]['seed'] = layer.seed
|
||||||
|
|
||||||
|
@ -469,7 +481,7 @@ def model_from_bytes(model, bytes_data):
|
||||||
setattr(layer, dim, value)
|
setattr(layer, dim, value)
|
||||||
for param in weights[i]['params']:
|
for param in weights[i]['params']:
|
||||||
dest = getattr(layer, param['name'])
|
dest = getattr(layer, param['name'])
|
||||||
dest[:] = param['value']
|
copy_array(dest, param['value'])
|
||||||
i += 1
|
i += 1
|
||||||
if hasattr(layer, '_layers'):
|
if hasattr(layer, '_layers'):
|
||||||
queue.extend(layer._layers)
|
queue.extend(layer._layers)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user