mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-25 13:11:03 +03:00 
			
		
		
		
	feat: add spacy.WandbLogger.v3 with optional run_name and entity parameters (#9202)
				
					
				
			* feat: add `spacy.WandbLogger.v3` with optional `run_name` and `entity` parameters * update versioning in docs Co-authored-by: svlandeg <sofie.vanlandeghem@gmail.com>
This commit is contained in:
		
							parent
							
								
									1d57d78758
								
							
						
					
					
						commit
						865cfbc903
					
				|  | @ -177,3 +177,87 @@ def wandb_logger( | ||||||
|         return log_step, finalize |         return log_step, finalize | ||||||
| 
 | 
 | ||||||
|     return setup_logger |     return setup_logger | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @registry.loggers("spacy.WandbLogger.v3") | ||||||
|  | def wandb_logger( | ||||||
|  |     project_name: str, | ||||||
|  |     remove_config_values: List[str] = [], | ||||||
|  |     model_log_interval: Optional[int] = None, | ||||||
|  |     log_dataset_dir: Optional[str] = None, | ||||||
|  |     entity: Optional[str] = None, | ||||||
|  |     run_name: Optional[str] = None, | ||||||
|  | ): | ||||||
|  |     try: | ||||||
|  |         import wandb | ||||||
|  | 
 | ||||||
|  |         # test that these are available | ||||||
|  |         from wandb import init, log, join  # noqa: F401 | ||||||
|  |     except ImportError: | ||||||
|  |         raise ImportError(Errors.E880) | ||||||
|  | 
 | ||||||
|  |     console = console_logger(progress_bar=False) | ||||||
|  | 
 | ||||||
|  |     def setup_logger( | ||||||
|  |         nlp: "Language", stdout: IO = sys.stdout, stderr: IO = sys.stderr | ||||||
|  |     ) -> Tuple[Callable[[Dict[str, Any]], None], Callable[[], None]]: | ||||||
|  |         config = nlp.config.interpolate() | ||||||
|  |         config_dot = util.dict_to_dot(config) | ||||||
|  |         for field in remove_config_values: | ||||||
|  |             del config_dot[field] | ||||||
|  |         config = util.dot_to_dict(config_dot) | ||||||
|  |         run = wandb.init(project=project_name, config=config, entity=entity, reinit=True) | ||||||
|  | 
 | ||||||
|  |         if run_name: | ||||||
|  |             wandb.run.name = run_name | ||||||
|  | 
 | ||||||
|  |         console_log_step, console_finalize = console(nlp, stdout, stderr) | ||||||
|  | 
 | ||||||
|  |         def log_dir_artifact( | ||||||
|  |             path: str, | ||||||
|  |             name: str, | ||||||
|  |             type: str, | ||||||
|  |             metadata: Optional[Dict[str, Any]] = {}, | ||||||
|  |             aliases: Optional[List[str]] = [], | ||||||
|  |         ): | ||||||
|  |             dataset_artifact = wandb.Artifact(name, type=type, metadata=metadata) | ||||||
|  |             dataset_artifact.add_dir(path, name=name) | ||||||
|  |             wandb.log_artifact(dataset_artifact, aliases=aliases) | ||||||
|  | 
 | ||||||
|  |         if log_dataset_dir: | ||||||
|  |             log_dir_artifact(path=log_dataset_dir, name="dataset", type="dataset") | ||||||
|  | 
 | ||||||
|  |         def log_step(info: Optional[Dict[str, Any]]): | ||||||
|  |             console_log_step(info) | ||||||
|  |             if info is not None: | ||||||
|  |                 score = info["score"] | ||||||
|  |                 other_scores = info["other_scores"] | ||||||
|  |                 losses = info["losses"] | ||||||
|  |                 wandb.log({"score": score}) | ||||||
|  |                 if losses: | ||||||
|  |                     wandb.log({f"loss_{k}": v for k, v in losses.items()}) | ||||||
|  |                 if isinstance(other_scores, dict): | ||||||
|  |                     wandb.log(other_scores) | ||||||
|  |                 if model_log_interval and info.get("output_path"): | ||||||
|  |                     if info["step"] % model_log_interval == 0 and info["step"] != 0: | ||||||
|  |                         log_dir_artifact( | ||||||
|  |                             path=info["output_path"], | ||||||
|  |                             name="pipeline_" + run.id, | ||||||
|  |                             type="checkpoint", | ||||||
|  |                             metadata=info, | ||||||
|  |                             aliases=[ | ||||||
|  |                                 f"epoch {info['epoch']} step {info['step']}", | ||||||
|  |                                 "latest", | ||||||
|  |                                 "best" | ||||||
|  |                                 if info["score"] == max(info["checkpoints"])[0] | ||||||
|  |                                 else "", | ||||||
|  |                             ], | ||||||
|  |                         ) | ||||||
|  | 
 | ||||||
|  |         def finalize() -> None: | ||||||
|  |             console_finalize() | ||||||
|  |             wandb.join() | ||||||
|  | 
 | ||||||
|  |         return log_step, finalize | ||||||
|  | 
 | ||||||
|  |     return setup_logger | ||||||
|  |  | ||||||
|  | @ -462,7 +462,7 @@ start decreasing across epochs. | ||||||
| 
 | 
 | ||||||
|  </Accordion> |  </Accordion> | ||||||
| 
 | 
 | ||||||
| #### spacy.WandbLogger.v2 {#WandbLogger tag="registered function"} | #### spacy.WandbLogger.v3 {#WandbLogger tag="registered function"} | ||||||
| 
 | 
 | ||||||
| > #### Installation | > #### Installation | ||||||
| > | > | ||||||
|  | @ -494,7 +494,7 @@ remain in the config file stored on your local system. | ||||||
| > | > | ||||||
| > ```ini | > ```ini | ||||||
| > [training.logger] | > [training.logger] | ||||||
| > @loggers = "spacy.WandbLogger.v2" | > @loggers = "spacy.WandbLogger.v3" | ||||||
| > project_name = "monitor_spacy_training" | > project_name = "monitor_spacy_training" | ||||||
| > remove_config_values = ["paths.train", "paths.dev", "corpora.train.path", "corpora.dev.path"] | > remove_config_values = ["paths.train", "paths.dev", "corpora.train.path", "corpora.dev.path"] | ||||||
| > log_dataset_dir = "corpus" | > log_dataset_dir = "corpus" | ||||||
|  | @ -502,11 +502,13 @@ remain in the config file stored on your local system. | ||||||
| > ``` | > ``` | ||||||
| 
 | 
 | ||||||
| | Name                   | Description                                                                                                                                                                                                     | | | Name                   | Description                                                                                                                                                                                                     | | ||||||
| | ---------------------- | ------------------------------------------------------------------------------------------------------------------------------------- | | | ---------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | ||||||
| | `project_name`         | The name of the project in the Weights & Biases interface. The project will be created automatically if it doesn't exist yet. ~~str~~                                                                           | | | `project_name`         | The name of the project in the Weights & Biases interface. The project will be created automatically if it doesn't exist yet. ~~str~~                                                                           | | ||||||
| | `remove_config_values` | A list of values to include from the config before it is uploaded to W&B (default: empty). ~~List[str]~~                                                                                                        | | | `remove_config_values` | A list of values to include from the config before it is uploaded to W&B (default: empty). ~~List[str]~~                                                                                                        | | ||||||
| | `model_log_interval`   | Steps to wait between logging model checkpoints to W&B dasboard (default: None). ~~Optional[int]~~                                                                                                              | | | `model_log_interval`   | Steps to wait between logging model checkpoints to W&B dasboard (default: None). ~~Optional[int]~~                                                                                                              | | ||||||
| | `log_dataset_dir`      | Directory containing dataset to be logged and versioned as W&B artifact (default: None). ~~Optional[str]~~                                                                                                      | | | `log_dataset_dir`      | Directory containing dataset to be logged and versioned as W&B artifact (default: None). ~~Optional[str]~~                                                                                                      | | ||||||
|  | | `run_name`             | The name of the run. If you don't specify a run_name, the name will be created by wandb library. (default: None ). ~~Optional[str]~~                                                                            | | ||||||
|  | | `entity`               | An entity is a username or team name where you're sending runs. If you don't specify an entity, the run will be sent to your default entity, which is usually your username. (default: None). ~~Optional[str]~~ | | ||||||
| 
 | 
 | ||||||
| <Project id="integrations/wandb"> | <Project id="integrations/wandb"> | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user