Fix use_pytorch_for_gpu_memory

This commit is contained in:
Matthew Honnibal 2020-09-01 00:41:38 +02:00
parent 9130094199
commit ec660e3131

View File

@ -77,6 +77,9 @@ def train(
)
if config.get("training", {}).get("seed") is not None:
fix_random_seed(config["training"]["seed"])
if config.get("system", {}).get("use_pytorch_for_gpu_memory"):
# It feels kind of weird to not have a default for this.
use_pytorch_for_gpu_memory()
# Use original config here before it's resolved to functions
sourced_components = get_sourced_components(config)
with show_validation_error(config_path):
@ -85,9 +88,6 @@ def train(
util.load_vectors_into_model(nlp, config["training"]["vectors"])
verify_config(nlp)
raw_text, tag_map, morph_rules, weights_data = load_from_paths(config)
if config.get("system", {}).get("use_pytorch_for_gpu_memory"):
# It feels kind of weird to not have a default for this.
use_pytorch_for_gpu_memory()
T_cfg = config["training"]
optimizer = T_cfg["optimizer"]
train_corpus = T_cfg["train_corpus"]