# cython: infer_types=True, profile=True, binding=True
from collections import defaultdict
from typing import Callable, Optional

from thinc.api import Config, Model

from ..language import Language
from ..scorer import Scorer
from ..training import remove_bilu_prefix
from ..util import registry
from ._parser_internals import nonproj
from ._parser_internals.arc_eager import ArcEager
from ._parser_internals.nonproj import DELIMITER
from ._parser_internals.transition_system import TransitionSystem
from .functions import merge_subtokens
from .transition_parser import Parser

default_model_config = """
[model]
@architectures = "spacy.TransitionBasedParser.v3"
state_type = "parser"
extra_state_tokens = false
hidden_width = 64
maxout_pieces = 2

[model.tok2vec]
@architectures = "spacy.HashEmbedCNN.v2"
pretrained_vectors = null
width = 96
depth = 4
embed_size = 2000
window_size = 1
maxout_pieces = 3
subword_features = true
"""
DEFAULT_PARSER_MODEL = Config().from_str(default_model_config)["model"]


@Language.factory(
    "parser",
    assigns=["token.dep", "token.head", "token.is_sent_start", "doc.sents"],
    default_config={
        "moves": None,
        "update_with_oracle_cut_size": 100,
        "learn_tokens": False,
        "min_action_freq": 30,
        "model": DEFAULT_PARSER_MODEL,
        "scorer": {"@scorers": "spacy.parser_scorer.v1"},
    },
    default_score_weights={
        "dep_uas": 0.5,
        "dep_las": 0.5,
        "dep_las_per_type": None,
        "sents_p": None,
        "sents_r": None,
        "sents_f": 0.0,
    },
)
def make_parser(
    nlp: Language,
    name: str,
    model: Model,
    moves: Optional[TransitionSystem],
    update_with_oracle_cut_size: int,
    learn_tokens: bool,
    min_action_freq: int,
    scorer: Optional[Callable],
):
    """Create a transition-based DependencyParser component. The dependency parser
    jointly learns sentence segmentation and labelled dependency parsing, and can
    optionally learn to merge tokens that had been over-segmented by the tokenizer.

    The parser uses a variant of the non-monotonic arc-eager transition-system
    described by Honnibal and Johnson (2014), with the addition of a "break"
    transition to perform the sentence segmentation. Nivre's pseudo-projective
    dependency transformation is used to allow the parser to predict
    non-projective parses.

    The parser is trained using an imitation learning objective. The parser follows
    the actions predicted by the current weights, and at each state, determines
    which actions are compatible with the optimal parse that could be reached
    from the current state. The weights such that the scores assigned to the
    set of optimal actions is increased, while scores assigned to other
    actions are decreased. Note that more than one action may be optimal for
    a given state.

    model (Model): The model for the transition-based parser. The model needs
        to have a specific substructure of named components --- see the
        spacy.ml.tb_framework.TransitionModel for details.
    moves (Optional[TransitionSystem]): This defines how the parse-state is created,
        updated and evaluated. If 'moves' is None, a new instance is
        created with `self.TransitionSystem()`. Defaults to `None`.
    update_with_oracle_cut_size (int): During training, cut long sequences into
        shorter segments by creating intermediate states based on the gold-standard
        history. The model is not very sensitive to this parameter, so you usually
        won't need to change it. 100 is a good default.
    learn_tokens (bool): Whether to learn to merge subtokens that are split
        relative to the gold standard. Experimental.
    min_action_freq (int): The minimum frequency of labelled actions to retain.
        Rarer labelled actions have their label backed-off to "dep". While this
        primarily affects the label accuracy, it can also affect the attachment
        structure, as the labels are used to represent the pseudo-projectivity
        transformation.
    scorer (Optional[Callable]): The scoring method.
    """
    return DependencyParser(
        nlp.vocab,
        model,
        name,
        moves=moves,
        update_with_oracle_cut_size=update_with_oracle_cut_size,
        multitasks=[],
        learn_tokens=learn_tokens,
        min_action_freq=min_action_freq,
        beam_width=1,
        beam_density=0.0,
        beam_update_prob=0.0,
        # At some point in the future we can try to implement support for
        # partial annotations, perhaps only in the beam objective.
        incorrect_spans_key=None,
        scorer=scorer,
    )


@Language.factory(
    "beam_parser",
    assigns=["token.dep", "token.head", "token.is_sent_start", "doc.sents"],
    default_config={
        "beam_width": 8,
        "beam_density": 0.01,
        "beam_update_prob": 0.5,
        "moves": None,
        "update_with_oracle_cut_size": 100,
        "learn_tokens": False,
        "min_action_freq": 30,
        "model": DEFAULT_PARSER_MODEL,
        "scorer": {"@scorers": "spacy.parser_scorer.v1"},
    },
    default_score_weights={
        "dep_uas": 0.5,
        "dep_las": 0.5,
        "dep_las_per_type": None,
        "sents_p": None,
        "sents_r": None,
        "sents_f": 0.0,
    },
)
def make_beam_parser(
    nlp: Language,
    name: str,
    model: Model,
    moves: Optional[TransitionSystem],
    update_with_oracle_cut_size: int,
    learn_tokens: bool,
    min_action_freq: int,
    beam_width: int,
    beam_density: float,
    beam_update_prob: float,
    scorer: Optional[Callable],
):
    """Create a transition-based DependencyParser component that uses beam-search.
    The dependency parser jointly learns sentence segmentation and labelled
    dependency parsing, and can optionally learn to merge tokens that had been
    over-segmented by the tokenizer.

    The parser uses a variant of the non-monotonic arc-eager transition-system
    described by Honnibal and Johnson (2014), with the addition of a "break"
    transition to perform the sentence segmentation. Nivre's pseudo-projective
    dependency transformation is used to allow the parser to predict
    non-projective parses.

    The parser is trained using a global objective. That is, it learns to assign
    probabilities to whole parses.

    model (Model): The model for the transition-based parser. The model needs
        to have a specific substructure of named components --- see the
        spacy.ml.tb_framework.TransitionModel for details.
    moves (Optional[TransitionSystem]): This defines how the parse-state is created,
        updated and evaluated. If 'moves' is None, a new instance is
        created with `self.TransitionSystem()`. Defaults to `None`.
    update_with_oracle_cut_size (int): During training, cut long sequences into
        shorter segments by creating intermediate states based on the gold-standard
        history. The model is not very sensitive to this parameter, so you usually
        won't need to change it. 100 is a good default.
    beam_width (int): The number of candidate analyses to maintain.
    beam_density (float): The minimum ratio between the scores of the first and
        last candidates in the beam. This allows the parser to avoid exploring
        candidates that are too far behind. This is mostly intended to improve
        efficiency, but it can also improve accuracy as deeper search is not
        always better.
    beam_update_prob (float): The chance of making a beam update, instead of a
        greedy update. Greedy updates are an approximation for the beam updates,
        and are faster to compute.
    learn_tokens (bool): Whether to learn to merge subtokens that are split
        relative to the gold standard. Experimental.
    min_action_freq (int): The minimum frequency of labelled actions to retain.
        Rarer labelled actions have their label backed-off to "dep". While this
        primarily affects the label accuracy, it can also affect the attachment
        structure, as the labels are used to represent the pseudo-projectivity
        transformation.
    """
    return DependencyParser(
        nlp.vocab,
        model,
        name,
        moves=moves,
        update_with_oracle_cut_size=update_with_oracle_cut_size,
        beam_width=beam_width,
        beam_density=beam_density,
        beam_update_prob=beam_update_prob,
        multitasks=[],
        learn_tokens=learn_tokens,
        min_action_freq=min_action_freq,
        # At some point in the future we can try to implement support for
        # partial annotations, perhaps only in the beam objective.
        incorrect_spans_key=None,
        scorer=scorer,
    )


def parser_score(examples, **kwargs):
    """Score a batch of examples.

    examples (Iterable[Example]): The examples to score.
    RETURNS (Dict[str, Any]): The scores, produced by Scorer.score_spans
        and Scorer.score_deps.

    DOCS: https://spacy.io/api/dependencyparser#score
    """

    def has_sents(doc):
        return doc.has_annotation("SENT_START")

    def dep_getter(token, attr):
        dep = getattr(token, attr)
        dep = token.vocab.strings.as_string(dep).lower()
        return dep

    results = {}
    results.update(
        Scorer.score_spans(examples, "sents", has_annotation=has_sents, **kwargs)
    )
    kwargs.setdefault("getter", dep_getter)
    kwargs.setdefault("ignore_labels", ("p", "punct"))
    results.update(Scorer.score_deps(examples, "dep", **kwargs))
    del results["sents_per_type"]
    return results


@registry.scorers("spacy.parser_scorer.v1")
def make_parser_scorer():
    return parser_score


class DependencyParser(Parser):
    """Pipeline component for dependency parsing.

    DOCS: https://spacy.io/api/dependencyparser
    """

    TransitionSystem = ArcEager

    def __init__(
        self,
        vocab,
        model,
        name="parser",
        moves=None,
        *,
        update_with_oracle_cut_size=100,
        min_action_freq=30,
        learn_tokens=False,
        beam_width=1,
        beam_density=0.0,
        beam_update_prob=0.0,
        multitasks=tuple(),
        incorrect_spans_key=None,
        scorer=parser_score,
    ):
        """Create a DependencyParser."""
        super().__init__(
            vocab,
            model,
            name,
            moves,
            update_with_oracle_cut_size=update_with_oracle_cut_size,
            min_action_freq=min_action_freq,
            learn_tokens=learn_tokens,
            beam_width=beam_width,
            beam_density=beam_density,
            beam_update_prob=beam_update_prob,
            multitasks=multitasks,
            incorrect_spans_key=incorrect_spans_key,
            scorer=scorer,
        )

    @property
    def postprocesses(self):
        output = [nonproj.deprojectivize]
        if self.cfg.get("learn_tokens") is True:
            output.append(merge_subtokens)
        return tuple(output)

    def add_multitask_objective(self, mt_component):
        self._multitasks.append(mt_component)

    def init_multitask_objectives(self, get_examples, nlp=None, **cfg):
        # TODO: transfer self.model.get_ref("tok2vec") to the multitask's model ?
        for labeller in self._multitasks:
            labeller.model.set_dim("nO", len(self.labels))
            if labeller.model.has_ref("output_layer"):
                labeller.model.get_ref("output_layer").set_dim("nO", len(self.labels))
            labeller.initialize(get_examples, nlp=nlp)

    @property
    def labels(self):
        labels = set()
        # Get the labels from the model by looking at the available moves
        for move in self.move_names:
            if "-" in move:
                label = remove_bilu_prefix(move)
                if DELIMITER in label:
                    label = label.split(DELIMITER)[1]
                labels.add(label)
        return tuple(sorted(labels))

    def scored_parses(self, beams):
        """Return two dictionaries with scores for each beam/doc that was processed:
        one containing (i, head) keys, and another containing (i, label) keys.
        """
        head_scores = []
        label_scores = []
        for beam in beams:
            score_head_dict = defaultdict(float)
            score_label_dict = defaultdict(float)
            for score, parses in self.moves.get_beam_parses(beam):
                for head, i, label in parses:
                    score_head_dict[(i, head)] += score
                    score_label_dict[(i, label)] += score
            head_scores.append(score_head_dict)
            label_scores.append(score_label_dict)
        return head_scores, label_scores

    def _ensure_labels_are_added(self, docs):
        # This gives the parser a chance to add labels it's missing for a batch
        # of documents. However, this isn't desirable for the dependency parser,
        # because we instead have a label frequency cut-off and back off rare
        # labels to 'dep'.
        pass