From 1d8c4070aae601d1161f74284f857e9bc476050a Mon Sep 17 00:00:00 2001 From: svlandeg Date: Fri, 28 Aug 2020 13:55:32 +0200 Subject: [PATCH 1/3] add disable_fields to wandb_logger --- spacy/gold/loggers.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/spacy/gold/loggers.py b/spacy/gold/loggers.py index 10a153014..7c5a3d317 100644 --- a/spacy/gold/loggers.py +++ b/spacy/gold/loggers.py @@ -1,6 +1,7 @@ from typing import Dict, Any, Tuple, Callable from ..util import registry +from .. import util from ..errors import Errors from wasabi import msg @@ -66,7 +67,7 @@ def console_logger(): @registry.loggers("spacy.WandbLogger.v1") -def wandb_logger(project_name: str): +def wandb_logger(project_name: str, disable_fields: list = []): import wandb console = console_logger() @@ -75,16 +76,19 @@ def wandb_logger(project_name: str): nlp: "Language" ) -> Tuple[Callable[[Dict[str, Any]], None], Callable]: config = nlp.config.interpolate() + config_dot = util.dict_to_dot(config) + for field in disable_fields: + del config_dot[field] + config = util.dot_to_dict(config_dot) wandb.init(project=project_name, config=config) console_log_step, console_finalize = console(nlp) def log_step(info: Dict[str, Any]): console_log_step(info) - epoch = info["epoch"] score = info["score"] other_scores = info["other_scores"] losses = info["losses"] - wandb.log({"score": score, "epoch": epoch}) + wandb.log({"score": score}) if losses: wandb.log({f"loss_{k}": v for k, v in losses.items()}) if isinstance(other_scores, dict): From 33883aa764daf71a2e4b50ee26e324beeba488ba Mon Sep 17 00:00:00 2001 From: svlandeg Date: Fri, 28 Aug 2020 14:06:23 +0200 Subject: [PATCH 2/3] rename field --- spacy/gold/loggers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spacy/gold/loggers.py b/spacy/gold/loggers.py index 7c5a3d317..9feec558f 100644 --- a/spacy/gold/loggers.py +++ b/spacy/gold/loggers.py @@ -67,7 +67,7 @@ def console_logger(): @registry.loggers("spacy.WandbLogger.v1") -def wandb_logger(project_name: str, disable_fields: list = []): +def wandb_logger(project_name: str, remove_config_values: list = []): import wandb console = console_logger() @@ -77,7 +77,7 @@ def wandb_logger(project_name: str, disable_fields: list = []): ) -> Tuple[Callable[[Dict[str, Any]], None], Callable]: config = nlp.config.interpolate() config_dot = util.dict_to_dot(config) - for field in disable_fields: + for field in remove_config_values: del config_dot[field] config = util.dot_to_dict(config_dot) wandb.init(project=project_name, config=config) From 05a1bafa158a79fe56e5152599d2d5206c5bea2b Mon Sep 17 00:00:00 2001 From: svlandeg Date: Fri, 28 Aug 2020 14:08:33 +0200 Subject: [PATCH 3/3] fix type --- spacy/gold/loggers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spacy/gold/loggers.py b/spacy/gold/loggers.py index 9feec558f..e58d4a4aa 100644 --- a/spacy/gold/loggers.py +++ b/spacy/gold/loggers.py @@ -1,4 +1,4 @@ -from typing import Dict, Any, Tuple, Callable +from typing import Dict, Any, Tuple, Callable, List from ..util import registry from .. import util @@ -67,7 +67,7 @@ def console_logger(): @registry.loggers("spacy.WandbLogger.v1") -def wandb_logger(project_name: str, remove_config_values: list = []): +def wandb_logger(project_name: str, remove_config_values: List[str] = []): import wandb console = console_logger()