Weights & Biases logger for train CLI (#5971)

* quick test as part of train script

* train_logger in config, default ConsoleLogger in loggers catalogue

* entitiy typo

* add wandb_logger

* cleanup

* Update spacy/cli/train_logger.py

Co-authored-by: Ines Montani <ines@ines.io>

* move loggers to gold.loggers

Co-authored-by: Ines Montani <ines@ines.io>
This commit is contained in:
Sofie Van Landeghem 2020-08-26 15:24:33 +02:00 committed by GitHub
parent cb54f0d779
commit 79d460e3a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 118 additions and 60 deletions

View File

@ -20,6 +20,7 @@ from .errors import Errors
from .language import Language
from . import util
if sys.maxunicode == 65535:
raise SystemError(Errors.E130)

View File

@ -18,9 +18,6 @@ from .. import util
from ..gold.example import Example
from ..errors import Errors
# Don't remove - required to load the built-in architectures
from ..ml import models # noqa: F401
@app.command(
"train", context_settings={"allow_extra_args": True, "ignore_unknown_options": True}
@ -96,6 +93,7 @@ def train(
train_corpus = T_cfg["train_corpus"]
dev_corpus = T_cfg["dev_corpus"]
batcher = T_cfg["batcher"]
train_logger = T_cfg["logger"]
# Components that shouldn't be updated during training
frozen_components = T_cfg["frozen_components"]
# Sourced components that require resume_training
@ -149,10 +147,11 @@ def train(
exclude=frozen_components,
)
msg.info(f"Training. Initial learn rate: {optimizer.learn_rate}")
print_row = setup_printer(T_cfg, nlp)
print_row, finalize_logger = train_logger(nlp)
try:
progress = tqdm.tqdm(total=T_cfg["eval_frequency"], leave=False)
progress.set_description(f"Epoch 1")
for batch, info, is_best_checkpoint in training_step_iterator:
progress.update(1)
if is_best_checkpoint is not None:
@ -162,7 +161,9 @@ def train(
update_meta(T_cfg, nlp, info)
nlp.to_disk(output_path / "model-best")
progress = tqdm.tqdm(total=T_cfg["eval_frequency"], leave=False)
progress.set_description(f"Epoch {info['epoch']}")
except Exception as e:
finalize_logger()
if output_path is not None:
# We don't want to swallow the traceback if we don't have a
# specific error.
@ -173,6 +174,7 @@ def train(
nlp.to_disk(output_path / "model-final")
raise e
finally:
finalize_logger()
if output_path is not None:
final_model_path = output_path / "model-final"
if optimizer.averages:
@ -203,7 +205,7 @@ def create_train_batches(iterator, batcher, max_epochs: int):
def create_evaluation_callback(
nlp: Language, dev_corpus: Callable, weights: Dict[str, float],
nlp: Language, dev_corpus: Callable, weights: Dict[str, float]
) -> Callable[[], Tuple[float, Dict[str, float]]]:
def evaluate() -> Tuple[float, Dict[str, float]]:
dev_examples = list(dev_corpus(nlp))
@ -353,57 +355,6 @@ def subdivide_batch(batch, accumulate_gradient):
yield subbatch
def setup_printer(
training: Union[Dict[str, Any], Config], nlp: Language
) -> Callable[[Dict[str, Any]], None]:
score_cols = list(training["score_weights"])
score_widths = [max(len(col), 6) for col in score_cols]
loss_cols = [f"Loss {pipe}" for pipe in nlp.pipe_names]
loss_widths = [max(len(col), 8) for col in loss_cols]
table_header = ["E", "#"] + loss_cols + score_cols + ["Score"]
table_header = [col.upper() for col in table_header]
table_widths = [3, 6] + loss_widths + score_widths + [6]
table_aligns = ["r" for _ in table_widths]
msg.row(table_header, widths=table_widths)
msg.row(["-" * width for width in table_widths])
def print_row(info: Dict[str, Any]) -> None:
try:
losses = [
"{0:.2f}".format(float(info["losses"][pipe_name]))
for pipe_name in nlp.pipe_names
]
except KeyError as e:
raise KeyError(
Errors.E983.format(
dict="scores (losses)", key=str(e), keys=list(info["losses"].keys())
)
) from None
try:
scores = [
"{0:.2f}".format(float(info["other_scores"].get(col, 0.0)) * 100)
for col in score_cols
]
except KeyError as e:
raise KeyError(
Errors.E983.format(
dict="scores (other)",
key=str(e),
keys=list(info["other_scores"].keys()),
)
) from None
data = (
[info["epoch"], info["step"]]
+ losses
+ scores
+ ["{0:.2f}".format(float(info["score"]))]
)
msg.row(data, widths=table_widths, aligns=table_aligns)
return print_row
def update_meta(
training: Union[Dict[str, Any], Config], nlp: Language, info: Dict[str, Any]
) -> None:
@ -435,7 +386,7 @@ def load_from_paths(
return raw_text, tag_map, morph_rules, weights_data
def verify_cli_args(config_path: Path, output_path: Optional[Path] = None,) -> None:
def verify_cli_args(config_path: Path, output_path: Optional[Path] = None) -> None:
# Make sure all files and paths exists if they are needed
if not config_path or not config_path.exists():
msg.fail("Config file not found", config_path, exits=1)

View File

@ -40,6 +40,9 @@ score_weights = {}
# Names of pipeline components that shouldn't be updated during training
frozen_components = []
[training.logger]
@loggers = "spacy.ConsoleLogger.v1"
[training.train_corpus]
@readers = "spacy.Corpus.v1"
path = ${paths.train}

View File

@ -6,3 +6,4 @@ from .iob_utils import biluo_tags_from_offsets, offsets_from_biluo_tags # noqa:
from .iob_utils import spans_from_biluo_tags, tags_to_entities # noqa: F401
from .gold_io import docs_to_json, read_json_file # noqa: F401
from .batchers import minibatch_by_padded_size, minibatch_by_words # noqa: F401
from .loggers import console_logger, wandb_logger # noqa: F401

99
spacy/gold/loggers.py Normal file
View File

@ -0,0 +1,99 @@
from typing import Dict, Any, Tuple, Callable
from ..util import registry
from ..errors import Errors
from wasabi import msg
@registry.loggers("spacy.ConsoleLogger.v1")
def console_logger():
def setup_printer(
nlp: "Language"
) -> Tuple[Callable[[Dict[str, Any]], None], Callable]:
score_cols = list(nlp.config["training"]["score_weights"])
score_widths = [max(len(col), 6) for col in score_cols]
loss_cols = [f"Loss {pipe}" for pipe in nlp.pipe_names]
loss_widths = [max(len(col), 8) for col in loss_cols]
table_header = ["E", "#"] + loss_cols + score_cols + ["Score"]
table_header = [col.upper() for col in table_header]
table_widths = [3, 6] + loss_widths + score_widths + [6]
table_aligns = ["r" for _ in table_widths]
msg.row(table_header, widths=table_widths)
msg.row(["-" * width for width in table_widths])
def log_step(info: Dict[str, Any]):
try:
losses = [
"{0:.2f}".format(float(info["losses"][pipe_name]))
for pipe_name in nlp.pipe_names
]
except KeyError as e:
raise KeyError(
Errors.E983.format(
dict="scores (losses)",
key=str(e),
keys=list(info["losses"].keys()),
)
) from None
try:
scores = [
"{0:.2f}".format(float(info["other_scores"].get(col, 0.0)) * 100)
for col in score_cols
]
except KeyError as e:
raise KeyError(
Errors.E983.format(
dict="scores (other)",
key=str(e),
keys=list(info["other_scores"].keys()),
)
) from None
data = (
[info["epoch"], info["step"]]
+ losses
+ scores
+ ["{0:.2f}".format(float(info["score"]))]
)
msg.row(data, widths=table_widths, aligns=table_aligns)
def finalize():
pass
return log_step, finalize
return setup_printer
@registry.loggers("spacy.WandbLogger.v1")
def wandb_logger(project_name: str):
import wandb
console = console_logger()
def setup_logger(
nlp: "Language"
) -> Tuple[Callable[[Dict[str, Any]], None], Callable]:
config = nlp.config.interpolate()
wandb.init(project=project_name, config=config)
console_log_step, console_finalize = console(nlp)
def log_step(info: Dict[str, Any]):
console_log_step(info)
epoch = info["epoch"]
score = info["score"]
other_scores = info["other_scores"]
losses = info["losses"]
wandb.log({"score": score, "epoch": epoch})
if losses:
wandb.log({f"loss_{k}": v for k, v in losses.items()})
if isinstance(other_scores, dict):
wandb.log(other_scores)
def finalize():
console_finalize()
pass
return log_step, finalize
return setup_logger

View File

@ -68,7 +68,7 @@ class EntityRuler:
ent_id_sep: str = DEFAULT_ENT_ID_SEP,
patterns: Optional[List[PatternType]] = None,
) -> None:
"""Initialize the entitiy ruler. If patterns are supplied here, they
"""Initialize the entity ruler. If patterns are supplied here, they
need to be a list of dictionaries with a `"label"` and `"pattern"`
key. A pattern can either be a token pattern (list) or a phrase pattern
(string). For example: `{'label': 'ORG', 'pattern': 'Apple'}`.
@ -223,7 +223,7 @@ class EntityRuler:
return all_patterns
def add_patterns(self, patterns: List[PatternType]) -> None:
"""Add patterns to the entitiy ruler. A pattern can either be a token
"""Add patterns to the entity ruler. A pattern can either be a token
pattern (list of dicts) or a phrase pattern (string). For example:
{'label': 'ORG', 'pattern': 'Apple'}
{'label': 'GPE', 'pattern': [{'lower': 'san'}, {'lower': 'francisco'}]}

View File

@ -1,4 +1,4 @@
from typing import Dict, List, Union, Optional, Sequence, Any, Callable, Type
from typing import Dict, List, Union, Optional, Sequence, Any, Callable, Type, Tuple
from typing import Iterable, TypeVar, TYPE_CHECKING
from enum import Enum
from pydantic import BaseModel, Field, ValidationError, validator
@ -18,6 +18,7 @@ if TYPE_CHECKING:
ItemT = TypeVar("ItemT")
Batcher = Callable[[Iterable[ItemT]], Iterable[List[ItemT]]]
Reader = Callable[["Language", str], Iterable["Example"]]
Logger = Callable[["Language"], Tuple[Callable[[Dict[str, Any]], None], Callable]]
def validate(schema: Type[BaseModel], obj: Dict[str, Any]) -> List[str]:
@ -209,6 +210,7 @@ class ConfigSchemaTraining(BaseModel):
init_tok2vec: Optional[StrictStr] = Field(..., title="Path to pretrained tok2vec weights")
raw_text: Optional[StrictStr] = Field(default=None, title="Raw text")
optimizer: Optimizer = Field(..., title="The optimizer to use")
logger: Logger = Field(..., title="The logger to track training progress")
frozen_components: List[str] = Field(..., title="Pipeline components that shouldn't be updated during training")
# fmt: on

View File

@ -81,6 +81,7 @@ class registry(thinc.registry):
callbacks = catalogue.create("spacy", "callbacks")
batchers = catalogue.create("spacy", "batchers", entry_points=True)
readers = catalogue.create("spacy", "readers", entry_points=True)
loggers = catalogue.create("spacy", "loggers", entry_points=True)
# These are factories registered via third-party packages and the
# spacy_factories entry point. This registry only exists so we can easily
# load them via the entry points. The "true" factories are added via the