Tmp refactor train func for ray

This commit is contained in:
Matthw Honnibal 2020-07-16 03:39:42 +02:00
parent 3bccf8b954
commit 874b34e5d4

View File

@ -156,20 +156,35 @@ def train_cli(
with init_tok2vec.open("rb") as file_:
weights_data = file_.read()
train_args = dict(
config_path=config_path,
data_paths={"train": train_path, "dev": dev_path},
output_path=output_path,
raw_text=raw_text,
tag_map=tag_map,
weights_data=weights_data,
omit_extra_lookups=omit_extra_lookups
)
if num_workers and num_workers > 1:
from spacy_ray import distributed_setup_and_train
distributed_setup_and_train(use_gpu, num_workers, strategy, ray_address, train_args)
from _ray_async_utils import distributed_setup_and_train
distributed_setup_and_train(
use_gpu,
num_workers,
strategy,
ray_address,
{
"config": config_path,
"train": train_path,
"dev": dev_path,
"output": output_path
}
)
else:
nlp, config = load_nlp_and_config(config_path)
corpus = Corpus(train_path, dev_path, limit=config["training"]["limit"])
train_args = dict(
nlp=nlp,
config=config,
corpus=corpus,
output_path=output_path,
raw_text=raw_text,
tag_map=tag_map,
weights_data=weights_data,
omit_extra_lookups=omit_extra_lookups
)
if use_gpu >= 0:
msg.info(f"Using GPU: {use_gpu}")
require_gpu(use_gpu)
@ -177,20 +192,8 @@ def train_cli(
msg.info("Using CPU")
train(**train_args)
def train(
config_path: Path,
data_paths: Dict[str, Path],
raw_text: Optional[Path] = None,
output_path: Optional[Path] = None,
tag_map: Optional[Path] = None,
weights_data: Optional[bytes] = None,
omit_extra_lookups: bool = False,
disable_tqdm: bool = False,
remote_optimizer = None,
randomization_index: int = 0
) -> None:
msg.info(f"Loading config from: {config_path}")
# Read the config first without creating objects, to get to the original nlp_config
def load_nlp_and_config(config_path):
config = util.load_config(config_path, create_objects=False)
if config["training"].get("seed"):
fix_random_seed(config["training"]["seed"])
@ -199,16 +202,28 @@ def train(
use_pytorch_for_gpu_memory()
nlp_config = config["nlp"]
config = util.load_config(config_path, create_objects=True)
nlp = util.load_model_from_config(nlp_config)
return nlp, config
def train(
nlp,
config,
corpus,
raw_text: Optional[Path] = None,
output_path: Optional[Path] = None,
tag_map: Optional[Path] = None,
weights_data: Optional[bytes] = None,
omit_extra_lookups: bool = False,
disable_tqdm: bool = False,
randomization_index: int = 0,
num_workers=1
) -> None:
msg.info(f"Loading config from: {config_path}")
# Read the config first without creating objects, to get to the original nlp_config
training = config["training"]
msg.info("Creating nlp from config")
nlp = util.load_model_from_config(nlp_config)
optimizer = training["optimizer"]
# TODO: is there a cleaner way of doing this, instead of creating
# the optimizer twice? are there any problems when doing this?
if remote_optimizer:
optimizer = remote_optimizer
limit = training["limit"]
corpus = Corpus(data_paths["train"], data_paths["dev"], limit=limit)
if "textcat" in nlp_config["pipeline"]:
verify_textcat_config(nlp, nlp_config)
if training.get("resume", False):
@ -275,7 +290,6 @@ def train(
)
msg.info(f"Training. Initial learn rate: {optimizer.learn_rate}")
print_row = setup_printer(training, nlp)
tqdm_args = dict(total=training["eval_frequency"], leave=False, disable=disable_tqdm)
try:
progress = tqdm.tqdm(**tqdm_args)