mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 10:16:27 +03:00
Merge pull request #5992 from svlandeg/feature/wandb-restrict-config
This commit is contained in:
commit
89f692bc8a
|
@ -1,6 +1,7 @@
|
||||||
from typing import Dict, Any, Tuple, Callable
|
from typing import Dict, Any, Tuple, Callable, List
|
||||||
|
|
||||||
from ..util import registry
|
from ..util import registry
|
||||||
|
from .. import util
|
||||||
from ..errors import Errors
|
from ..errors import Errors
|
||||||
from wasabi import msg
|
from wasabi import msg
|
||||||
|
|
||||||
|
@ -66,7 +67,7 @@ def console_logger():
|
||||||
|
|
||||||
|
|
||||||
@registry.loggers("spacy.WandbLogger.v1")
|
@registry.loggers("spacy.WandbLogger.v1")
|
||||||
def wandb_logger(project_name: str):
|
def wandb_logger(project_name: str, remove_config_values: List[str] = []):
|
||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
console = console_logger()
|
console = console_logger()
|
||||||
|
@ -75,16 +76,19 @@ def wandb_logger(project_name: str):
|
||||||
nlp: "Language"
|
nlp: "Language"
|
||||||
) -> Tuple[Callable[[Dict[str, Any]], None], Callable]:
|
) -> Tuple[Callable[[Dict[str, Any]], None], Callable]:
|
||||||
config = nlp.config.interpolate()
|
config = nlp.config.interpolate()
|
||||||
|
config_dot = util.dict_to_dot(config)
|
||||||
|
for field in remove_config_values:
|
||||||
|
del config_dot[field]
|
||||||
|
config = util.dot_to_dict(config_dot)
|
||||||
wandb.init(project=project_name, config=config)
|
wandb.init(project=project_name, config=config)
|
||||||
console_log_step, console_finalize = console(nlp)
|
console_log_step, console_finalize = console(nlp)
|
||||||
|
|
||||||
def log_step(info: Dict[str, Any]):
|
def log_step(info: Dict[str, Any]):
|
||||||
console_log_step(info)
|
console_log_step(info)
|
||||||
epoch = info["epoch"]
|
|
||||||
score = info["score"]
|
score = info["score"]
|
||||||
other_scores = info["other_scores"]
|
other_scores = info["other_scores"]
|
||||||
losses = info["losses"]
|
losses = info["losses"]
|
||||||
wandb.log({"score": score, "epoch": epoch})
|
wandb.log({"score": score})
|
||||||
if losses:
|
if losses:
|
||||||
wandb.log({f"loss_{k}": v for k, v in losses.items()})
|
wandb.log({f"loss_{k}": v for k, v in losses.items()})
|
||||||
if isinstance(other_scores, dict):
|
if isinstance(other_scores, dict):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user