Type annotations

This commit is contained in:
Daniël de Kok 2023-04-18 17:43:33 +02:00
parent a425808bd4
commit aa0783cf7f

View File

@ -94,21 +94,6 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
return nlp
def _set_gpu_allocator_from_config(config, use_gpu):
if "gpu_allocator" not in config["training"]:
raise ValueError(Errors.E1015.format(value="[training] gpu_allocator"))
allocator = config["training"]["gpu_allocator"]
if use_gpu >= 0 and allocator:
set_gpu_allocator(allocator)
def _set_seed_from_config(config):
if "seed" not in config["training"]:
raise ValueError(Errors.E1015.format(value="[training] seed"))
if config["training"]["seed"] is not None:
fix_random_seed(config["training"]["seed"])
def init_nlp_distill(
config: Config, teacher: "Language", *, use_gpu: int = -1
) -> "Language":
@ -430,3 +415,18 @@ def ensure_shape(vectors_loc):
yield from lines2
lines2.close()
lines.close()
def _set_gpu_allocator_from_config(config: Config, use_gpu: int):
if "gpu_allocator" not in config["training"]:
raise ValueError(Errors.E1015.format(value="[training] gpu_allocator"))
allocator = config["training"]["gpu_allocator"]
if use_gpu >= 0 and allocator:
set_gpu_allocator(allocator)
def _set_seed_from_config(config: Config):
if "seed" not in config["training"]:
raise ValueError(Errors.E1015.format(value="[training] seed"))
if config["training"]["seed"] is not None:
fix_random_seed(config["training"]["seed"])