From 865cfbc903a53df15f74c767c14bcf29847d0848 Mon Sep 17 00:00:00 2001 From: Jozef Harag <32jojo32@gmail.com> Date: Thu, 16 Sep 2021 12:26:41 +0200 Subject: [PATCH] feat: add `spacy.WandbLogger.v3` with optional `run_name` and `entity` parameters (#9202) * feat: add `spacy.WandbLogger.v3` with optional `run_name` and `entity` parameters * update versioning in docs Co-authored-by: svlandeg --- spacy/training/loggers.py | 84 +++++++++++++++++++++++++++++++++++ website/docs/api/top-level.md | 18 ++++---- 2 files changed, 94 insertions(+), 8 deletions(-) diff --git a/spacy/training/loggers.py b/spacy/training/loggers.py index 5cf2db6b3..524b422a5 100644 --- a/spacy/training/loggers.py +++ b/spacy/training/loggers.py @@ -177,3 +177,87 @@ def wandb_logger( return log_step, finalize return setup_logger + + +@registry.loggers("spacy.WandbLogger.v3") +def wandb_logger( + project_name: str, + remove_config_values: List[str] = [], + model_log_interval: Optional[int] = None, + log_dataset_dir: Optional[str] = None, + entity: Optional[str] = None, + run_name: Optional[str] = None, +): + try: + import wandb + + # test that these are available + from wandb import init, log, join # noqa: F401 + except ImportError: + raise ImportError(Errors.E880) + + console = console_logger(progress_bar=False) + + def setup_logger( + nlp: "Language", stdout: IO = sys.stdout, stderr: IO = sys.stderr + ) -> Tuple[Callable[[Dict[str, Any]], None], Callable[[], None]]: + config = nlp.config.interpolate() + config_dot = util.dict_to_dot(config) + for field in remove_config_values: + del config_dot[field] + config = util.dot_to_dict(config_dot) + run = wandb.init(project=project_name, config=config, entity=entity, reinit=True) + + if run_name: + wandb.run.name = run_name + + console_log_step, console_finalize = console(nlp, stdout, stderr) + + def log_dir_artifact( + path: str, + name: str, + type: str, + metadata: Optional[Dict[str, Any]] = {}, + aliases: Optional[List[str]] = [], + ): + dataset_artifact = wandb.Artifact(name, type=type, metadata=metadata) + dataset_artifact.add_dir(path, name=name) + wandb.log_artifact(dataset_artifact, aliases=aliases) + + if log_dataset_dir: + log_dir_artifact(path=log_dataset_dir, name="dataset", type="dataset") + + def log_step(info: Optional[Dict[str, Any]]): + console_log_step(info) + if info is not None: + score = info["score"] + other_scores = info["other_scores"] + losses = info["losses"] + wandb.log({"score": score}) + if losses: + wandb.log({f"loss_{k}": v for k, v in losses.items()}) + if isinstance(other_scores, dict): + wandb.log(other_scores) + if model_log_interval and info.get("output_path"): + if info["step"] % model_log_interval == 0 and info["step"] != 0: + log_dir_artifact( + path=info["output_path"], + name="pipeline_" + run.id, + type="checkpoint", + metadata=info, + aliases=[ + f"epoch {info['epoch']} step {info['step']}", + "latest", + "best" + if info["score"] == max(info["checkpoints"])[0] + else "", + ], + ) + + def finalize() -> None: + console_finalize() + wandb.join() + + return log_step, finalize + + return setup_logger diff --git a/website/docs/api/top-level.md b/website/docs/api/top-level.md index 8190d9f78..3cf81ae93 100644 --- a/website/docs/api/top-level.md +++ b/website/docs/api/top-level.md @@ -462,7 +462,7 @@ start decreasing across epochs. -#### spacy.WandbLogger.v2 {#WandbLogger tag="registered function"} +#### spacy.WandbLogger.v3 {#WandbLogger tag="registered function"} > #### Installation > @@ -494,19 +494,21 @@ remain in the config file stored on your local system. > > ```ini > [training.logger] -> @loggers = "spacy.WandbLogger.v2" +> @loggers = "spacy.WandbLogger.v3" > project_name = "monitor_spacy_training" > remove_config_values = ["paths.train", "paths.dev", "corpora.train.path", "corpora.dev.path"] > log_dataset_dir = "corpus" > model_log_interval = 1000 > ``` -| Name | Description | -| ---------------------- | ------------------------------------------------------------------------------------------------------------------------------------- | -| `project_name` | The name of the project in the Weights & Biases interface. The project will be created automatically if it doesn't exist yet. ~~str~~ | -| `remove_config_values` | A list of values to include from the config before it is uploaded to W&B (default: empty). ~~List[str]~~ | -| `model_log_interval` | Steps to wait between logging model checkpoints to W&B dasboard (default: None). ~~Optional[int]~~ | -| `log_dataset_dir` | Directory containing dataset to be logged and versioned as W&B artifact (default: None). ~~Optional[str]~~ | +| Name | Description | +| ---------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `project_name` | The name of the project in the Weights & Biases interface. The project will be created automatically if it doesn't exist yet. ~~str~~ | +| `remove_config_values` | A list of values to include from the config before it is uploaded to W&B (default: empty). ~~List[str]~~ | +| `model_log_interval` | Steps to wait between logging model checkpoints to W&B dasboard (default: None). ~~Optional[int]~~ | +| `log_dataset_dir` | Directory containing dataset to be logged and versioned as W&B artifact (default: None). ~~Optional[str]~~ | +| `run_name` | The name of the run. If you don't specify a run_name, the name will be created by wandb library. (default: None ). ~~Optional[str]~~ | +| `entity` | An entity is a username or team name where you're sending runs. If you don't specify an entity, the run will be sent to your default entity, which is usually your username. (default: None). ~~Optional[str]~~ |