mirror of
https://github.com/explosion/spaCy.git
synced 2025-03-15 23:52:30 +03:00
Tmp refactor train func for ray
This commit is contained in:
parent
3bccf8b954
commit
874b34e5d4
|
@ -156,20 +156,35 @@ def train_cli(
|
||||||
with init_tok2vec.open("rb") as file_:
|
with init_tok2vec.open("rb") as file_:
|
||||||
weights_data = file_.read()
|
weights_data = file_.read()
|
||||||
|
|
||||||
train_args = dict(
|
|
||||||
config_path=config_path,
|
|
||||||
data_paths={"train": train_path, "dev": dev_path},
|
|
||||||
output_path=output_path,
|
|
||||||
raw_text=raw_text,
|
|
||||||
tag_map=tag_map,
|
|
||||||
weights_data=weights_data,
|
|
||||||
omit_extra_lookups=omit_extra_lookups
|
|
||||||
)
|
|
||||||
|
|
||||||
if num_workers and num_workers > 1:
|
if num_workers and num_workers > 1:
|
||||||
from spacy_ray import distributed_setup_and_train
|
from _ray_async_utils import distributed_setup_and_train
|
||||||
distributed_setup_and_train(use_gpu, num_workers, strategy, ray_address, train_args)
|
distributed_setup_and_train(
|
||||||
|
use_gpu,
|
||||||
|
num_workers,
|
||||||
|
strategy,
|
||||||
|
ray_address,
|
||||||
|
{
|
||||||
|
"config": config_path,
|
||||||
|
"train": train_path,
|
||||||
|
"dev": dev_path,
|
||||||
|
"output": output_path
|
||||||
|
}
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
|
nlp, config = load_nlp_and_config(config_path)
|
||||||
|
corpus = Corpus(train_path, dev_path, limit=config["training"]["limit"])
|
||||||
|
|
||||||
|
train_args = dict(
|
||||||
|
nlp=nlp,
|
||||||
|
config=config,
|
||||||
|
corpus=corpus,
|
||||||
|
output_path=output_path,
|
||||||
|
raw_text=raw_text,
|
||||||
|
tag_map=tag_map,
|
||||||
|
weights_data=weights_data,
|
||||||
|
omit_extra_lookups=omit_extra_lookups
|
||||||
|
)
|
||||||
|
|
||||||
if use_gpu >= 0:
|
if use_gpu >= 0:
|
||||||
msg.info(f"Using GPU: {use_gpu}")
|
msg.info(f"Using GPU: {use_gpu}")
|
||||||
require_gpu(use_gpu)
|
require_gpu(use_gpu)
|
||||||
|
@ -177,20 +192,8 @@ def train_cli(
|
||||||
msg.info("Using CPU")
|
msg.info("Using CPU")
|
||||||
train(**train_args)
|
train(**train_args)
|
||||||
|
|
||||||
def train(
|
|
||||||
config_path: Path,
|
def load_nlp_and_config(config_path):
|
||||||
data_paths: Dict[str, Path],
|
|
||||||
raw_text: Optional[Path] = None,
|
|
||||||
output_path: Optional[Path] = None,
|
|
||||||
tag_map: Optional[Path] = None,
|
|
||||||
weights_data: Optional[bytes] = None,
|
|
||||||
omit_extra_lookups: bool = False,
|
|
||||||
disable_tqdm: bool = False,
|
|
||||||
remote_optimizer = None,
|
|
||||||
randomization_index: int = 0
|
|
||||||
) -> None:
|
|
||||||
msg.info(f"Loading config from: {config_path}")
|
|
||||||
# Read the config first without creating objects, to get to the original nlp_config
|
|
||||||
config = util.load_config(config_path, create_objects=False)
|
config = util.load_config(config_path, create_objects=False)
|
||||||
if config["training"].get("seed"):
|
if config["training"].get("seed"):
|
||||||
fix_random_seed(config["training"]["seed"])
|
fix_random_seed(config["training"]["seed"])
|
||||||
|
@ -199,16 +202,28 @@ def train(
|
||||||
use_pytorch_for_gpu_memory()
|
use_pytorch_for_gpu_memory()
|
||||||
nlp_config = config["nlp"]
|
nlp_config = config["nlp"]
|
||||||
config = util.load_config(config_path, create_objects=True)
|
config = util.load_config(config_path, create_objects=True)
|
||||||
|
nlp = util.load_model_from_config(nlp_config)
|
||||||
|
return nlp, config
|
||||||
|
|
||||||
|
|
||||||
|
def train(
|
||||||
|
nlp,
|
||||||
|
config,
|
||||||
|
corpus,
|
||||||
|
raw_text: Optional[Path] = None,
|
||||||
|
output_path: Optional[Path] = None,
|
||||||
|
tag_map: Optional[Path] = None,
|
||||||
|
weights_data: Optional[bytes] = None,
|
||||||
|
omit_extra_lookups: bool = False,
|
||||||
|
disable_tqdm: bool = False,
|
||||||
|
randomization_index: int = 0,
|
||||||
|
num_workers=1
|
||||||
|
) -> None:
|
||||||
|
msg.info(f"Loading config from: {config_path}")
|
||||||
|
# Read the config first without creating objects, to get to the original nlp_config
|
||||||
training = config["training"]
|
training = config["training"]
|
||||||
msg.info("Creating nlp from config")
|
msg.info("Creating nlp from config")
|
||||||
nlp = util.load_model_from_config(nlp_config)
|
|
||||||
optimizer = training["optimizer"]
|
optimizer = training["optimizer"]
|
||||||
# TODO: is there a cleaner way of doing this, instead of creating
|
|
||||||
# the optimizer twice? are there any problems when doing this?
|
|
||||||
if remote_optimizer:
|
|
||||||
optimizer = remote_optimizer
|
|
||||||
limit = training["limit"]
|
|
||||||
corpus = Corpus(data_paths["train"], data_paths["dev"], limit=limit)
|
|
||||||
if "textcat" in nlp_config["pipeline"]:
|
if "textcat" in nlp_config["pipeline"]:
|
||||||
verify_textcat_config(nlp, nlp_config)
|
verify_textcat_config(nlp, nlp_config)
|
||||||
if training.get("resume", False):
|
if training.get("resume", False):
|
||||||
|
@ -275,7 +290,6 @@ def train(
|
||||||
)
|
)
|
||||||
msg.info(f"Training. Initial learn rate: {optimizer.learn_rate}")
|
msg.info(f"Training. Initial learn rate: {optimizer.learn_rate}")
|
||||||
print_row = setup_printer(training, nlp)
|
print_row = setup_printer(training, nlp)
|
||||||
|
|
||||||
tqdm_args = dict(total=training["eval_frequency"], leave=False, disable=disable_tqdm)
|
tqdm_args = dict(total=training["eval_frequency"], leave=False, disable=disable_tqdm)
|
||||||
try:
|
try:
|
||||||
progress = tqdm.tqdm(**tqdm_args)
|
progress = tqdm.tqdm(**tqdm_args)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user