mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
add disable_fields to wandb_logger
This commit is contained in:
parent
03dde511b4
commit
1d8c4070aa
|
@ -1,6 +1,7 @@
|
|||
from typing import Dict, Any, Tuple, Callable
|
||||
|
||||
from ..util import registry
|
||||
from .. import util
|
||||
from ..errors import Errors
|
||||
from wasabi import msg
|
||||
|
||||
|
@ -66,7 +67,7 @@ def console_logger():
|
|||
|
||||
|
||||
@registry.loggers("spacy.WandbLogger.v1")
|
||||
def wandb_logger(project_name: str):
|
||||
def wandb_logger(project_name: str, disable_fields: list = []):
|
||||
import wandb
|
||||
|
||||
console = console_logger()
|
||||
|
@ -75,16 +76,19 @@ def wandb_logger(project_name: str):
|
|||
nlp: "Language"
|
||||
) -> Tuple[Callable[[Dict[str, Any]], None], Callable]:
|
||||
config = nlp.config.interpolate()
|
||||
config_dot = util.dict_to_dot(config)
|
||||
for field in disable_fields:
|
||||
del config_dot[field]
|
||||
config = util.dot_to_dict(config_dot)
|
||||
wandb.init(project=project_name, config=config)
|
||||
console_log_step, console_finalize = console(nlp)
|
||||
|
||||
def log_step(info: Dict[str, Any]):
|
||||
console_log_step(info)
|
||||
epoch = info["epoch"]
|
||||
score = info["score"]
|
||||
other_scores = info["other_scores"]
|
||||
losses = info["losses"]
|
||||
wandb.log({"score": score, "epoch": epoch})
|
||||
wandb.log({"score": score})
|
||||
if losses:
|
||||
wandb.log({f"loss_{k}": v for k, v in losses.items()})
|
||||
if isinstance(other_scores, dict):
|
||||
|
|
Loading…
Reference in New Issue
Block a user