mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 10:46:29 +03:00
Silence env_opt, and fix serialization for GPU
This commit is contained in:
parent
53a3824334
commit
92f9e5cc9a
|
@ -21,6 +21,7 @@ import ujson
|
|||
|
||||
from .symbols import ORTH
|
||||
from .compat import cupy, CudaStream, path2str, basestring_, input_, unicode_
|
||||
from .compat import copy_array, normalize_string_keys
|
||||
|
||||
|
||||
LANGUAGES = {}
|
||||
|
@ -242,6 +243,12 @@ def itershuffle(iterable, bufsize=1000):
|
|||
raise StopIteration
|
||||
|
||||
|
||||
_PRINT_ENV = False
|
||||
def set_env_log(value):
|
||||
global _PRINT_ENV
|
||||
_PRINT_ENV = value
|
||||
|
||||
|
||||
def env_opt(name, default=None):
|
||||
if type(default) is float:
|
||||
type_convert = float
|
||||
|
@ -249,13 +256,16 @@ def env_opt(name, default=None):
|
|||
type_convert = int
|
||||
if 'SPACY_' + name.upper() in os.environ:
|
||||
value = type_convert(os.environ['SPACY_' + name.upper()])
|
||||
if _PRINT_ENV:
|
||||
print(name, "=", repr(value), "via", "$SPACY_" + name.upper())
|
||||
return value
|
||||
elif name in os.environ:
|
||||
value = type_convert(os.environ[name])
|
||||
if _PRINT_ENV:
|
||||
print(name, "=", repr(value), "via", '$' + name)
|
||||
return value
|
||||
else:
|
||||
if _PRINT_ENV:
|
||||
print(name, '=', repr(default), "by default")
|
||||
return default
|
||||
|
||||
|
@ -432,7 +442,9 @@ def model_to_bytes(model):
|
|||
i = 0
|
||||
for layer in queue:
|
||||
if hasattr(layer, '_mem'):
|
||||
weights.append({'dims': dict(getattr(layer, '_dims', {})), 'params': []})
|
||||
weights.append({
|
||||
'dims': normalize_string_keys(getattr(layer, '_dims', {})),
|
||||
'params': []})
|
||||
if hasattr(layer, 'seed'):
|
||||
weights[-1]['seed'] = layer.seed
|
||||
|
||||
|
@ -469,7 +481,7 @@ def model_from_bytes(model, bytes_data):
|
|||
setattr(layer, dim, value)
|
||||
for param in weights[i]['params']:
|
||||
dest = getattr(layer, param['name'])
|
||||
dest[:] = param['value']
|
||||
copy_array(dest, param['value'])
|
||||
i += 1
|
||||
if hasattr(layer, '_layers'):
|
||||
queue.extend(layer._layers)
|
||||
|
|
Loading…
Reference in New Issue
Block a user