mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Update train.py
This commit is contained in:
parent
796f6c52d1
commit
fa3c98f8b3
|
@ -88,7 +88,7 @@ def train(
|
||||||
require_gpu(use_gpu)
|
require_gpu(use_gpu)
|
||||||
else:
|
else:
|
||||||
msg.info("Using CPU")
|
msg.info("Using CPU")
|
||||||
raw_text, tag_map, weights_data = load_from_paths(config)
|
raw_text, tag_map, morph_rules, weights_data = load_from_paths(config)
|
||||||
if config["training"]["seed"] is not None:
|
if config["training"]["seed"] is not None:
|
||||||
fix_random_seed(config["training"]["seed"])
|
fix_random_seed(config["training"]["seed"])
|
||||||
if config["training"].get("use_pytorch_for_gpu_memory"):
|
if config["training"].get("use_pytorch_for_gpu_memory"):
|
||||||
|
@ -498,7 +498,7 @@ def load_from_paths(config):
|
||||||
msg.fail("Can't find pretrained tok2vec", init_tok2vec, exits=1)
|
msg.fail("Can't find pretrained tok2vec", init_tok2vec, exits=1)
|
||||||
with init_tok2vec.open("rb") as file_:
|
with init_tok2vec.open("rb") as file_:
|
||||||
weights_data = file_.read()
|
weights_data = file_.read()
|
||||||
return raw_text, tag_map, weights_data
|
return raw_text, tag_map, morph_rules, weights_data
|
||||||
|
|
||||||
|
|
||||||
def verify_cli_args(
|
def verify_cli_args(
|
||||||
|
|
Loading…
Reference in New Issue
Block a user