diff --git a/spacy/gold/loggers.py b/spacy/gold/loggers.py index 10a153014..e58d4a4aa 100644 --- a/spacy/gold/loggers.py +++ b/spacy/gold/loggers.py @@ -1,6 +1,7 @@ -from typing import Dict, Any, Tuple, Callable +from typing import Dict, Any, Tuple, Callable, List from ..util import registry +from .. import util from ..errors import Errors from wasabi import msg @@ -66,7 +67,7 @@ def console_logger(): @registry.loggers("spacy.WandbLogger.v1") -def wandb_logger(project_name: str): +def wandb_logger(project_name: str, remove_config_values: List[str] = []): import wandb console = console_logger() @@ -75,16 +76,19 @@ def wandb_logger(project_name: str): nlp: "Language" ) -> Tuple[Callable[[Dict[str, Any]], None], Callable]: config = nlp.config.interpolate() + config_dot = util.dict_to_dot(config) + for field in remove_config_values: + del config_dot[field] + config = util.dot_to_dict(config_dot) wandb.init(project=project_name, config=config) console_log_step, console_finalize = console(nlp) def log_step(info: Dict[str, Any]): console_log_step(info) - epoch = info["epoch"] score = info["score"] other_scores = info["other_scores"] losses = info["losses"] - wandb.log({"score": score, "epoch": epoch}) + wandb.log({"score": score}) if losses: wandb.log({f"loss_{k}": v for k, v in losses.items()}) if isinstance(other_scores, dict):