mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-27 09:44:36 +03:00
feat: add spacy.WandbLogger.v3
with optional run_name
and entity
parameters (#9202)
* feat: add `spacy.WandbLogger.v3` with optional `run_name` and `entity` parameters * update versioning in docs Co-authored-by: svlandeg <sofie.vanlandeghem@gmail.com>
This commit is contained in:
parent
1d57d78758
commit
865cfbc903
|
@ -177,3 +177,87 @@ def wandb_logger(
|
||||||
return log_step, finalize
|
return log_step, finalize
|
||||||
|
|
||||||
return setup_logger
|
return setup_logger
|
||||||
|
|
||||||
|
|
||||||
|
@registry.loggers("spacy.WandbLogger.v3")
|
||||||
|
def wandb_logger(
|
||||||
|
project_name: str,
|
||||||
|
remove_config_values: List[str] = [],
|
||||||
|
model_log_interval: Optional[int] = None,
|
||||||
|
log_dataset_dir: Optional[str] = None,
|
||||||
|
entity: Optional[str] = None,
|
||||||
|
run_name: Optional[str] = None,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
import wandb
|
||||||
|
|
||||||
|
# test that these are available
|
||||||
|
from wandb import init, log, join # noqa: F401
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(Errors.E880)
|
||||||
|
|
||||||
|
console = console_logger(progress_bar=False)
|
||||||
|
|
||||||
|
def setup_logger(
|
||||||
|
nlp: "Language", stdout: IO = sys.stdout, stderr: IO = sys.stderr
|
||||||
|
) -> Tuple[Callable[[Dict[str, Any]], None], Callable[[], None]]:
|
||||||
|
config = nlp.config.interpolate()
|
||||||
|
config_dot = util.dict_to_dot(config)
|
||||||
|
for field in remove_config_values:
|
||||||
|
del config_dot[field]
|
||||||
|
config = util.dot_to_dict(config_dot)
|
||||||
|
run = wandb.init(project=project_name, config=config, entity=entity, reinit=True)
|
||||||
|
|
||||||
|
if run_name:
|
||||||
|
wandb.run.name = run_name
|
||||||
|
|
||||||
|
console_log_step, console_finalize = console(nlp, stdout, stderr)
|
||||||
|
|
||||||
|
def log_dir_artifact(
|
||||||
|
path: str,
|
||||||
|
name: str,
|
||||||
|
type: str,
|
||||||
|
metadata: Optional[Dict[str, Any]] = {},
|
||||||
|
aliases: Optional[List[str]] = [],
|
||||||
|
):
|
||||||
|
dataset_artifact = wandb.Artifact(name, type=type, metadata=metadata)
|
||||||
|
dataset_artifact.add_dir(path, name=name)
|
||||||
|
wandb.log_artifact(dataset_artifact, aliases=aliases)
|
||||||
|
|
||||||
|
if log_dataset_dir:
|
||||||
|
log_dir_artifact(path=log_dataset_dir, name="dataset", type="dataset")
|
||||||
|
|
||||||
|
def log_step(info: Optional[Dict[str, Any]]):
|
||||||
|
console_log_step(info)
|
||||||
|
if info is not None:
|
||||||
|
score = info["score"]
|
||||||
|
other_scores = info["other_scores"]
|
||||||
|
losses = info["losses"]
|
||||||
|
wandb.log({"score": score})
|
||||||
|
if losses:
|
||||||
|
wandb.log({f"loss_{k}": v for k, v in losses.items()})
|
||||||
|
if isinstance(other_scores, dict):
|
||||||
|
wandb.log(other_scores)
|
||||||
|
if model_log_interval and info.get("output_path"):
|
||||||
|
if info["step"] % model_log_interval == 0 and info["step"] != 0:
|
||||||
|
log_dir_artifact(
|
||||||
|
path=info["output_path"],
|
||||||
|
name="pipeline_" + run.id,
|
||||||
|
type="checkpoint",
|
||||||
|
metadata=info,
|
||||||
|
aliases=[
|
||||||
|
f"epoch {info['epoch']} step {info['step']}",
|
||||||
|
"latest",
|
||||||
|
"best"
|
||||||
|
if info["score"] == max(info["checkpoints"])[0]
|
||||||
|
else "",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
def finalize() -> None:
|
||||||
|
console_finalize()
|
||||||
|
wandb.join()
|
||||||
|
|
||||||
|
return log_step, finalize
|
||||||
|
|
||||||
|
return setup_logger
|
||||||
|
|
|
@ -462,7 +462,7 @@ start decreasing across epochs.
|
||||||
|
|
||||||
</Accordion>
|
</Accordion>
|
||||||
|
|
||||||
#### spacy.WandbLogger.v2 {#WandbLogger tag="registered function"}
|
#### spacy.WandbLogger.v3 {#WandbLogger tag="registered function"}
|
||||||
|
|
||||||
> #### Installation
|
> #### Installation
|
||||||
>
|
>
|
||||||
|
@ -494,7 +494,7 @@ remain in the config file stored on your local system.
|
||||||
>
|
>
|
||||||
> ```ini
|
> ```ini
|
||||||
> [training.logger]
|
> [training.logger]
|
||||||
> @loggers = "spacy.WandbLogger.v2"
|
> @loggers = "spacy.WandbLogger.v3"
|
||||||
> project_name = "monitor_spacy_training"
|
> project_name = "monitor_spacy_training"
|
||||||
> remove_config_values = ["paths.train", "paths.dev", "corpora.train.path", "corpora.dev.path"]
|
> remove_config_values = ["paths.train", "paths.dev", "corpora.train.path", "corpora.dev.path"]
|
||||||
> log_dataset_dir = "corpus"
|
> log_dataset_dir = "corpus"
|
||||||
|
@ -502,11 +502,13 @@ remain in the config file stored on your local system.
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
| Name | Description |
|
| Name | Description |
|
||||||
| ---------------------- | ------------------------------------------------------------------------------------------------------------------------------------- |
|
| ---------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
| `project_name` | The name of the project in the Weights & Biases interface. The project will be created automatically if it doesn't exist yet. ~~str~~ |
|
| `project_name` | The name of the project in the Weights & Biases interface. The project will be created automatically if it doesn't exist yet. ~~str~~ |
|
||||||
| `remove_config_values` | A list of values to include from the config before it is uploaded to W&B (default: empty). ~~List[str]~~ |
|
| `remove_config_values` | A list of values to include from the config before it is uploaded to W&B (default: empty). ~~List[str]~~ |
|
||||||
| `model_log_interval` | Steps to wait between logging model checkpoints to W&B dasboard (default: None). ~~Optional[int]~~ |
|
| `model_log_interval` | Steps to wait between logging model checkpoints to W&B dasboard (default: None). ~~Optional[int]~~ |
|
||||||
| `log_dataset_dir` | Directory containing dataset to be logged and versioned as W&B artifact (default: None). ~~Optional[str]~~ |
|
| `log_dataset_dir` | Directory containing dataset to be logged and versioned as W&B artifact (default: None). ~~Optional[str]~~ |
|
||||||
|
| `run_name` | The name of the run. If you don't specify a run_name, the name will be created by wandb library. (default: None ). ~~Optional[str]~~ |
|
||||||
|
| `entity` | An entity is a username or team name where you're sending runs. If you don't specify an entity, the run will be sent to your default entity, which is usually your username. (default: None). ~~Optional[str]~~ |
|
||||||
|
|
||||||
<Project id="integrations/wandb">
|
<Project id="integrations/wandb">
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user