add disable_fields to wandb_logger

This commit is contained in:
svlandeg 2020-08-28 13:55:32 +02:00
parent 03dde511b4
commit 1d8c4070aa

View File

@ -1,6 +1,7 @@
from typing import Dict, Any, Tuple, Callable from typing import Dict, Any, Tuple, Callable
from ..util import registry from ..util import registry
from .. import util
from ..errors import Errors from ..errors import Errors
from wasabi import msg from wasabi import msg
@ -66,7 +67,7 @@ def console_logger():
@registry.loggers("spacy.WandbLogger.v1") @registry.loggers("spacy.WandbLogger.v1")
def wandb_logger(project_name: str): def wandb_logger(project_name: str, disable_fields: list = []):
import wandb import wandb
console = console_logger() console = console_logger()
@ -75,16 +76,19 @@ def wandb_logger(project_name: str):
nlp: "Language" nlp: "Language"
) -> Tuple[Callable[[Dict[str, Any]], None], Callable]: ) -> Tuple[Callable[[Dict[str, Any]], None], Callable]:
config = nlp.config.interpolate() config = nlp.config.interpolate()
config_dot = util.dict_to_dot(config)
for field in disable_fields:
del config_dot[field]
config = util.dot_to_dict(config_dot)
wandb.init(project=project_name, config=config) wandb.init(project=project_name, config=config)
console_log_step, console_finalize = console(nlp) console_log_step, console_finalize = console(nlp)
def log_step(info: Dict[str, Any]): def log_step(info: Dict[str, Any]):
console_log_step(info) console_log_step(info)
epoch = info["epoch"]
score = info["score"] score = info["score"]
other_scores = info["other_scores"] other_scores = info["other_scores"]
losses = info["losses"] losses = info["losses"]
wandb.log({"score": score, "epoch": epoch}) wandb.log({"score": score})
if losses: if losses:
wandb.log({f"loss_{k}": v for k, v in losses.items()}) wandb.log({f"loss_{k}": v for k, v in losses.items()})
if isinstance(other_scores, dict): if isinstance(other_scores, dict):