Tidy up, auto-format, types

This commit is contained in:
Ines Montani 2020-10-03 16:31:58 +02:00
parent 3b8f352eda
commit 989a96308f

View File

@ -1,5 +1,5 @@
from typing import TYPE_CHECKING, Dict, Any, Tuple, Callable, List, Optional, IO
import wasabi
from wasabi import Printer
import tqdm
import sys
@ -7,15 +7,16 @@ from ..util import registry
from .. import util
from ..errors import Errors
if TYPE_CHECKING:
from ..language import Language # noqa: F401
@registry.loggers("spacy.ConsoleLogger.v1")
def console_logger(progress_bar: bool=False):
def console_logger(progress_bar: bool = False):
def setup_printer(
nlp: "Language",
stdout: IO=sys.stdout,
stderr: IO=sys.stderr
) -> Tuple[Callable[[Optional[Dict[str, Any]]], None], Callable]:
msg = wasabi.Printer(no_print=True)
nlp: "Language", stdout: IO = sys.stdout, stderr: IO = sys.stderr
) -> Tuple[Callable[[Optional[Dict[str, Any]]], None], Callable[[], None]]:
msg = Printer(no_print=True)
# we assume here that only components are enabled that should be trained & logged
logged_pipes = nlp.pipe_names
eval_frequency = nlp.config["training"]["eval_frequency"]
@ -32,14 +33,14 @@ def console_logger(progress_bar: bool=False):
stdout.write(msg.row(["-" * width for width in table_widths]))
progress = None
def log_step(info: Optional[Dict[str, Any]]):
def log_step(info: Optional[Dict[str, Any]]) -> None:
nonlocal progress
if info is None:
# If we don't have a new checkpoint, just return.
if progress is not None:
progress.update(1)
return
return
try:
losses = [
"{0:.2f}".format(float(info["losses"][pipe_name]))
@ -78,14 +79,11 @@ def console_logger(progress_bar: bool=False):
if progress_bar:
# Set disable=None, so that it disables on non-TTY
progress = tqdm.tqdm(
total=eval_frequency,
disable=None,
leave=False,
file=stderr
total=eval_frequency, disable=None, leave=False, file=stderr
)
progress.set_description(f"Epoch {info['epoch']+1}")
def finalize():
def finalize() -> None:
pass
return log_step, finalize
@ -100,10 +98,8 @@ def wandb_logger(project_name: str, remove_config_values: List[str] = []):
console = console_logger(progress_bar=False)
def setup_logger(
nlp: "Language",
stdout: IO=sys.stdout,
stderr: IO=sys.stderr
) -> Tuple[Callable[[Dict[str, Any]], None], Callable]:
nlp: "Language", stdout: IO = sys.stdout, stderr: IO = sys.stderr
) -> Tuple[Callable[[Dict[str, Any]], None], Callable[[], None]]:
config = nlp.config.interpolate()
config_dot = util.dict_to_dot(config)
for field in remove_config_values:
@ -124,7 +120,7 @@ def wandb_logger(project_name: str, remove_config_values: List[str] = []):
if isinstance(other_scores, dict):
wandb.log(other_scores)
def finalize():
def finalize() -> None:
console_finalize()
wandb.join()