mirror of
https://github.com/explosion/spaCy.git
synced 2025-03-14 15:12:15 +03:00
Tmp refactor train func for ray
This commit is contained in:
parent
3bccf8b954
commit
874b34e5d4
|
@ -156,20 +156,35 @@ def train_cli(
|
|||
with init_tok2vec.open("rb") as file_:
|
||||
weights_data = file_.read()
|
||||
|
||||
train_args = dict(
|
||||
config_path=config_path,
|
||||
data_paths={"train": train_path, "dev": dev_path},
|
||||
output_path=output_path,
|
||||
raw_text=raw_text,
|
||||
tag_map=tag_map,
|
||||
weights_data=weights_data,
|
||||
omit_extra_lookups=omit_extra_lookups
|
||||
)
|
||||
|
||||
if num_workers and num_workers > 1:
|
||||
from spacy_ray import distributed_setup_and_train
|
||||
distributed_setup_and_train(use_gpu, num_workers, strategy, ray_address, train_args)
|
||||
from _ray_async_utils import distributed_setup_and_train
|
||||
distributed_setup_and_train(
|
||||
use_gpu,
|
||||
num_workers,
|
||||
strategy,
|
||||
ray_address,
|
||||
{
|
||||
"config": config_path,
|
||||
"train": train_path,
|
||||
"dev": dev_path,
|
||||
"output": output_path
|
||||
}
|
||||
)
|
||||
else:
|
||||
nlp, config = load_nlp_and_config(config_path)
|
||||
corpus = Corpus(train_path, dev_path, limit=config["training"]["limit"])
|
||||
|
||||
train_args = dict(
|
||||
nlp=nlp,
|
||||
config=config,
|
||||
corpus=corpus,
|
||||
output_path=output_path,
|
||||
raw_text=raw_text,
|
||||
tag_map=tag_map,
|
||||
weights_data=weights_data,
|
||||
omit_extra_lookups=omit_extra_lookups
|
||||
)
|
||||
|
||||
if use_gpu >= 0:
|
||||
msg.info(f"Using GPU: {use_gpu}")
|
||||
require_gpu(use_gpu)
|
||||
|
@ -177,20 +192,8 @@ def train_cli(
|
|||
msg.info("Using CPU")
|
||||
train(**train_args)
|
||||
|
||||
def train(
|
||||
config_path: Path,
|
||||
data_paths: Dict[str, Path],
|
||||
raw_text: Optional[Path] = None,
|
||||
output_path: Optional[Path] = None,
|
||||
tag_map: Optional[Path] = None,
|
||||
weights_data: Optional[bytes] = None,
|
||||
omit_extra_lookups: bool = False,
|
||||
disable_tqdm: bool = False,
|
||||
remote_optimizer = None,
|
||||
randomization_index: int = 0
|
||||
) -> None:
|
||||
msg.info(f"Loading config from: {config_path}")
|
||||
# Read the config first without creating objects, to get to the original nlp_config
|
||||
|
||||
def load_nlp_and_config(config_path):
|
||||
config = util.load_config(config_path, create_objects=False)
|
||||
if config["training"].get("seed"):
|
||||
fix_random_seed(config["training"]["seed"])
|
||||
|
@ -199,16 +202,28 @@ def train(
|
|||
use_pytorch_for_gpu_memory()
|
||||
nlp_config = config["nlp"]
|
||||
config = util.load_config(config_path, create_objects=True)
|
||||
nlp = util.load_model_from_config(nlp_config)
|
||||
return nlp, config
|
||||
|
||||
|
||||
def train(
|
||||
nlp,
|
||||
config,
|
||||
corpus,
|
||||
raw_text: Optional[Path] = None,
|
||||
output_path: Optional[Path] = None,
|
||||
tag_map: Optional[Path] = None,
|
||||
weights_data: Optional[bytes] = None,
|
||||
omit_extra_lookups: bool = False,
|
||||
disable_tqdm: bool = False,
|
||||
randomization_index: int = 0,
|
||||
num_workers=1
|
||||
) -> None:
|
||||
msg.info(f"Loading config from: {config_path}")
|
||||
# Read the config first without creating objects, to get to the original nlp_config
|
||||
training = config["training"]
|
||||
msg.info("Creating nlp from config")
|
||||
nlp = util.load_model_from_config(nlp_config)
|
||||
optimizer = training["optimizer"]
|
||||
# TODO: is there a cleaner way of doing this, instead of creating
|
||||
# the optimizer twice? are there any problems when doing this?
|
||||
if remote_optimizer:
|
||||
optimizer = remote_optimizer
|
||||
limit = training["limit"]
|
||||
corpus = Corpus(data_paths["train"], data_paths["dev"], limit=limit)
|
||||
if "textcat" in nlp_config["pipeline"]:
|
||||
verify_textcat_config(nlp, nlp_config)
|
||||
if training.get("resume", False):
|
||||
|
@ -275,7 +290,6 @@ def train(
|
|||
)
|
||||
msg.info(f"Training. Initial learn rate: {optimizer.learn_rate}")
|
||||
print_row = setup_printer(training, nlp)
|
||||
|
||||
tqdm_args = dict(total=training["eval_frequency"], leave=False, disable=disable_tqdm)
|
||||
try:
|
||||
progress = tqdm.tqdm(**tqdm_args)
|
||||
|
|
Loading…
Reference in New Issue
Block a user