from typing import Callable, Iterator, Dict, List, Tuple, TYPE_CHECKING
from typing import Optional
import random
import itertools
from functools import partial

from ..util import registry
from .example import Example
from .iob_utils import split_bilu_label, _doc_to_biluo_tags_with_partial

if TYPE_CHECKING:
    from ..language import Language  # noqa: F401


@registry.augmenters("spacy.combined_augmenter.v1")
def create_combined_augmenter(
    lower_level: float,
    orth_level: float,
    orth_variants: Optional[Dict[str, List[Dict]]],
    whitespace_level: float,
    whitespace_per_token: float,
    whitespace_variants: Optional[List[str]],
) -> Callable[["Language", Example], Iterator[Example]]:
    """Create a data augmentation callback that uses orth-variant replacement.
    The callback can be added to a corpus or other data iterator during training.

    lower_level (float): The percentage of texts that will be lowercased.
    orth_level (float): The percentage of texts that will be augmented.
    orth_variants (Optional[Dict[str, List[Dict]]]): A dictionary containing the
        single and paired orth variants. Typically loaded from a JSON file.
    whitespace_level (float): The percentage of texts that will have whitespace
        tokens inserted.
    whitespace_per_token (float): The number of whitespace tokens to insert in
        the modified doc as a percentage of the doc length.
    whitespace_variants (Optional[List[str]]): The whitespace token texts.
    RETURNS (Callable[[Language, Example], Iterator[Example]]): The augmenter.
    """
    return partial(
        combined_augmenter,
        lower_level=lower_level,
        orth_level=orth_level,
        orth_variants=orth_variants,
        whitespace_level=whitespace_level,
        whitespace_per_token=whitespace_per_token,
        whitespace_variants=whitespace_variants,
    )


def combined_augmenter(
    nlp: "Language",
    example: Example,
    *,
    lower_level: float = 0.0,
    orth_level: float = 0.0,
    orth_variants: Optional[Dict[str, List[Dict]]] = None,
    whitespace_level: float = 0.0,
    whitespace_per_token: float = 0.0,
    whitespace_variants: Optional[List[str]] = None,
) -> Iterator[Example]:
    if random.random() < lower_level:
        example = make_lowercase_variant(nlp, example)
    if orth_variants and random.random() < orth_level:
        raw_text = example.text
        orig_dict = example.to_dict()
        orig_dict["doc_annotation"]["entities"] = _doc_to_biluo_tags_with_partial(
            example.reference
        )
        variant_text, variant_token_annot = make_orth_variants(
            nlp,
            raw_text,
            orig_dict["token_annotation"],
            orth_variants,
            lower=False,
        )
        orig_dict["token_annotation"] = variant_token_annot
        example = example.from_dict(nlp.make_doc(variant_text), orig_dict)
    if whitespace_variants and random.random() < whitespace_level:
        for _ in range(int(len(example.reference) * whitespace_per_token)):
            example = make_whitespace_variant(
                nlp,
                example,
                random.choice(whitespace_variants),
                random.randrange(0, len(example.reference)),
            )
    yield example


@registry.augmenters("spacy.orth_variants.v1")
def create_orth_variants_augmenter(
    level: float, lower: float, orth_variants: Dict[str, List[Dict]]
) -> Callable[["Language", Example], Iterator[Example]]:
    """Create a data augmentation callback that uses orth-variant replacement.
    The callback can be added to a corpus or other data iterator during training.

    level (float): The percentage of texts that will be augmented.
    lower (float): The percentage of texts that will be lowercased.
    orth_variants (Dict[str, List[Dict]]): A dictionary containing
        the single and paired orth variants. Typically loaded from a JSON file.
    RETURNS (Callable[[Language, Example], Iterator[Example]]): The augmenter.
    """
    return partial(
        orth_variants_augmenter, orth_variants=orth_variants, level=level, lower=lower
    )


@registry.augmenters("spacy.lower_case.v1")
def create_lower_casing_augmenter(
    level: float,
) -> Callable[["Language", Example], Iterator[Example]]:
    """Create a data augmentation callback that converts documents to lowercase.
    The callback can be added to a corpus or other data iterator during training.

    level (float): The percentage of texts that will be augmented.
    RETURNS (Callable[[Language, Example], Iterator[Example]]): The augmenter.
    """
    return partial(lower_casing_augmenter, level=level)


def dont_augment(nlp: "Language", example: Example) -> Iterator[Example]:
    yield example


def lower_casing_augmenter(
    nlp: "Language", example: Example, *, level: float
) -> Iterator[Example]:
    if random.random() >= level:
        yield example
    else:
        yield make_lowercase_variant(nlp, example)


def make_lowercase_variant(nlp: "Language", example: Example):
    example_dict = example.to_dict()
    example_dict["doc_annotation"]["entities"] = _doc_to_biluo_tags_with_partial(
        example.reference
    )
    doc = nlp.make_doc(example.text.lower())
    example_dict["token_annotation"]["ORTH"] = [t.lower_ for t in example.reference]
    return example.from_dict(doc, example_dict)


def orth_variants_augmenter(
    nlp: "Language",
    example: Example,
    orth_variants: Dict[str, List[Dict]],
    *,
    level: float = 0.0,
    lower: float = 0.0,
) -> Iterator[Example]:
    if random.random() >= level:
        yield example
    else:
        raw_text = example.text
        orig_dict = example.to_dict()
        orig_dict["doc_annotation"]["entities"] = _doc_to_biluo_tags_with_partial(
            example.reference
        )
        variant_text, variant_token_annot = make_orth_variants(
            nlp,
            raw_text,
            orig_dict["token_annotation"],
            orth_variants,
            lower=raw_text is not None and random.random() < lower,
        )
        orig_dict["token_annotation"] = variant_token_annot
        yield example.from_dict(nlp.make_doc(variant_text), orig_dict)


def make_orth_variants(
    nlp: "Language",
    raw: str,
    token_dict: Dict[str, List[str]],
    orth_variants: Dict[str, List[Dict[str, List[str]]]],
    *,
    lower: bool = False,
) -> Tuple[str, Dict[str, List[str]]]:
    words = token_dict.get("ORTH", [])
    tags = token_dict.get("TAG", [])
    # keep unmodified if words are not defined
    if not words:
        return raw, token_dict
    if lower:
        words = [w.lower() for w in words]
        raw = raw.lower()
    # if no tags, only lowercase
    if not tags:
        token_dict["ORTH"] = words
        return raw, token_dict
    # single variants
    ndsv = orth_variants.get("single", [])
    punct_choices = [random.choice(x["variants"]) for x in ndsv]
    for word_idx in range(len(words)):
        for punct_idx in range(len(ndsv)):
            if (
                tags[word_idx] in ndsv[punct_idx]["tags"]
                and words[word_idx] in ndsv[punct_idx]["variants"]
            ):
                words[word_idx] = punct_choices[punct_idx]
    # paired variants
    ndpv = orth_variants.get("paired", [])
    punct_choices = [random.choice(x["variants"]) for x in ndpv]
    for word_idx in range(len(words)):
        for punct_idx in range(len(ndpv)):
            if tags[word_idx] in ndpv[punct_idx]["tags"] and words[
                word_idx
            ] in itertools.chain.from_iterable(ndpv[punct_idx]["variants"]):
                # backup option: random left vs. right from pair
                pair_idx = random.choice([0, 1])
                # best option: rely on paired POS tags like `` / ''
                if len(ndpv[punct_idx]["tags"]) == 2:
                    pair_idx = ndpv[punct_idx]["tags"].index(tags[word_idx])
                # next best option: rely on position in variants
                # (may not be unambiguous, so order of variants matters)
                else:
                    for pair in ndpv[punct_idx]["variants"]:
                        if words[word_idx] in pair:
                            pair_idx = pair.index(words[word_idx])
                words[word_idx] = punct_choices[punct_idx][pair_idx]
    token_dict["ORTH"] = words
    raw = construct_modified_raw_text(token_dict)
    return raw, token_dict


def make_whitespace_variant(
    nlp: "Language",
    example: Example,
    whitespace: str,
    position: int,
) -> Example:
    """Insert the whitespace token at the specified token offset in the doc.
    This is primarily intended for v2-compatible training data that doesn't
    include links or spans. If the document includes links, spans, or partial
    dependency annotation, it is returned without modifications.

    The augmentation follows the basics of the v2 space attachment policy, but
    without a distinction between "real" and other tokens, so space tokens
    may be attached to space tokens:
    - at the beginning of a sentence attach the space token to the following
      token
    - otherwise attach the space token to the preceding token

    The augmenter does not attempt to consolidate adjacent whitespace in the
    same way that the tokenizer would.

    The following annotation is used for the space token:
    TAG: "_SP"
    MORPH: ""
    POS: "SPACE"
    LEMMA: ORTH
    DEP: "dep"
    SENT_START: False

    The annotation for each attribute is only set for the space token if there
    is already at least partial annotation for that attribute in the original
    example.

    RETURNS (Example): Example with one additional space token.
    """
    example_dict = example.to_dict()
    example_dict["doc_annotation"]["entities"] = _doc_to_biluo_tags_with_partial(
        example.reference
    )
    doc_dict = example_dict.get("doc_annotation", {})
    token_dict = example_dict.get("token_annotation", {})
    # returned unmodified if:
    # - doc is empty
    # - words are not defined
    # - links are defined (only character-based offsets, which is more a quirk
    #   of Example.to_dict than a technical constraint)
    # - spans are defined
    # - there are partial dependencies
    if (
        len(example.reference) == 0
        or "ORTH" not in token_dict
        or len(doc_dict.get("links", [])) > 0
        or len(example.reference.spans) > 0
        or (
            example.reference.has_annotation("DEP")
            and not example.reference.has_annotation("DEP", require_complete=True)
        )
    ):
        return example
    words = token_dict.get("ORTH", [])
    length = len(words)
    assert 0 <= position <= length
    if example.reference.has_annotation("ENT_TYPE"):
        # I-ENTITY if between B/I-ENTITY and I/L-ENTITY otherwise O
        entity = "O"
        if position > 1 and position < length:
            ent_prev = doc_dict["entities"][position - 1]
            ent_next = doc_dict["entities"][position]
            if "-" in ent_prev and "-" in ent_next:
                ent_iob_prev, ent_type_prev = split_bilu_label(ent_prev)
                ent_iob_next, ent_type_next = split_bilu_label(ent_next)
                if (
                    ent_iob_prev in ("B", "I")
                    and ent_iob_next in ("I", "L")
                    and ent_type_prev == ent_type_next
                ):
                    entity = f"I-{ent_type_prev}"
        doc_dict["entities"].insert(position, entity)
    else:
        del doc_dict["entities"]
    token_dict["ORTH"].insert(position, whitespace)
    token_dict["SPACY"].insert(position, False)
    if example.reference.has_annotation("TAG"):
        token_dict["TAG"].insert(position, "_SP")
    else:
        del token_dict["TAG"]
    if example.reference.has_annotation("LEMMA"):
        token_dict["LEMMA"].insert(position, whitespace)
    else:
        del token_dict["LEMMA"]
    if example.reference.has_annotation("POS"):
        token_dict["POS"].insert(position, "SPACE")
    else:
        del token_dict["POS"]
    if example.reference.has_annotation("MORPH"):
        token_dict["MORPH"].insert(position, "")
    else:
        del token_dict["MORPH"]
    if example.reference.has_annotation("DEP", require_complete=True):
        if position == 0:
            token_dict["HEAD"].insert(position, 0)
        else:
            token_dict["HEAD"].insert(position, position - 1)
        for i in range(len(token_dict["HEAD"])):
            if token_dict["HEAD"][i] >= position:
                token_dict["HEAD"][i] += 1
        token_dict["DEP"].insert(position, "dep")
    else:
        del token_dict["HEAD"]
        del token_dict["DEP"]
    if example.reference.has_annotation("SENT_START"):
        token_dict["SENT_START"].insert(position, False)
    else:
        del token_dict["SENT_START"]
    raw = construct_modified_raw_text(token_dict)
    return Example.from_dict(nlp.make_doc(raw), example_dict)


def construct_modified_raw_text(token_dict):
    """Construct modified raw text from words and spaces."""
    raw = ""
    for orth, spacy in zip(token_dict["ORTH"], token_dict["SPACY"]):
        raw += orth
        if spacy:
            raw += " "
    return raw