This commit is contained in:
Richard Liaw 2020-07-06 16:15:31 -07:00
parent 3141ea0931
commit 3941367742
2 changed files with 1 additions and 40 deletions

View File

@ -171,7 +171,7 @@ def train_cli(
else:
if use_gpu >= 0:
msg.info(f"Using GPU: {str(use_gpu)}")
util.use_gpu(use_gpu)
require_gpu(use_gpu)
else:
msg.info("Using CPU")
train(**train_args)

View File

@ -906,45 +906,6 @@ def escape_html(text):
return text
<<<<<<< HEAD
def use_gpu(gpu_id):
return require_gpu(gpu_id)
def gpu_is_available():
try:
cupy.cuda.runtime.getDeviceCount()
return True
except cupy.cuda.runtime.CUDARuntimeError:
return False
def fix_random_seed(seed=0):
random.seed(seed)
numpy.random.seed(seed)
if cupy is not None and gpu_is_available():
cupy.random.seed(seed)
def get_serialization_exclude(serializers, exclude, kwargs):
"""Helper function to validate serialization args and manage transition from
keyword arguments (pre v2.1) to exclude argument.
"""
exclude = list(exclude)
# Split to support file names like meta.json
options = [name.split(".")[0] for name in serializers]
for key, value in kwargs.items():
if key in ("vocab",) and value is False:
warnings.warn(Warnings.W015.format(arg=key), DeprecationWarning)
exclude.append(key)
elif key.split(".")[0] in options:
raise ValueError(Errors.E128.format(arg=key))
# TODO: user warning?
return exclude
=======
>>>>>>> 19d42f42de30ba57e17427798ea2562cdab2c9f8
def get_words_and_spaces(words, text):
if "".join("".join(words).split()) != "".join(text.split()):
raise ValueError(Errors.E194.format(text=text, words=words))