Don't fix random seeds on import

This commit is contained in:
Johannes Dollinger 2018-02-13 12:42:23 +01:00
parent c63e99da8a
commit bf94c13382
3 changed files with 9 additions and 11 deletions

View File

@ -3,8 +3,6 @@ from __future__ import unicode_literals, division, print_function
import plac
from timeit import default_timer as timer
import random
import numpy.random
from ..gold import GoldCorpus
from ..util import prints
@ -12,10 +10,6 @@ from .. import util
from .. import displacy
random.seed(0)
numpy.random.seed(0)
@plac.annotations(
model=("model name or path", "positional", None, str),
data_path=("location of JSON-formatted evaluation data", "positional",
@ -31,6 +25,8 @@ def evaluate(model, data_path, gpu_id=-1, gold_preproc=False, displacy_path=None
Evaluate a model. To render a sample of parses in a HTML file, set an
output directory as the displacy_path argument.
"""
util.fix_random_seed()
if gpu_id >= 0:
util.use_gpu(gpu_id)
util.set_env_log(False)

View File

@ -6,8 +6,6 @@ from pathlib import Path
import tqdm
from thinc.neural._classes.model import Model
from timeit import default_timer as timer
import random
import numpy.random
from ..gold import GoldCorpus, minibatch
from ..util import prints
@ -16,9 +14,6 @@ from .. import about
from .. import displacy
from ..compat import json_dumps
random.seed(0)
numpy.random.seed(0)
@plac.annotations(
lang=("model language", "positional", None, str),
@ -45,6 +40,7 @@ def train(lang, output_dir, train_data, dev_data, n_iter=30, n_sents=0,
"""
Train a model. Expects data in spaCy's JSON format.
"""
util.fix_random_seed()
util.set_env_log(True)
n_sents = n_sents or None
output_path = util.ensure_path(output_dir)

View File

@ -17,6 +17,7 @@ from thinc.neural._classes.model import Model
import functools
import cytoolz
import itertools
import numpy as np
from .symbols import ORTH
from .compat import cupy, CudaStream, path2str, basestring_, input_, unicode_
@ -623,3 +624,8 @@ def use_gpu(gpu_id):
Model.ops = CupyOps()
Model.Ops = CupyOps
return device
def fix_random_seed(seed=0):
random.seed(0)
np.random.seed(0)