From 1d8c4070aae601d1161f74284f857e9bc476050a Mon Sep 17 00:00:00 2001 From: svlandeg Date: Fri, 28 Aug 2020 13:55:32 +0200 Subject: [PATCH] add disable_fields to wandb_logger --- spacy/gold/loggers.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/spacy/gold/loggers.py b/spacy/gold/loggers.py index 10a153014..7c5a3d317 100644 --- a/spacy/gold/loggers.py +++ b/spacy/gold/loggers.py @@ -1,6 +1,7 @@ from typing import Dict, Any, Tuple, Callable from ..util import registry +from .. import util from ..errors import Errors from wasabi import msg @@ -66,7 +67,7 @@ def console_logger(): @registry.loggers("spacy.WandbLogger.v1") -def wandb_logger(project_name: str): +def wandb_logger(project_name: str, disable_fields: list = []): import wandb console = console_logger() @@ -75,16 +76,19 @@ def wandb_logger(project_name: str): nlp: "Language" ) -> Tuple[Callable[[Dict[str, Any]], None], Callable]: config = nlp.config.interpolate() + config_dot = util.dict_to_dot(config) + for field in disable_fields: + del config_dot[field] + config = util.dot_to_dict(config_dot) wandb.init(project=project_name, config=config) console_log_step, console_finalize = console(nlp) def log_step(info: Dict[str, Any]): console_log_step(info) - epoch = info["epoch"] score = info["score"] other_scores = info["other_scores"] losses = info["losses"] - wandb.log({"score": score, "epoch": epoch}) + wandb.log({"score": score}) if losses: wandb.log({f"loss_{k}": v for k, v in losses.items()}) if isinstance(other_scores, dict):