mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
Set cupy.random seed in fix_random_seed helper
This commit is contained in:
parent
ffdd5e964f
commit
11a29af751
|
@ -15,6 +15,11 @@ import itertools
|
|||
import numpy.random
|
||||
import srsly
|
||||
|
||||
try:
|
||||
import cupy.random
|
||||
except ImportError:
|
||||
cupy = None
|
||||
|
||||
from .symbols import ORTH
|
||||
from .compat import cupy, CudaStream, path2str, basestring_, unicode_
|
||||
from .compat import import_file
|
||||
|
@ -598,6 +603,8 @@ def use_gpu(gpu_id):
|
|||
def fix_random_seed(seed=0):
|
||||
random.seed(seed)
|
||||
numpy.random.seed(seed)
|
||||
if cupy is not None:
|
||||
cupy.random.seed(seed)
|
||||
|
||||
|
||||
class SimpleFrozenDict(dict):
|
||||
|
|
Loading…
Reference in New Issue
Block a user