mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 01:04:34 +03:00
Weights & Biases logger for train CLI (#5971)
* quick test as part of train script * train_logger in config, default ConsoleLogger in loggers catalogue * entitiy typo * add wandb_logger * cleanup * Update spacy/cli/train_logger.py Co-authored-by: Ines Montani <ines@ines.io> * move loggers to gold.loggers Co-authored-by: Ines Montani <ines@ines.io>
This commit is contained in:
parent
cb54f0d779
commit
79d460e3a2
|
@ -20,6 +20,7 @@ from .errors import Errors
|
|||
from .language import Language
|
||||
from . import util
|
||||
|
||||
|
||||
if sys.maxunicode == 65535:
|
||||
raise SystemError(Errors.E130)
|
||||
|
||||
|
|
|
@ -18,9 +18,6 @@ from .. import util
|
|||
from ..gold.example import Example
|
||||
from ..errors import Errors
|
||||
|
||||
# Don't remove - required to load the built-in architectures
|
||||
from ..ml import models # noqa: F401
|
||||
|
||||
|
||||
@app.command(
|
||||
"train", context_settings={"allow_extra_args": True, "ignore_unknown_options": True}
|
||||
|
@ -96,6 +93,7 @@ def train(
|
|||
train_corpus = T_cfg["train_corpus"]
|
||||
dev_corpus = T_cfg["dev_corpus"]
|
||||
batcher = T_cfg["batcher"]
|
||||
train_logger = T_cfg["logger"]
|
||||
# Components that shouldn't be updated during training
|
||||
frozen_components = T_cfg["frozen_components"]
|
||||
# Sourced components that require resume_training
|
||||
|
@ -149,10 +147,11 @@ def train(
|
|||
exclude=frozen_components,
|
||||
)
|
||||
msg.info(f"Training. Initial learn rate: {optimizer.learn_rate}")
|
||||
print_row = setup_printer(T_cfg, nlp)
|
||||
print_row, finalize_logger = train_logger(nlp)
|
||||
|
||||
try:
|
||||
progress = tqdm.tqdm(total=T_cfg["eval_frequency"], leave=False)
|
||||
progress.set_description(f"Epoch 1")
|
||||
for batch, info, is_best_checkpoint in training_step_iterator:
|
||||
progress.update(1)
|
||||
if is_best_checkpoint is not None:
|
||||
|
@ -162,7 +161,9 @@ def train(
|
|||
update_meta(T_cfg, nlp, info)
|
||||
nlp.to_disk(output_path / "model-best")
|
||||
progress = tqdm.tqdm(total=T_cfg["eval_frequency"], leave=False)
|
||||
progress.set_description(f"Epoch {info['epoch']}")
|
||||
except Exception as e:
|
||||
finalize_logger()
|
||||
if output_path is not None:
|
||||
# We don't want to swallow the traceback if we don't have a
|
||||
# specific error.
|
||||
|
@ -173,6 +174,7 @@ def train(
|
|||
nlp.to_disk(output_path / "model-final")
|
||||
raise e
|
||||
finally:
|
||||
finalize_logger()
|
||||
if output_path is not None:
|
||||
final_model_path = output_path / "model-final"
|
||||
if optimizer.averages:
|
||||
|
@ -203,7 +205,7 @@ def create_train_batches(iterator, batcher, max_epochs: int):
|
|||
|
||||
|
||||
def create_evaluation_callback(
|
||||
nlp: Language, dev_corpus: Callable, weights: Dict[str, float],
|
||||
nlp: Language, dev_corpus: Callable, weights: Dict[str, float]
|
||||
) -> Callable[[], Tuple[float, Dict[str, float]]]:
|
||||
def evaluate() -> Tuple[float, Dict[str, float]]:
|
||||
dev_examples = list(dev_corpus(nlp))
|
||||
|
@ -353,57 +355,6 @@ def subdivide_batch(batch, accumulate_gradient):
|
|||
yield subbatch
|
||||
|
||||
|
||||
def setup_printer(
|
||||
training: Union[Dict[str, Any], Config], nlp: Language
|
||||
) -> Callable[[Dict[str, Any]], None]:
|
||||
score_cols = list(training["score_weights"])
|
||||
score_widths = [max(len(col), 6) for col in score_cols]
|
||||
loss_cols = [f"Loss {pipe}" for pipe in nlp.pipe_names]
|
||||
loss_widths = [max(len(col), 8) for col in loss_cols]
|
||||
table_header = ["E", "#"] + loss_cols + score_cols + ["Score"]
|
||||
table_header = [col.upper() for col in table_header]
|
||||
table_widths = [3, 6] + loss_widths + score_widths + [6]
|
||||
table_aligns = ["r" for _ in table_widths]
|
||||
msg.row(table_header, widths=table_widths)
|
||||
msg.row(["-" * width for width in table_widths])
|
||||
|
||||
def print_row(info: Dict[str, Any]) -> None:
|
||||
try:
|
||||
losses = [
|
||||
"{0:.2f}".format(float(info["losses"][pipe_name]))
|
||||
for pipe_name in nlp.pipe_names
|
||||
]
|
||||
except KeyError as e:
|
||||
raise KeyError(
|
||||
Errors.E983.format(
|
||||
dict="scores (losses)", key=str(e), keys=list(info["losses"].keys())
|
||||
)
|
||||
) from None
|
||||
|
||||
try:
|
||||
scores = [
|
||||
"{0:.2f}".format(float(info["other_scores"].get(col, 0.0)) * 100)
|
||||
for col in score_cols
|
||||
]
|
||||
except KeyError as e:
|
||||
raise KeyError(
|
||||
Errors.E983.format(
|
||||
dict="scores (other)",
|
||||
key=str(e),
|
||||
keys=list(info["other_scores"].keys()),
|
||||
)
|
||||
) from None
|
||||
data = (
|
||||
[info["epoch"], info["step"]]
|
||||
+ losses
|
||||
+ scores
|
||||
+ ["{0:.2f}".format(float(info["score"]))]
|
||||
)
|
||||
msg.row(data, widths=table_widths, aligns=table_aligns)
|
||||
|
||||
return print_row
|
||||
|
||||
|
||||
def update_meta(
|
||||
training: Union[Dict[str, Any], Config], nlp: Language, info: Dict[str, Any]
|
||||
) -> None:
|
||||
|
@ -435,7 +386,7 @@ def load_from_paths(
|
|||
return raw_text, tag_map, morph_rules, weights_data
|
||||
|
||||
|
||||
def verify_cli_args(config_path: Path, output_path: Optional[Path] = None,) -> None:
|
||||
def verify_cli_args(config_path: Path, output_path: Optional[Path] = None) -> None:
|
||||
# Make sure all files and paths exists if they are needed
|
||||
if not config_path or not config_path.exists():
|
||||
msg.fail("Config file not found", config_path, exits=1)
|
||||
|
|
|
@ -40,6 +40,9 @@ score_weights = {}
|
|||
# Names of pipeline components that shouldn't be updated during training
|
||||
frozen_components = []
|
||||
|
||||
[training.logger]
|
||||
@loggers = "spacy.ConsoleLogger.v1"
|
||||
|
||||
[training.train_corpus]
|
||||
@readers = "spacy.Corpus.v1"
|
||||
path = ${paths.train}
|
||||
|
|
|
@ -6,3 +6,4 @@ from .iob_utils import biluo_tags_from_offsets, offsets_from_biluo_tags # noqa:
|
|||
from .iob_utils import spans_from_biluo_tags, tags_to_entities # noqa: F401
|
||||
from .gold_io import docs_to_json, read_json_file # noqa: F401
|
||||
from .batchers import minibatch_by_padded_size, minibatch_by_words # noqa: F401
|
||||
from .loggers import console_logger, wandb_logger # noqa: F401
|
||||
|
|
99
spacy/gold/loggers.py
Normal file
99
spacy/gold/loggers.py
Normal file
|
@ -0,0 +1,99 @@
|
|||
from typing import Dict, Any, Tuple, Callable
|
||||
|
||||
from ..util import registry
|
||||
from ..errors import Errors
|
||||
from wasabi import msg
|
||||
|
||||
|
||||
@registry.loggers("spacy.ConsoleLogger.v1")
|
||||
def console_logger():
|
||||
def setup_printer(
|
||||
nlp: "Language"
|
||||
) -> Tuple[Callable[[Dict[str, Any]], None], Callable]:
|
||||
score_cols = list(nlp.config["training"]["score_weights"])
|
||||
score_widths = [max(len(col), 6) for col in score_cols]
|
||||
loss_cols = [f"Loss {pipe}" for pipe in nlp.pipe_names]
|
||||
loss_widths = [max(len(col), 8) for col in loss_cols]
|
||||
table_header = ["E", "#"] + loss_cols + score_cols + ["Score"]
|
||||
table_header = [col.upper() for col in table_header]
|
||||
table_widths = [3, 6] + loss_widths + score_widths + [6]
|
||||
table_aligns = ["r" for _ in table_widths]
|
||||
msg.row(table_header, widths=table_widths)
|
||||
msg.row(["-" * width for width in table_widths])
|
||||
|
||||
def log_step(info: Dict[str, Any]):
|
||||
try:
|
||||
losses = [
|
||||
"{0:.2f}".format(float(info["losses"][pipe_name]))
|
||||
for pipe_name in nlp.pipe_names
|
||||
]
|
||||
except KeyError as e:
|
||||
raise KeyError(
|
||||
Errors.E983.format(
|
||||
dict="scores (losses)",
|
||||
key=str(e),
|
||||
keys=list(info["losses"].keys()),
|
||||
)
|
||||
) from None
|
||||
|
||||
try:
|
||||
scores = [
|
||||
"{0:.2f}".format(float(info["other_scores"].get(col, 0.0)) * 100)
|
||||
for col in score_cols
|
||||
]
|
||||
except KeyError as e:
|
||||
raise KeyError(
|
||||
Errors.E983.format(
|
||||
dict="scores (other)",
|
||||
key=str(e),
|
||||
keys=list(info["other_scores"].keys()),
|
||||
)
|
||||
) from None
|
||||
data = (
|
||||
[info["epoch"], info["step"]]
|
||||
+ losses
|
||||
+ scores
|
||||
+ ["{0:.2f}".format(float(info["score"]))]
|
||||
)
|
||||
msg.row(data, widths=table_widths, aligns=table_aligns)
|
||||
|
||||
def finalize():
|
||||
pass
|
||||
|
||||
return log_step, finalize
|
||||
|
||||
return setup_printer
|
||||
|
||||
|
||||
@registry.loggers("spacy.WandbLogger.v1")
|
||||
def wandb_logger(project_name: str):
|
||||
import wandb
|
||||
|
||||
console = console_logger()
|
||||
|
||||
def setup_logger(
|
||||
nlp: "Language"
|
||||
) -> Tuple[Callable[[Dict[str, Any]], None], Callable]:
|
||||
config = nlp.config.interpolate()
|
||||
wandb.init(project=project_name, config=config)
|
||||
console_log_step, console_finalize = console(nlp)
|
||||
|
||||
def log_step(info: Dict[str, Any]):
|
||||
console_log_step(info)
|
||||
epoch = info["epoch"]
|
||||
score = info["score"]
|
||||
other_scores = info["other_scores"]
|
||||
losses = info["losses"]
|
||||
wandb.log({"score": score, "epoch": epoch})
|
||||
if losses:
|
||||
wandb.log({f"loss_{k}": v for k, v in losses.items()})
|
||||
if isinstance(other_scores, dict):
|
||||
wandb.log(other_scores)
|
||||
|
||||
def finalize():
|
||||
console_finalize()
|
||||
pass
|
||||
|
||||
return log_step, finalize
|
||||
|
||||
return setup_logger
|
|
@ -68,7 +68,7 @@ class EntityRuler:
|
|||
ent_id_sep: str = DEFAULT_ENT_ID_SEP,
|
||||
patterns: Optional[List[PatternType]] = None,
|
||||
) -> None:
|
||||
"""Initialize the entitiy ruler. If patterns are supplied here, they
|
||||
"""Initialize the entity ruler. If patterns are supplied here, they
|
||||
need to be a list of dictionaries with a `"label"` and `"pattern"`
|
||||
key. A pattern can either be a token pattern (list) or a phrase pattern
|
||||
(string). For example: `{'label': 'ORG', 'pattern': 'Apple'}`.
|
||||
|
@ -223,7 +223,7 @@ class EntityRuler:
|
|||
return all_patterns
|
||||
|
||||
def add_patterns(self, patterns: List[PatternType]) -> None:
|
||||
"""Add patterns to the entitiy ruler. A pattern can either be a token
|
||||
"""Add patterns to the entity ruler. A pattern can either be a token
|
||||
pattern (list of dicts) or a phrase pattern (string). For example:
|
||||
{'label': 'ORG', 'pattern': 'Apple'}
|
||||
{'label': 'GPE', 'pattern': [{'lower': 'san'}, {'lower': 'francisco'}]}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Dict, List, Union, Optional, Sequence, Any, Callable, Type
|
||||
from typing import Dict, List, Union, Optional, Sequence, Any, Callable, Type, Tuple
|
||||
from typing import Iterable, TypeVar, TYPE_CHECKING
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, Field, ValidationError, validator
|
||||
|
@ -18,6 +18,7 @@ if TYPE_CHECKING:
|
|||
ItemT = TypeVar("ItemT")
|
||||
Batcher = Callable[[Iterable[ItemT]], Iterable[List[ItemT]]]
|
||||
Reader = Callable[["Language", str], Iterable["Example"]]
|
||||
Logger = Callable[["Language"], Tuple[Callable[[Dict[str, Any]], None], Callable]]
|
||||
|
||||
|
||||
def validate(schema: Type[BaseModel], obj: Dict[str, Any]) -> List[str]:
|
||||
|
@ -209,6 +210,7 @@ class ConfigSchemaTraining(BaseModel):
|
|||
init_tok2vec: Optional[StrictStr] = Field(..., title="Path to pretrained tok2vec weights")
|
||||
raw_text: Optional[StrictStr] = Field(default=None, title="Raw text")
|
||||
optimizer: Optimizer = Field(..., title="The optimizer to use")
|
||||
logger: Logger = Field(..., title="The logger to track training progress")
|
||||
frozen_components: List[str] = Field(..., title="Pipeline components that shouldn't be updated during training")
|
||||
# fmt: on
|
||||
|
||||
|
|
|
@ -81,6 +81,7 @@ class registry(thinc.registry):
|
|||
callbacks = catalogue.create("spacy", "callbacks")
|
||||
batchers = catalogue.create("spacy", "batchers", entry_points=True)
|
||||
readers = catalogue.create("spacy", "readers", entry_points=True)
|
||||
loggers = catalogue.create("spacy", "loggers", entry_points=True)
|
||||
# These are factories registered via third-party packages and the
|
||||
# spacy_factories entry point. This registry only exists so we can easily
|
||||
# load them via the entry points. The "true" factories are added via the
|
||||
|
|
Loading…
Reference in New Issue
Block a user