mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 17:24:41 +03:00
feat: add spacy.WandbLogger.v3
with optional run_name
and entity
parameters (#9202)
* feat: add `spacy.WandbLogger.v3` with optional `run_name` and `entity` parameters * update versioning in docs Co-authored-by: svlandeg <sofie.vanlandeghem@gmail.com>
This commit is contained in:
parent
1d57d78758
commit
865cfbc903
|
@ -177,3 +177,87 @@ def wandb_logger(
|
|||
return log_step, finalize
|
||||
|
||||
return setup_logger
|
||||
|
||||
|
||||
@registry.loggers("spacy.WandbLogger.v3")
|
||||
def wandb_logger(
|
||||
project_name: str,
|
||||
remove_config_values: List[str] = [],
|
||||
model_log_interval: Optional[int] = None,
|
||||
log_dataset_dir: Optional[str] = None,
|
||||
entity: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
):
|
||||
try:
|
||||
import wandb
|
||||
|
||||
# test that these are available
|
||||
from wandb import init, log, join # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError(Errors.E880)
|
||||
|
||||
console = console_logger(progress_bar=False)
|
||||
|
||||
def setup_logger(
|
||||
nlp: "Language", stdout: IO = sys.stdout, stderr: IO = sys.stderr
|
||||
) -> Tuple[Callable[[Dict[str, Any]], None], Callable[[], None]]:
|
||||
config = nlp.config.interpolate()
|
||||
config_dot = util.dict_to_dot(config)
|
||||
for field in remove_config_values:
|
||||
del config_dot[field]
|
||||
config = util.dot_to_dict(config_dot)
|
||||
run = wandb.init(project=project_name, config=config, entity=entity, reinit=True)
|
||||
|
||||
if run_name:
|
||||
wandb.run.name = run_name
|
||||
|
||||
console_log_step, console_finalize = console(nlp, stdout, stderr)
|
||||
|
||||
def log_dir_artifact(
|
||||
path: str,
|
||||
name: str,
|
||||
type: str,
|
||||
metadata: Optional[Dict[str, Any]] = {},
|
||||
aliases: Optional[List[str]] = [],
|
||||
):
|
||||
dataset_artifact = wandb.Artifact(name, type=type, metadata=metadata)
|
||||
dataset_artifact.add_dir(path, name=name)
|
||||
wandb.log_artifact(dataset_artifact, aliases=aliases)
|
||||
|
||||
if log_dataset_dir:
|
||||
log_dir_artifact(path=log_dataset_dir, name="dataset", type="dataset")
|
||||
|
||||
def log_step(info: Optional[Dict[str, Any]]):
|
||||
console_log_step(info)
|
||||
if info is not None:
|
||||
score = info["score"]
|
||||
other_scores = info["other_scores"]
|
||||
losses = info["losses"]
|
||||
wandb.log({"score": score})
|
||||
if losses:
|
||||
wandb.log({f"loss_{k}": v for k, v in losses.items()})
|
||||
if isinstance(other_scores, dict):
|
||||
wandb.log(other_scores)
|
||||
if model_log_interval and info.get("output_path"):
|
||||
if info["step"] % model_log_interval == 0 and info["step"] != 0:
|
||||
log_dir_artifact(
|
||||
path=info["output_path"],
|
||||
name="pipeline_" + run.id,
|
||||
type="checkpoint",
|
||||
metadata=info,
|
||||
aliases=[
|
||||
f"epoch {info['epoch']} step {info['step']}",
|
||||
"latest",
|
||||
"best"
|
||||
if info["score"] == max(info["checkpoints"])[0]
|
||||
else "",
|
||||
],
|
||||
)
|
||||
|
||||
def finalize() -> None:
|
||||
console_finalize()
|
||||
wandb.join()
|
||||
|
||||
return log_step, finalize
|
||||
|
||||
return setup_logger
|
||||
|
|
|
@ -462,7 +462,7 @@ start decreasing across epochs.
|
|||
|
||||
</Accordion>
|
||||
|
||||
#### spacy.WandbLogger.v2 {#WandbLogger tag="registered function"}
|
||||
#### spacy.WandbLogger.v3 {#WandbLogger tag="registered function"}
|
||||
|
||||
> #### Installation
|
||||
>
|
||||
|
@ -494,7 +494,7 @@ remain in the config file stored on your local system.
|
|||
>
|
||||
> ```ini
|
||||
> [training.logger]
|
||||
> @loggers = "spacy.WandbLogger.v2"
|
||||
> @loggers = "spacy.WandbLogger.v3"
|
||||
> project_name = "monitor_spacy_training"
|
||||
> remove_config_values = ["paths.train", "paths.dev", "corpora.train.path", "corpora.dev.path"]
|
||||
> log_dataset_dir = "corpus"
|
||||
|
@ -502,11 +502,13 @@ remain in the config file stored on your local system.
|
|||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| ---------------------- | ------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| ---------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `project_name` | The name of the project in the Weights & Biases interface. The project will be created automatically if it doesn't exist yet. ~~str~~ |
|
||||
| `remove_config_values` | A list of values to include from the config before it is uploaded to W&B (default: empty). ~~List[str]~~ |
|
||||
| `model_log_interval` | Steps to wait between logging model checkpoints to W&B dasboard (default: None). ~~Optional[int]~~ |
|
||||
| `log_dataset_dir` | Directory containing dataset to be logged and versioned as W&B artifact (default: None). ~~Optional[str]~~ |
|
||||
| `run_name` | The name of the run. If you don't specify a run_name, the name will be created by wandb library. (default: None ). ~~Optional[str]~~ |
|
||||
| `entity` | An entity is a username or team name where you're sending runs. If you don't specify an entity, the run will be sent to your default entity, which is usually your username. (default: None). ~~Optional[str]~~ |
|
||||
|
||||
<Project id="integrations/wandb">
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user