feat: add spacy.WandbLogger.v3 with optional run_name and entity parameters (#9202)

* feat: add `spacy.WandbLogger.v3` with optional `run_name` and `entity` parameters

* update versioning in docs

Co-authored-by: svlandeg <sofie.vanlandeghem@gmail.com>
This commit is contained in:
Jozef Harag 2021-09-16 12:26:41 +02:00 committed by GitHub
parent 1d57d78758
commit 865cfbc903
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 94 additions and 8 deletions

View File

@ -177,3 +177,87 @@ def wandb_logger(
return log_step, finalize
return setup_logger
@registry.loggers("spacy.WandbLogger.v3")
def wandb_logger(
project_name: str,
remove_config_values: List[str] = [],
model_log_interval: Optional[int] = None,
log_dataset_dir: Optional[str] = None,
entity: Optional[str] = None,
run_name: Optional[str] = None,
):
try:
import wandb
# test that these are available
from wandb import init, log, join # noqa: F401
except ImportError:
raise ImportError(Errors.E880)
console = console_logger(progress_bar=False)
def setup_logger(
nlp: "Language", stdout: IO = sys.stdout, stderr: IO = sys.stderr
) -> Tuple[Callable[[Dict[str, Any]], None], Callable[[], None]]:
config = nlp.config.interpolate()
config_dot = util.dict_to_dot(config)
for field in remove_config_values:
del config_dot[field]
config = util.dot_to_dict(config_dot)
run = wandb.init(project=project_name, config=config, entity=entity, reinit=True)
if run_name:
wandb.run.name = run_name
console_log_step, console_finalize = console(nlp, stdout, stderr)
def log_dir_artifact(
path: str,
name: str,
type: str,
metadata: Optional[Dict[str, Any]] = {},
aliases: Optional[List[str]] = [],
):
dataset_artifact = wandb.Artifact(name, type=type, metadata=metadata)
dataset_artifact.add_dir(path, name=name)
wandb.log_artifact(dataset_artifact, aliases=aliases)
if log_dataset_dir:
log_dir_artifact(path=log_dataset_dir, name="dataset", type="dataset")
def log_step(info: Optional[Dict[str, Any]]):
console_log_step(info)
if info is not None:
score = info["score"]
other_scores = info["other_scores"]
losses = info["losses"]
wandb.log({"score": score})
if losses:
wandb.log({f"loss_{k}": v for k, v in losses.items()})
if isinstance(other_scores, dict):
wandb.log(other_scores)
if model_log_interval and info.get("output_path"):
if info["step"] % model_log_interval == 0 and info["step"] != 0:
log_dir_artifact(
path=info["output_path"],
name="pipeline_" + run.id,
type="checkpoint",
metadata=info,
aliases=[
f"epoch {info['epoch']} step {info['step']}",
"latest",
"best"
if info["score"] == max(info["checkpoints"])[0]
else "",
],
)
def finalize() -> None:
console_finalize()
wandb.join()
return log_step, finalize
return setup_logger

View File

@ -462,7 +462,7 @@ start decreasing across epochs.
</Accordion>
#### spacy.WandbLogger.v2 {#WandbLogger tag="registered function"}
#### spacy.WandbLogger.v3 {#WandbLogger tag="registered function"}
> #### Installation
>
@ -494,7 +494,7 @@ remain in the config file stored on your local system.
>
> ```ini
> [training.logger]
> @loggers = "spacy.WandbLogger.v2"
> @loggers = "spacy.WandbLogger.v3"
> project_name = "monitor_spacy_training"
> remove_config_values = ["paths.train", "paths.dev", "corpora.train.path", "corpora.dev.path"]
> log_dataset_dir = "corpus"
@ -502,11 +502,13 @@ remain in the config file stored on your local system.
> ```
| Name | Description |
| ---------------------- | ------------------------------------------------------------------------------------------------------------------------------------- |
| ---------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `project_name` | The name of the project in the Weights & Biases interface. The project will be created automatically if it doesn't exist yet. ~~str~~ |
| `remove_config_values` | A list of values to include from the config before it is uploaded to W&B (default: empty). ~~List[str]~~ |
| `model_log_interval` | Steps to wait between logging model checkpoints to W&B dasboard (default: None). ~~Optional[int]~~ |
| `log_dataset_dir` | Directory containing dataset to be logged and versioned as W&B artifact (default: None). ~~Optional[str]~~ |
| `run_name` | The name of the run. If you don't specify a run_name, the name will be created by wandb library. (default: None ). ~~Optional[str]~~ |
| `entity` | An entity is a username or team name where you're sending runs. If you don't specify an entity, the run will be sent to your default entity, which is usually your username. (default: None). ~~Optional[str]~~ |
<Project id="integrations/wandb">