from itertools import chain
from pathlib import Path
from typing import Iterable, List, Optional, Union, cast

import srsly
import tqdm
from wasabi import msg

from ..tokens import Doc, DocBin
from ..util import ensure_path, load_model
from ..vocab import Vocab
from ._util import Arg, Opt, app, import_code, setup_gpu, walk_directory

path_help = """Location of the documents to predict on.
Can be a single file in .spacy format or a .jsonl file.
Files with other extensions are treated as single plain text documents.
If a directory is provided it is traversed recursively to grab
all files to be processed.
The files can be a mixture of .spacy, .jsonl and text files.
If .jsonl is provided the specified field is going
to be grabbed ("text" by default)."""

out_help = "Path to save the resulting .spacy file"
code_help = (
    "Path to Python file with additional " "code (registered functions) to be imported"
)
gold_help = "Use gold preprocessing provided in the .spacy files"
force_msg = (
    "The provided output file already exists. "
    "To force overwriting the output file, set the --force or -F flag."
)


DocOrStrStream = Union[Iterable[str], Iterable[Doc]]


def _stream_docbin(path: Path, vocab: Vocab) -> Iterable[Doc]:
    """
    Stream Doc objects from DocBin.
    """
    docbin = DocBin().from_disk(path)
    for doc in docbin.get_docs(vocab):
        yield doc


def _stream_jsonl(path: Path, field: str) -> Iterable[str]:
    """
    Stream "text" field from JSONL. If the field "text" is
    not found it raises error.
    """
    for entry in srsly.read_jsonl(path):
        if field not in entry:
            msg.fail(f"{path} does not contain the required '{field}' field.", exits=1)
        else:
            yield entry[field]


def _stream_texts(paths: Iterable[Path]) -> Iterable[str]:
    """
    Yields strings from text files in paths.
    """
    for path in paths:
        with open(path, "r") as fin:
            text = fin.read()
            yield text


@app.command("apply")
def apply_cli(
    # fmt: off
    model: str = Arg(..., help="Model name or path"),
    data_path: Path = Arg(..., help=path_help, exists=True),
    output_file: Path = Arg(..., help=out_help, dir_okay=False),
    code_path: Optional[Path] = Opt(None, "--code", "-c", help=code_help),
    text_key: str = Opt("text", "--text-key", "-tk", help="Key containing text string for JSONL"),
    force_overwrite: bool = Opt(False, "--force", "-F", help="Force overwriting the output file"),
    use_gpu: int = Opt(-1, "--gpu-id", "-g", help="GPU ID or -1 for CPU."),
    batch_size: int = Opt(1, "--batch-size", "-b", help="Batch size."),
    n_process: int = Opt(1, "--n-process", "-n", help="number of processors to use.")
):
    """
    Apply a trained pipeline to documents to get predictions.
    Expects a loadable spaCy pipeline and path to the data, which
    can be a directory or a file.
    The data files can be provided in multiple formats:
        1. .spacy files
        2. .jsonl files with a specified "field" to read the text from.
        3. Files with any other extension are assumed to be containing
           a single document.
    DOCS: https://spacy.io/api/cli#apply
    """
    data_path = ensure_path(data_path)
    output_file = ensure_path(output_file)
    code_path = ensure_path(code_path)
    if output_file.exists() and not force_overwrite:
        msg.fail(force_msg, exits=1)
    if not data_path.exists():
        msg.fail(f"Couldn't find data path: {data_path}", exits=1)
    import_code(code_path)
    setup_gpu(use_gpu)
    apply(data_path, output_file, model, text_key, batch_size, n_process)


def apply(
    data_path: Path,
    output_file: Path,
    model: str,
    json_field: str,
    batch_size: int,
    n_process: int,
):
    docbin = DocBin(store_user_data=True)
    paths = walk_directory(data_path)
    if len(paths) == 0:
        docbin.to_disk(output_file)
        msg.warn(
            "Did not find data to process,"
            f" {data_path} seems to be an empty directory."
        )
        return
    nlp = load_model(model)
    msg.good(f"Loaded model {model}")
    vocab = nlp.vocab
    streams: List[DocOrStrStream] = []
    text_files = []
    for path in paths:
        if path.suffix == ".spacy":
            streams.append(_stream_docbin(path, vocab))
        elif path.suffix == ".jsonl":
            streams.append(_stream_jsonl(path, json_field))
        else:
            text_files.append(path)
    if len(text_files) > 0:
        streams.append(_stream_texts(text_files))
    datagen = cast(DocOrStrStream, chain(*streams))
    for doc in tqdm.tqdm(
        nlp.pipe(datagen, batch_size=batch_size, n_process=n_process), disable=None
    ):
        docbin.add(doc)
    if output_file.suffix == "":
        output_file = output_file.with_suffix(".spacy")
    docbin.to_disk(output_file)