Move functionality to config training setting

This commit is contained in:
thomashacker 2023-01-19 14:54:28 +01:00
parent 2bb8db88f8
commit 58502caa57
5 changed files with 27 additions and 27 deletions

View File

@ -22,8 +22,7 @@ def train_cli(
output_path: Optional[Path] = Opt(None, "--output", "--output-path", "-o", help="Output directory to store trained pipeline in"), output_path: Optional[Path] = Opt(None, "--output", "--output-path", "-o", help="Output directory to store trained pipeline in"),
code_path: Optional[Path] = Opt(None, "--code", "-c", help="Path to Python file with additional code (registered functions) to be imported"), code_path: Optional[Path] = Opt(None, "--code", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
verbose: bool = Opt(False, "--verbose", "-V", "-VV", help="Display more information for debugging purposes"), verbose: bool = Opt(False, "--verbose", "-V", "-VV", help="Display more information for debugging purposes"),
use_gpu: int = Opt(-1, "--gpu-id", "-g", help="GPU ID or -1 for CPU"), use_gpu: int = Opt(-1, "--gpu-id", "-g", help="GPU ID or -1 for CPU")
use_rehearse: bool = Opt(False, "--use_rehearse", "-r", help="Perform 'rehearsal updates' on a pre-trained model")
# fmt: on # fmt: on
): ):
""" """
@ -43,13 +42,7 @@ def train_cli(
util.logger.setLevel(logging.DEBUG if verbose else logging.INFO) util.logger.setLevel(logging.DEBUG if verbose else logging.INFO)
overrides = parse_config_overrides(ctx.args) overrides = parse_config_overrides(ctx.args)
import_code(code_path) import_code(code_path)
train( train(config_path, output_path, use_gpu=use_gpu, overrides=overrides)
config_path,
output_path,
use_gpu=use_gpu,
overrides=overrides,
use_rehearse=use_rehearse,
)
def train( def train(
@ -88,5 +81,4 @@ def train(
use_gpu=use_gpu, use_gpu=use_gpu,
stdout=sys.stdout, stdout=sys.stdout,
stderr=sys.stderr, stderr=sys.stderr,
use_rehearse=use_rehearse,
) )

View File

@ -1183,6 +1183,7 @@ class Language:
losses: Optional[Dict[str, float]] = None, losses: Optional[Dict[str, float]] = None,
component_cfg: Optional[Dict[str, Dict[str, Any]]] = None, component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
exclude: Iterable[str] = SimpleFrozenList(), exclude: Iterable[str] = SimpleFrozenList(),
rehearse_components: List[str] = [],
) -> Dict[str, float]: ) -> Dict[str, float]:
"""Make a "rehearsal" update to the models in the pipeline, to prevent """Make a "rehearsal" update to the models in the pipeline, to prevent
forgetting. Rehearsal updates run an initial copy of the model over some forgetting. Rehearsal updates run an initial copy of the model over some
@ -1195,6 +1196,7 @@ class Language:
component_cfg (Dict[str, Dict]): Config parameters for specific pipeline component_cfg (Dict[str, Dict]): Config parameters for specific pipeline
components, keyed by component name. components, keyed by component name.
exclude (Iterable[str]): Names of components that shouldn't be updated. exclude (Iterable[str]): Names of components that shouldn't be updated.
rehearse_components (List[str]): Names of components that should be rehearsed
RETURNS (dict): Results from the update. RETURNS (dict): Results from the update.
EXAMPLE: EXAMPLE:
@ -1216,7 +1218,11 @@ class Language:
component_cfg = {} component_cfg = {}
for name, proc in pipes: for name, proc in pipes:
if name in exclude or not hasattr(proc, "rehearse"): if (
name in exclude
or not hasattr(proc, "rehearse")
or name not in rehearse_components
):
continue continue
proc.rehearse( # type: ignore[attr-defined] proc.rehearse( # type: ignore[attr-defined]
examples, sgd=None, losses=losses, **component_cfg.get(name, {}) examples, sgd=None, losses=losses, **component_cfg.get(name, {})

View File

@ -356,6 +356,7 @@ class ConfigSchemaTraining(BaseModel):
logger: Logger = Field(..., title="The logger to track training progress") logger: Logger = Field(..., title="The logger to track training progress")
frozen_components: List[str] = Field(..., title="Pipeline components that shouldn't be updated during training") frozen_components: List[str] = Field(..., title="Pipeline components that shouldn't be updated during training")
annotating_components: List[str] = Field(..., title="Pipeline components that should set annotations during training") annotating_components: List[str] = Field(..., title="Pipeline components that should set annotations during training")
rehearse_components: List[str] = Field(..., title="Pipeline components that should be rehearsed during training")
before_to_disk: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after training, before it's saved to disk") before_to_disk: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after training, before it's saved to disk")
before_update: Optional[Callable[["Language", Dict[str, Any]], None]] = Field(..., title="Optional callback that is invoked at the start of each training step") before_update: Optional[Callable[["Language", Dict[str, Any]], None]] = Field(..., title="Optional callback that is invoked at the start of each training step")
# fmt: on # fmt: on

View File

@ -60,6 +60,8 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
optimizer = T["optimizer"] optimizer = T["optimizer"]
# Components that shouldn't be updated during training # Components that shouldn't be updated during training
frozen_components = T["frozen_components"] frozen_components = T["frozen_components"]
# Components that shouldn't be updated during training
rehearse_components = T["rehearse_components"]
# Sourced components that require resume_training # Sourced components that require resume_training
resume_components = [p for p in sourced if p not in frozen_components] resume_components = [p for p in sourced if p not in frozen_components]
logger.info(f"Pipeline: {nlp.pipe_names}") logger.info(f"Pipeline: {nlp.pipe_names}")
@ -67,6 +69,8 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
with nlp.select_pipes(enable=resume_components): with nlp.select_pipes(enable=resume_components):
logger.info(f"Resuming training for: {resume_components}") logger.info(f"Resuming training for: {resume_components}")
nlp.resume_training(sgd=optimizer) nlp.resume_training(sgd=optimizer)
if rehearse_components:
logger.info(f"Rehearsing components: {rehearse_components}")
# Make sure that listeners are defined before initializing further # Make sure that listeners are defined before initializing further
nlp._link_components() nlp._link_components()
with nlp.select_pipes(disable=[*frozen_components, *resume_components]): with nlp.select_pipes(disable=[*frozen_components, *resume_components]):

View File

@ -26,7 +26,6 @@ def train(
output_path: Optional[Path] = None, output_path: Optional[Path] = None,
*, *,
use_gpu: int = -1, use_gpu: int = -1,
use_rehearse: bool = False,
stdout: IO = sys.stdout, stdout: IO = sys.stdout,
stderr: IO = sys.stderr, stderr: IO = sys.stderr,
) -> Tuple["Language", Optional[Path]]: ) -> Tuple["Language", Optional[Path]]:
@ -36,7 +35,6 @@ def train(
output_path (Optional[Path]): Optional output path to save trained model to. output_path (Optional[Path]): Optional output path to save trained model to.
use_gpu (int): Whether to train on GPU. Make sure to call require_gpu use_gpu (int): Whether to train on GPU. Make sure to call require_gpu
before calling this function. before calling this function.
use_rehearse (bool): Use nlp.rehearse after nlp.update
stdout (file): A file-like object to write output messages. To disable stdout (file): A file-like object to write output messages. To disable
printing, set to io.StringIO. printing, set to io.StringIO.
stderr (file): A second file-like object to write output messages. To disable stderr (file): A second file-like object to write output messages. To disable
@ -56,10 +54,7 @@ def train(
T = registry.resolve(config["training"], schema=ConfigSchemaTraining) T = registry.resolve(config["training"], schema=ConfigSchemaTraining)
dot_names = [T["train_corpus"], T["dev_corpus"]] dot_names = [T["train_corpus"], T["dev_corpus"]]
train_corpus, dev_corpus = resolve_dot_names(config, dot_names) train_corpus, dev_corpus = resolve_dot_names(config, dot_names)
if use_rehearse: optimizer = T["optimizer"]
optimizer = nlp.resume_training()
else:
optimizer = T["optimizer"]
score_weights = T["score_weights"] score_weights = T["score_weights"]
batcher = T["batcher"] batcher = T["batcher"]
train_logger = T["logger"] train_logger = T["logger"]
@ -82,6 +77,8 @@ def train(
frozen_components = T["frozen_components"] frozen_components = T["frozen_components"]
# Components that should set annotations on update # Components that should set annotations on update
annotating_components = T["annotating_components"] annotating_components = T["annotating_components"]
# Components that should be rehearsed after update
rehearse_components = T["rehearse_components"]
# Create iterator, which yields out info after each optimization step. # Create iterator, which yields out info after each optimization step.
training_step_iterator = train_while_improving( training_step_iterator = train_while_improving(
nlp, nlp,
@ -93,9 +90,9 @@ def train(
patience=T["patience"], patience=T["patience"],
max_steps=T["max_steps"], max_steps=T["max_steps"],
eval_frequency=T["eval_frequency"], eval_frequency=T["eval_frequency"],
use_rehearse=use_rehearse,
exclude=frozen_components, exclude=frozen_components,
annotating_components=annotating_components, annotating_components=annotating_components,
rehearse_components=rehearse_components,
before_update=before_update, before_update=before_update,
) )
clean_output_dir(output_path) clean_output_dir(output_path)
@ -156,9 +153,9 @@ def train_while_improving(
accumulate_gradient: int, accumulate_gradient: int,
patience: int, patience: int,
max_steps: int, max_steps: int,
use_rehearse: bool = False,
exclude: List[str], exclude: List[str],
annotating_components: List[str], annotating_components: List[str],
rehearse_components: List[str],
before_update: Optional[Callable[["Language", Dict[str, Any]], None]], before_update: Optional[Callable[["Language", Dict[str, Any]], None]],
): ):
"""Train until an evaluation stops improving. Works as a generator, """Train until an evaluation stops improving. Works as a generator,
@ -217,17 +214,17 @@ def train_while_improving(
subbatch, subbatch,
drop=dropout, drop=dropout,
losses=losses, losses=losses,
sgd=False, # type: ignore[arg-type] sgd=None, # type: ignore[arg-type]
exclude=exclude, exclude=exclude,
annotates=annotating_components, annotates=annotating_components,
) )
if use_rehearse: nlp.rehearse(
nlp.rehearse( subbatch,
subbatch, losses=losses,
losses=losses, sgd=None, # type: ignore[arg-type]
sgd=False, # type: ignore[arg-type] exclude=exclude,
exclude=exclude, rehearse_components=rehearse_components,
) )
# TODO: refactor this so we don't have to run it separately in here # TODO: refactor this so we don't have to run it separately in here
for name, proc in nlp.pipeline: for name, proc in nlp.pipeline:
if ( if (