mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-25 21:21:10 +03:00 
			
		
		
		
	* Add optional artifacts logging * Update docs * Update spacy/training/loggers.py Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> * Update spacy/training/loggers.py Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> * Update spacy/training/loggers.py Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> * Bump WandbLogger Version * Add documentation of v1 to legacy docs * bump spacy-legacy to 3.0.2 (to be released) Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> Co-authored-by: svlandeg <sofie.vanlandeghem@gmail.com>
		
			
				
	
	
		
			178 lines
		
	
	
		
			6.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			178 lines
		
	
	
		
			6.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from typing import TYPE_CHECKING, Dict, Any, Tuple, Callable, List, Optional, IO
 | |
| from wasabi import Printer
 | |
| import tqdm
 | |
| import sys
 | |
| 
 | |
| from ..util import registry
 | |
| from .. import util
 | |
| from ..errors import Errors
 | |
| 
 | |
| if TYPE_CHECKING:
 | |
|     from ..language import Language  # noqa: F401
 | |
| 
 | |
| 
 | |
| def setup_table(
 | |
|     *, cols: List[str], widths: List[int], max_width: int = 13
 | |
| ) -> Tuple[List[str], List[int], List[str]]:
 | |
|     final_cols = []
 | |
|     final_widths = []
 | |
|     for col, width in zip(cols, widths):
 | |
|         if len(col) > max_width:
 | |
|             col = col[: max_width - 3] + "..."  # shorten column if too long
 | |
|         final_cols.append(col.upper())
 | |
|         final_widths.append(max(len(col), width))
 | |
|     return final_cols, final_widths, ["r" for _ in final_widths]
 | |
| 
 | |
| 
 | |
| @registry.loggers("spacy.ConsoleLogger.v1")
 | |
| def console_logger(progress_bar: bool = False):
 | |
|     def setup_printer(
 | |
|         nlp: "Language", stdout: IO = sys.stdout, stderr: IO = sys.stderr
 | |
|     ) -> Tuple[Callable[[Optional[Dict[str, Any]]], None], Callable[[], None]]:
 | |
|         write = lambda text: stdout.write(f"{text}\n")
 | |
|         msg = Printer(no_print=True)
 | |
|         # ensure that only trainable components are logged
 | |
|         logged_pipes = [
 | |
|             name
 | |
|             for name, proc in nlp.pipeline
 | |
|             if hasattr(proc, "is_trainable") and proc.is_trainable
 | |
|         ]
 | |
|         eval_frequency = nlp.config["training"]["eval_frequency"]
 | |
|         score_weights = nlp.config["training"]["score_weights"]
 | |
|         score_cols = [col for col, value in score_weights.items() if value is not None]
 | |
|         loss_cols = [f"Loss {pipe}" for pipe in logged_pipes]
 | |
|         spacing = 2
 | |
|         table_header, table_widths, table_aligns = setup_table(
 | |
|             cols=["E", "#"] + loss_cols + score_cols + ["Score"],
 | |
|             widths=[3, 6] + [8 for _ in loss_cols] + [6 for _ in score_cols] + [6],
 | |
|         )
 | |
|         write(msg.row(table_header, widths=table_widths, spacing=spacing))
 | |
|         write(msg.row(["-" * width for width in table_widths], spacing=spacing))
 | |
|         progress = None
 | |
| 
 | |
|         def log_step(info: Optional[Dict[str, Any]]) -> None:
 | |
|             nonlocal progress
 | |
| 
 | |
|             if info is None:
 | |
|                 # If we don't have a new checkpoint, just return.
 | |
|                 if progress is not None:
 | |
|                     progress.update(1)
 | |
|                 return
 | |
|             losses = [
 | |
|                 "{0:.2f}".format(float(info["losses"][pipe_name]))
 | |
|                 for pipe_name in logged_pipes
 | |
|             ]
 | |
| 
 | |
|             scores = []
 | |
|             for col in score_cols:
 | |
|                 score = info["other_scores"].get(col, 0.0)
 | |
|                 try:
 | |
|                     score = float(score)
 | |
|                 except TypeError:
 | |
|                     err = Errors.E916.format(name=col, score_type=type(score))
 | |
|                     raise ValueError(err) from None
 | |
|                 if col != "speed":
 | |
|                     score *= 100
 | |
|                 scores.append("{0:.2f}".format(score))
 | |
| 
 | |
|             data = (
 | |
|                 [info["epoch"], info["step"]]
 | |
|                 + losses
 | |
|                 + scores
 | |
|                 + ["{0:.2f}".format(float(info["score"]))]
 | |
|             )
 | |
|             if progress is not None:
 | |
|                 progress.close()
 | |
|             write(
 | |
|                 msg.row(data, widths=table_widths, aligns=table_aligns, spacing=spacing)
 | |
|             )
 | |
|             if progress_bar:
 | |
|                 # Set disable=None, so that it disables on non-TTY
 | |
|                 progress = tqdm.tqdm(
 | |
|                     total=eval_frequency, disable=None, leave=False, file=stderr
 | |
|                 )
 | |
|                 progress.set_description(f"Epoch {info['epoch']+1}")
 | |
| 
 | |
|         def finalize() -> None:
 | |
|             pass
 | |
| 
 | |
|         return log_step, finalize
 | |
| 
 | |
|     return setup_printer
 | |
| 
 | |
| 
 | |
| @registry.loggers("spacy.WandbLogger.v2")
 | |
| def wandb_logger(
 | |
|     project_name: str,
 | |
|     remove_config_values: List[str] = [],
 | |
|     model_log_interval: Optional[int] = None,
 | |
|     log_dataset_dir: Optional[str] = None,
 | |
| ):
 | |
|     try:
 | |
|         import wandb
 | |
|         from wandb import init, log, join  # test that these are available
 | |
|     except ImportError:
 | |
|         raise ImportError(Errors.E880)
 | |
| 
 | |
|     console = console_logger(progress_bar=False)
 | |
| 
 | |
|     def setup_logger(
 | |
|         nlp: "Language", stdout: IO = sys.stdout, stderr: IO = sys.stderr
 | |
|     ) -> Tuple[Callable[[Dict[str, Any]], None], Callable[[], None]]:
 | |
|         config = nlp.config.interpolate()
 | |
|         config_dot = util.dict_to_dot(config)
 | |
|         for field in remove_config_values:
 | |
|             del config_dot[field]
 | |
|         config = util.dot_to_dict(config_dot)
 | |
|         run = wandb.init(project=project_name, config=config, reinit=True)
 | |
|         console_log_step, console_finalize = console(nlp, stdout, stderr)
 | |
| 
 | |
|         def log_dir_artifact(
 | |
|             path: str,
 | |
|             name: str,
 | |
|             type: str,
 | |
|             metadata: Optional[Dict[str, Any]] = {},
 | |
|             aliases: Optional[List[str]] = [],
 | |
|         ):
 | |
|             dataset_artifact = wandb.Artifact(name, type=type, metadata=metadata)
 | |
|             dataset_artifact.add_dir(path, name=name)
 | |
|             wandb.log_artifact(dataset_artifact, aliases=aliases)
 | |
| 
 | |
|         if log_dataset_dir:
 | |
|             log_dir_artifact(path=log_dataset_dir, name="dataset", type="dataset")
 | |
| 
 | |
|         def log_step(info: Optional[Dict[str, Any]]):
 | |
|             console_log_step(info)
 | |
|             if info is not None:
 | |
|                 score = info["score"]
 | |
|                 other_scores = info["other_scores"]
 | |
|                 losses = info["losses"]
 | |
|                 wandb.log({"score": score})
 | |
|                 if losses:
 | |
|                     wandb.log({f"loss_{k}": v for k, v in losses.items()})
 | |
|                 if isinstance(other_scores, dict):
 | |
|                     wandb.log(other_scores)
 | |
|                 if model_log_interval and info.get("output_path"):
 | |
|                     if info["step"] % model_log_interval == 0 and info["step"] != 0:
 | |
|                         log_dir_artifact(
 | |
|                             path=info["output_path"],
 | |
|                             name="pipeline_" + run.id,
 | |
|                             type="checkpoint",
 | |
|                             metadata=info,
 | |
|                             aliases=[
 | |
|                                 f"epoch {info['epoch']} step {info['step']}",
 | |
|                                 "latest",
 | |
|                                 "best"
 | |
|                                 if info["score"] == max(info["checkpoints"])[0]
 | |
|                                 else "",
 | |
|                             ],
 | |
|                         )
 | |
| 
 | |
|         def finalize() -> None:
 | |
|             console_finalize()
 | |
|             wandb.join()
 | |
| 
 | |
|         return log_step, finalize
 | |
| 
 | |
|     return setup_logger
 |