mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	
		
			
				
	
	
		
			116 lines
		
	
	
		
			5.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			116 lines
		
	
	
		
			5.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
from typing import Optional
 | 
						|
from pathlib import Path
 | 
						|
from wasabi import msg
 | 
						|
import typer
 | 
						|
import re
 | 
						|
 | 
						|
from ._util import app, Arg, Opt, parse_config_overrides, show_validation_error
 | 
						|
from ._util import import_code, setup_gpu
 | 
						|
from ..training.pretrain import pretrain
 | 
						|
from ..util import load_config
 | 
						|
 | 
						|
 | 
						|
@app.command(
 | 
						|
    "pretrain",
 | 
						|
    context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
 | 
						|
)
 | 
						|
def pretrain_cli(
 | 
						|
    # fmt: off
 | 
						|
    ctx: typer.Context,  # This is only used to read additional arguments
 | 
						|
    config_path: Path = Arg(..., help="Path to config file", exists=True, dir_okay=False, allow_dash=True),
 | 
						|
    output_dir: Path = Arg(..., help="Directory to write weights to on each epoch"),
 | 
						|
    code_path: Optional[Path] = Opt(None, "--code", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
 | 
						|
    resume_path: Optional[Path] = Opt(None, "--resume-path", "-r", help="Path to pretrained weights from which to resume pretraining"),
 | 
						|
    epoch_resume: Optional[int] = Opt(None, "--epoch-resume", "-er", help="The epoch to resume counting from when using --resume-path. Prevents unintended overwriting of existing weight files."),
 | 
						|
    use_gpu: int = Opt(-1, "--gpu-id", "-g", help="GPU ID or -1 for CPU"),
 | 
						|
    # fmt: on
 | 
						|
):
 | 
						|
    """
 | 
						|
    Pre-train the 'token-to-vector' (tok2vec) layer of pipeline components,
 | 
						|
    using an approximate language-modelling objective. Two objective types
 | 
						|
    are available, vector-based and character-based.
 | 
						|
 | 
						|
    In the vector-based objective, we load word vectors that have been trained
 | 
						|
    using a word2vec-style distributional similarity algorithm, and train a
 | 
						|
    component like a CNN, BiLSTM, etc to predict vectors which match the
 | 
						|
    pretrained ones. The weights are saved to a directory after each epoch. You
 | 
						|
    can then pass a path to one of these pretrained weights files to the
 | 
						|
    'spacy train' command.
 | 
						|
 | 
						|
    This technique may be especially helpful if you have little labelled data.
 | 
						|
    However, it's still quite experimental, so your mileage may vary.
 | 
						|
 | 
						|
    To load the weights back in during 'spacy train', you need to ensure
 | 
						|
    all settings are the same between pretraining and training. Ideally,
 | 
						|
    this is done by using the same config file for both commands.
 | 
						|
 | 
						|
    DOCS: https://spacy.io/api/cli#pretrain
 | 
						|
    """
 | 
						|
    config_overrides = parse_config_overrides(ctx.args)
 | 
						|
    import_code(code_path)
 | 
						|
    verify_cli_args(config_path, output_dir, resume_path, epoch_resume)
 | 
						|
    setup_gpu(use_gpu)
 | 
						|
    msg.info(f"Loading config from: {config_path}")
 | 
						|
 | 
						|
    with show_validation_error(config_path):
 | 
						|
        raw_config = load_config(
 | 
						|
            config_path, overrides=config_overrides, interpolate=False
 | 
						|
        )
 | 
						|
    config = raw_config.interpolate()
 | 
						|
    if not config.get("pretraining"):
 | 
						|
        # TODO: What's the solution here? How do we handle optional blocks?
 | 
						|
        msg.fail("The [pretraining] block in your config is empty", exits=1)
 | 
						|
    if not output_dir.exists():
 | 
						|
        output_dir.mkdir(parents=True)
 | 
						|
        msg.good(f"Created output directory: {output_dir}")
 | 
						|
    # Save non-interpolated config
 | 
						|
    raw_config.to_disk(output_dir / "config.cfg")
 | 
						|
    msg.good("Saved config file in the output directory")
 | 
						|
 | 
						|
    pretrain(
 | 
						|
        config,
 | 
						|
        output_dir,
 | 
						|
        resume_path=resume_path,
 | 
						|
        epoch_resume=epoch_resume,
 | 
						|
        use_gpu=use_gpu,
 | 
						|
        silent=False,
 | 
						|
    )
 | 
						|
    msg.good("Successfully finished pretrain")
 | 
						|
 | 
						|
 | 
						|
def verify_cli_args(config_path, output_dir, resume_path, epoch_resume):
 | 
						|
    if not config_path or (str(config_path) != "-" and not config_path.exists()):
 | 
						|
        msg.fail("Config file not found", config_path, exits=1)
 | 
						|
    if output_dir.exists() and [p for p in output_dir.iterdir()]:
 | 
						|
        if resume_path:
 | 
						|
            msg.warn(
 | 
						|
                "Output directory is not empty.",
 | 
						|
                "If you're resuming a run in this directory, the old weights "
 | 
						|
                "for the consecutive epochs will be overwritten with the new ones.",
 | 
						|
            )
 | 
						|
        else:
 | 
						|
            msg.warn(
 | 
						|
                "Output directory is not empty. ",
 | 
						|
                "It is better to use an empty directory or refer to a new output path, "
 | 
						|
                "then the new directory will be created for you.",
 | 
						|
            )
 | 
						|
    if resume_path is not None:
 | 
						|
        if resume_path.is_dir():
 | 
						|
            # This is necessary because Windows gives a Permission Denied when we
 | 
						|
            # try to open the directory later, which is confusing. See #7878
 | 
						|
            msg.fail(
 | 
						|
                "--resume-path should be a weights file, but {resume_path} is a directory.",
 | 
						|
                exits=True,
 | 
						|
            )
 | 
						|
        model_name = re.search(r"model\d+\.bin", str(resume_path))
 | 
						|
        if not model_name and not epoch_resume:
 | 
						|
            msg.fail(
 | 
						|
                "You have to use the --epoch-resume setting when using a renamed weight file for --resume-path",
 | 
						|
                exits=True,
 | 
						|
            )
 | 
						|
        elif not model_name and epoch_resume < 0:
 | 
						|
            msg.fail(
 | 
						|
                f"The argument --epoch-resume has to be greater or equal to 0. {epoch_resume} is invalid",
 | 
						|
                exits=True,
 | 
						|
            )
 |