from typing import Optional
from pathlib import Path
from wasabi import msg
import typer
import re

from ._util import app, Arg, Opt, parse_config_overrides, show_validation_error
from ._util import import_code, setup_gpu
from ..training.pretrain import pretrain
from ..util import load_config


@app.command(
    "pretrain",
    context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
)
def pretrain_cli(
    # fmt: off
    ctx: typer.Context,  # This is only used to read additional arguments
    config_path: Path = Arg(..., help="Path to config file", exists=True, dir_okay=False, allow_dash=True),
    output_dir: Path = Arg(..., help="Directory to write weights to on each epoch"),
    code_path: Optional[Path] = Opt(None, "--code", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
    resume_path: Optional[Path] = Opt(None, "--resume-path", "-r", help="Path to pretrained weights from which to resume pretraining"),
    epoch_resume: Optional[int] = Opt(None, "--epoch-resume", "-er", help="The epoch to resume counting from when using --resume-path. Prevents unintended overwriting of existing weight files."),
    use_gpu: int = Opt(-1, "--gpu-id", "-g", help="GPU ID or -1 for CPU"),
    # fmt: on
):
    """
    Pre-train the 'token-to-vector' (tok2vec) layer of pipeline components,
    using an approximate language-modelling objective. Two objective types
    are available, vector-based and character-based.

    In the vector-based objective, we load word vectors that have been trained
    using a word2vec-style distributional similarity algorithm, and train a
    component like a CNN, BiLSTM, etc to predict vectors which match the
    pretrained ones. The weights are saved to a directory after each epoch. You
    can then pass a path to one of these pretrained weights files to the
    'spacy train' command.

    This technique may be especially helpful if you have little labelled data.
    However, it's still quite experimental, so your mileage may vary.

    To load the weights back in during 'spacy train', you need to ensure
    all settings are the same between pretraining and training. Ideally,
    this is done by using the same config file for both commands.

    DOCS: https://spacy.io/api/cli#pretrain
    """
    config_overrides = parse_config_overrides(ctx.args)
    import_code(code_path)
    verify_cli_args(config_path, output_dir, resume_path, epoch_resume)
    setup_gpu(use_gpu)
    msg.info(f"Loading config from: {config_path}")

    with show_validation_error(config_path):
        raw_config = load_config(
            config_path, overrides=config_overrides, interpolate=False
        )
    config = raw_config.interpolate()
    if not config.get("pretraining"):
        # TODO: What's the solution here? How do we handle optional blocks?
        msg.fail("The [pretraining] block in your config is empty", exits=1)
    if not output_dir.exists():
        output_dir.mkdir()
        msg.good(f"Created output directory: {output_dir}")
    # Save non-interpolated config
    raw_config.to_disk(output_dir / "config.cfg")
    msg.good("Saved config file in the output directory")

    pretrain(
        config,
        output_dir,
        resume_path=resume_path,
        epoch_resume=epoch_resume,
        use_gpu=use_gpu,
        silent=False,
    )
    msg.good("Successfully finished pretrain")


def verify_cli_args(config_path, output_dir, resume_path, epoch_resume):
    if not config_path or (str(config_path) != "-" and not config_path.exists()):
        msg.fail("Config file not found", config_path, exits=1)
    if output_dir.exists() and [p for p in output_dir.iterdir()]:
        if resume_path:
            msg.warn(
                "Output directory is not empty.",
                "If you're resuming a run in this directory, the old weights "
                "for the consecutive epochs will be overwritten with the new ones.",
            )
        else:
            msg.warn(
                "Output directory is not empty. ",
                "It is better to use an empty directory or refer to a new output path, "
                "then the new directory will be created for you.",
            )
    if resume_path is not None:
        model_name = re.search(r"model\d+\.bin", str(resume_path))
        if not model_name and not epoch_resume:
            msg.fail(
                "You have to use the --epoch-resume setting when using a renamed weight file for --resume-path",
                exits=True,
            )
        elif not model_name and epoch_resume < 0:
            msg.fail(
                f"The argument --epoch-resume has to be greater or equal to 0. {epoch_resume} is invalid",
                exits=True,
            )