mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
Fix use_pytorch_for_gpu_memory
This commit is contained in:
parent
9130094199
commit
ec660e3131
|
@ -77,6 +77,9 @@ def train(
|
||||||
)
|
)
|
||||||
if config.get("training", {}).get("seed") is not None:
|
if config.get("training", {}).get("seed") is not None:
|
||||||
fix_random_seed(config["training"]["seed"])
|
fix_random_seed(config["training"]["seed"])
|
||||||
|
if config.get("system", {}).get("use_pytorch_for_gpu_memory"):
|
||||||
|
# It feels kind of weird to not have a default for this.
|
||||||
|
use_pytorch_for_gpu_memory()
|
||||||
# Use original config here before it's resolved to functions
|
# Use original config here before it's resolved to functions
|
||||||
sourced_components = get_sourced_components(config)
|
sourced_components = get_sourced_components(config)
|
||||||
with show_validation_error(config_path):
|
with show_validation_error(config_path):
|
||||||
|
@ -85,9 +88,6 @@ def train(
|
||||||
util.load_vectors_into_model(nlp, config["training"]["vectors"])
|
util.load_vectors_into_model(nlp, config["training"]["vectors"])
|
||||||
verify_config(nlp)
|
verify_config(nlp)
|
||||||
raw_text, tag_map, morph_rules, weights_data = load_from_paths(config)
|
raw_text, tag_map, morph_rules, weights_data = load_from_paths(config)
|
||||||
if config.get("system", {}).get("use_pytorch_for_gpu_memory"):
|
|
||||||
# It feels kind of weird to not have a default for this.
|
|
||||||
use_pytorch_for_gpu_memory()
|
|
||||||
T_cfg = config["training"]
|
T_cfg = config["training"]
|
||||||
optimizer = T_cfg["optimizer"]
|
optimizer = T_cfg["optimizer"]
|
||||||
train_corpus = T_cfg["train_corpus"]
|
train_corpus = T_cfg["train_corpus"]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user