diff --git a/spacy/training/loggers.py b/spacy/training/loggers.py index be2da4bd8..e8c948f54 100644 --- a/spacy/training/loggers.py +++ b/spacy/training/loggers.py @@ -1,5 +1,5 @@ from typing import TYPE_CHECKING, Dict, Any, Tuple, Callable, List, Optional, IO -import wasabi +from wasabi import Printer import tqdm import sys @@ -7,15 +7,16 @@ from ..util import registry from .. import util from ..errors import Errors +if TYPE_CHECKING: + from ..language import Language # noqa: F401 + @registry.loggers("spacy.ConsoleLogger.v1") -def console_logger(progress_bar: bool=False): +def console_logger(progress_bar: bool = False): def setup_printer( - nlp: "Language", - stdout: IO=sys.stdout, - stderr: IO=sys.stderr - ) -> Tuple[Callable[[Optional[Dict[str, Any]]], None], Callable]: - msg = wasabi.Printer(no_print=True) + nlp: "Language", stdout: IO = sys.stdout, stderr: IO = sys.stderr + ) -> Tuple[Callable[[Optional[Dict[str, Any]]], None], Callable[[], None]]: + msg = Printer(no_print=True) # we assume here that only components are enabled that should be trained & logged logged_pipes = nlp.pipe_names eval_frequency = nlp.config["training"]["eval_frequency"] @@ -32,14 +33,14 @@ def console_logger(progress_bar: bool=False): stdout.write(msg.row(["-" * width for width in table_widths])) progress = None - def log_step(info: Optional[Dict[str, Any]]): + def log_step(info: Optional[Dict[str, Any]]) -> None: nonlocal progress if info is None: # If we don't have a new checkpoint, just return. if progress is not None: progress.update(1) - return + return try: losses = [ "{0:.2f}".format(float(info["losses"][pipe_name])) @@ -78,14 +79,11 @@ def console_logger(progress_bar: bool=False): if progress_bar: # Set disable=None, so that it disables on non-TTY progress = tqdm.tqdm( - total=eval_frequency, - disable=None, - leave=False, - file=stderr + total=eval_frequency, disable=None, leave=False, file=stderr ) progress.set_description(f"Epoch {info['epoch']+1}") - def finalize(): + def finalize() -> None: pass return log_step, finalize @@ -100,10 +98,8 @@ def wandb_logger(project_name: str, remove_config_values: List[str] = []): console = console_logger(progress_bar=False) def setup_logger( - nlp: "Language", - stdout: IO=sys.stdout, - stderr: IO=sys.stderr - ) -> Tuple[Callable[[Dict[str, Any]], None], Callable]: + nlp: "Language", stdout: IO = sys.stdout, stderr: IO = sys.stderr + ) -> Tuple[Callable[[Dict[str, Any]], None], Callable[[], None]]: config = nlp.config.interpolate() config_dot = util.dict_to_dot(config) for field in remove_config_values: @@ -124,7 +120,7 @@ def wandb_logger(project_name: str, remove_config_values: List[str] = []): if isinstance(other_scores, dict): wandb.log(other_scores) - def finalize(): + def finalize() -> None: console_finalize() wandb.join()