mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 09:57:26 +03:00 
			
		
		
		
	Tmp refactor train func for ray
This commit is contained in:
		
							parent
							
								
									3bccf8b954
								
							
						
					
					
						commit
						874b34e5d4
					
				| 
						 | 
					@ -156,9 +156,28 @@ def train_cli(
 | 
				
			||||||
        with init_tok2vec.open("rb") as file_:
 | 
					        with init_tok2vec.open("rb") as file_:
 | 
				
			||||||
            weights_data = file_.read()
 | 
					            weights_data = file_.read()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if num_workers and num_workers > 1:
 | 
				
			||||||
 | 
					        from _ray_async_utils import distributed_setup_and_train
 | 
				
			||||||
 | 
					        distributed_setup_and_train(
 | 
				
			||||||
 | 
					            use_gpu,
 | 
				
			||||||
 | 
					            num_workers,
 | 
				
			||||||
 | 
					            strategy,
 | 
				
			||||||
 | 
					            ray_address,
 | 
				
			||||||
 | 
					            {
 | 
				
			||||||
 | 
					                "config": config_path,
 | 
				
			||||||
 | 
					                "train": train_path,
 | 
				
			||||||
 | 
					                "dev": dev_path,
 | 
				
			||||||
 | 
					                "output": output_path
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        nlp, config = load_nlp_and_config(config_path)
 | 
				
			||||||
 | 
					        corpus = Corpus(train_path, dev_path, limit=config["training"]["limit"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        train_args = dict(
 | 
					        train_args = dict(
 | 
				
			||||||
        config_path=config_path,
 | 
					            nlp=nlp,
 | 
				
			||||||
        data_paths={"train": train_path, "dev": dev_path},
 | 
					            config=config,
 | 
				
			||||||
 | 
					            corpus=corpus,
 | 
				
			||||||
            output_path=output_path,
 | 
					            output_path=output_path,
 | 
				
			||||||
            raw_text=raw_text,
 | 
					            raw_text=raw_text,
 | 
				
			||||||
            tag_map=tag_map,
 | 
					            tag_map=tag_map,
 | 
				
			||||||
| 
						 | 
					@ -166,10 +185,6 @@ def train_cli(
 | 
				
			||||||
            omit_extra_lookups=omit_extra_lookups
 | 
					            omit_extra_lookups=omit_extra_lookups
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if num_workers and num_workers > 1:
 | 
					 | 
				
			||||||
        from spacy_ray import distributed_setup_and_train
 | 
					 | 
				
			||||||
        distributed_setup_and_train(use_gpu, num_workers, strategy, ray_address, train_args)
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        if use_gpu >= 0:
 | 
					        if use_gpu >= 0:
 | 
				
			||||||
            msg.info(f"Using GPU: {use_gpu}")
 | 
					            msg.info(f"Using GPU: {use_gpu}")
 | 
				
			||||||
            require_gpu(use_gpu)
 | 
					            require_gpu(use_gpu)
 | 
				
			||||||
| 
						 | 
					@ -177,20 +192,8 @@ def train_cli(
 | 
				
			||||||
            msg.info("Using CPU")
 | 
					            msg.info("Using CPU")
 | 
				
			||||||
        train(**train_args)
 | 
					        train(**train_args)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def train(
 | 
					
 | 
				
			||||||
    config_path: Path,
 | 
					def load_nlp_and_config(config_path):
 | 
				
			||||||
    data_paths: Dict[str, Path],
 | 
					 | 
				
			||||||
    raw_text: Optional[Path] = None,
 | 
					 | 
				
			||||||
    output_path: Optional[Path] = None,
 | 
					 | 
				
			||||||
    tag_map: Optional[Path] = None,
 | 
					 | 
				
			||||||
    weights_data: Optional[bytes] = None,
 | 
					 | 
				
			||||||
    omit_extra_lookups: bool = False,
 | 
					 | 
				
			||||||
    disable_tqdm: bool = False,
 | 
					 | 
				
			||||||
    remote_optimizer = None,
 | 
					 | 
				
			||||||
    randomization_index: int = 0
 | 
					 | 
				
			||||||
) -> None:
 | 
					 | 
				
			||||||
    msg.info(f"Loading config from: {config_path}")
 | 
					 | 
				
			||||||
    # Read the config first without creating objects, to get to the original nlp_config
 | 
					 | 
				
			||||||
    config = util.load_config(config_path, create_objects=False)
 | 
					    config = util.load_config(config_path, create_objects=False)
 | 
				
			||||||
    if config["training"].get("seed"):
 | 
					    if config["training"].get("seed"):
 | 
				
			||||||
        fix_random_seed(config["training"]["seed"])
 | 
					        fix_random_seed(config["training"]["seed"])
 | 
				
			||||||
| 
						 | 
					@ -199,16 +202,28 @@ def train(
 | 
				
			||||||
        use_pytorch_for_gpu_memory()
 | 
					        use_pytorch_for_gpu_memory()
 | 
				
			||||||
    nlp_config = config["nlp"]
 | 
					    nlp_config = config["nlp"]
 | 
				
			||||||
    config = util.load_config(config_path, create_objects=True)
 | 
					    config = util.load_config(config_path, create_objects=True)
 | 
				
			||||||
 | 
					    nlp = util.load_model_from_config(nlp_config)
 | 
				
			||||||
 | 
					    return nlp, config
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def train(
 | 
				
			||||||
 | 
					    nlp,
 | 
				
			||||||
 | 
					    config,
 | 
				
			||||||
 | 
					    corpus,
 | 
				
			||||||
 | 
					    raw_text: Optional[Path] = None,
 | 
				
			||||||
 | 
					    output_path: Optional[Path] = None,
 | 
				
			||||||
 | 
					    tag_map: Optional[Path] = None,
 | 
				
			||||||
 | 
					    weights_data: Optional[bytes] = None,
 | 
				
			||||||
 | 
					    omit_extra_lookups: bool = False,
 | 
				
			||||||
 | 
					    disable_tqdm: bool = False,
 | 
				
			||||||
 | 
					    randomization_index: int = 0,
 | 
				
			||||||
 | 
					    num_workers=1
 | 
				
			||||||
 | 
					) -> None:
 | 
				
			||||||
 | 
					    msg.info(f"Loading config from: {config_path}")
 | 
				
			||||||
 | 
					    # Read the config first without creating objects, to get to the original nlp_config
 | 
				
			||||||
    training = config["training"]
 | 
					    training = config["training"]
 | 
				
			||||||
    msg.info("Creating nlp from config")
 | 
					    msg.info("Creating nlp from config")
 | 
				
			||||||
    nlp = util.load_model_from_config(nlp_config)
 | 
					 | 
				
			||||||
    optimizer = training["optimizer"]
 | 
					    optimizer = training["optimizer"]
 | 
				
			||||||
    # TODO: is there a cleaner way of doing this, instead of creating
 | 
					 | 
				
			||||||
    # the optimizer twice? are there any problems when doing this?
 | 
					 | 
				
			||||||
    if remote_optimizer:
 | 
					 | 
				
			||||||
        optimizer = remote_optimizer
 | 
					 | 
				
			||||||
    limit = training["limit"]
 | 
					 | 
				
			||||||
    corpus = Corpus(data_paths["train"], data_paths["dev"], limit=limit)
 | 
					 | 
				
			||||||
    if "textcat" in nlp_config["pipeline"]:
 | 
					    if "textcat" in nlp_config["pipeline"]:
 | 
				
			||||||
        verify_textcat_config(nlp, nlp_config)
 | 
					        verify_textcat_config(nlp, nlp_config)
 | 
				
			||||||
    if training.get("resume", False):
 | 
					    if training.get("resume", False):
 | 
				
			||||||
| 
						 | 
					@ -275,7 +290,6 @@ def train(
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    msg.info(f"Training. Initial learn rate: {optimizer.learn_rate}")
 | 
					    msg.info(f"Training. Initial learn rate: {optimizer.learn_rate}")
 | 
				
			||||||
    print_row = setup_printer(training, nlp)
 | 
					    print_row = setup_printer(training, nlp)
 | 
				
			||||||
 | 
					 | 
				
			||||||
    tqdm_args = dict(total=training["eval_frequency"], leave=False, disable=disable_tqdm)
 | 
					    tqdm_args = dict(total=training["eval_frequency"], leave=False, disable=disable_tqdm)
 | 
				
			||||||
    try:
 | 
					    try:
 | 
				
			||||||
        progress = tqdm.tqdm(**tqdm_args)
 | 
					        progress = tqdm.tqdm(**tqdm_args)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user