Merge pull request #5992 from svlandeg/feature/wandb-restrict-config

This commit is contained in:
Ines Montani 2020-08-28 15:05:29 +02:00 committed by GitHub
commit 89f692bc8a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,6 +1,7 @@
from typing import Dict, Any, Tuple, Callable
from typing import Dict, Any, Tuple, Callable, List
from ..util import registry
from .. import util
from ..errors import Errors
from wasabi import msg
@ -66,7 +67,7 @@ def console_logger():
@registry.loggers("spacy.WandbLogger.v1")
def wandb_logger(project_name: str):
def wandb_logger(project_name: str, remove_config_values: List[str] = []):
import wandb
console = console_logger()
@ -75,16 +76,19 @@ def wandb_logger(project_name: str):
nlp: "Language"
) -> Tuple[Callable[[Dict[str, Any]], None], Callable]:
config = nlp.config.interpolate()
config_dot = util.dict_to_dot(config)
for field in remove_config_values:
del config_dot[field]
config = util.dot_to_dict(config_dot)
wandb.init(project=project_name, config=config)
console_log_step, console_finalize = console(nlp)
def log_step(info: Dict[str, Any]):
console_log_step(info)
epoch = info["epoch"]
score = info["score"]
other_scores = info["other_scores"]
losses = info["losses"]
wandb.log({"score": score, "epoch": epoch})
wandb.log({"score": score})
if losses:
wandb.log({f"loss_{k}": v for k, v in losses.items()})
if isinstance(other_scores, dict):