# cython: infer_types=True, profile=True, binding=True
from collections import defaultdict
from typing import Optional, Iterable, Callable
from thinc.api import Model, Config

from ._parser_internals.transition_system import TransitionSystem
from .transition_parser cimport Parser
from ._parser_internals.ner cimport BiluoPushDown
from ..language import Language
from ..scorer import get_ner_prf, PRFScore
from ..util import registry
from ..training import remove_bilu_prefix


default_model_config = """
[model]
@architectures = "spacy.TransitionBasedParser.v2"
state_type = "ner"
extra_state_tokens = false
hidden_width = 64
maxout_pieces = 2
use_upper = true

[model.tok2vec]
@architectures = "spacy.HashEmbedCNN.v2"
pretrained_vectors = null
width = 96
depth = 4
embed_size = 2000
window_size = 1
maxout_pieces = 3
subword_features = true
"""
DEFAULT_NER_MODEL = Config().from_str(default_model_config)["model"]


@Language.factory(
    "ner",
    assigns=["doc.ents", "token.ent_iob", "token.ent_type"],
    default_config={
        "moves": None,
        "update_with_oracle_cut_size": 100,
        "model": DEFAULT_NER_MODEL,
        "incorrect_spans_key": None,
        "scorer": {"@scorers": "spacy.ner_scorer.v1"},
    },
    default_score_weights={"ents_f": 1.0, "ents_p": 0.0, "ents_r": 0.0, "ents_per_type": None},

)
def make_ner(
    nlp: Language,
    name: str,
    model: Model,
    moves: Optional[TransitionSystem],
    update_with_oracle_cut_size: int,
    incorrect_spans_key: Optional[str],
    scorer: Optional[Callable],
):
    """Create a transition-based EntityRecognizer component. The entity recognizer
    identifies non-overlapping labelled spans of tokens.

    The transition-based algorithm used encodes certain assumptions that are
    effective for "traditional" named entity recognition tasks, but may not be
    a good fit for every span identification problem. Specifically, the loss
    function optimizes for whole entity accuracy, so if your inter-annotator
    agreement on boundary tokens is low, the component will likely perform poorly
    on your problem. The transition-based algorithm also assumes that the most
    decisive information about your entities will be close to their initial tokens.
    If your entities are long and characterised by tokens in their middle, the
    component will likely do poorly on your task.

    model (Model): The model for the transition-based parser. The model needs
        to have a specific substructure of named components --- see the
        spacy.ml.tb_framework.TransitionModel for details.
    moves (Optional[TransitionSystem]): This defines how the parse-state is created,
        updated and evaluated. If 'moves' is None, a new instance is
        created with `self.TransitionSystem()`. Defaults to `None`.
    update_with_oracle_cut_size (int): During training, cut long sequences into
        shorter segments by creating intermediate states based on the gold-standard
        history. The model is not very sensitive to this parameter, so you usually
        won't need to change it. 100 is a good default.
    incorrect_spans_key (Optional[str]): Identifies spans that are known
        to be incorrect entity annotations. The incorrect entity annotations
        can be stored in the span group, under this key.
    scorer (Optional[Callable]): The scoring method.
    """
    return EntityRecognizer(
        nlp.vocab,
        model,
        name,
        moves=moves,
        update_with_oracle_cut_size=update_with_oracle_cut_size,
        incorrect_spans_key=incorrect_spans_key,
        multitasks=[],
        beam_width=1,
        beam_density=0.0,
        beam_update_prob=0.0,
        scorer=scorer,
    )

@Language.factory(
    "beam_ner",
    assigns=["doc.ents", "token.ent_iob", "token.ent_type"],
    default_config={
        "moves": None,
        "update_with_oracle_cut_size": 100,
        "model": DEFAULT_NER_MODEL,
        "beam_density": 0.01,
        "beam_update_prob": 0.5,
        "beam_width": 32,
        "incorrect_spans_key": None,
        "scorer": None,
    },
    default_score_weights={"ents_f": 1.0, "ents_p": 0.0, "ents_r": 0.0, "ents_per_type": None},
)
def make_beam_ner(
    nlp: Language,
    name: str,
    model: Model,
    moves: Optional[TransitionSystem],
    update_with_oracle_cut_size: int,
    beam_width: int,
    beam_density: float,
    beam_update_prob: float,
    incorrect_spans_key: Optional[str],
    scorer: Optional[Callable],
):
    """Create a transition-based EntityRecognizer component that uses beam-search.
    The entity recognizer identifies non-overlapping labelled spans of tokens.

    The transition-based algorithm used encodes certain assumptions that are
    effective for "traditional" named entity recognition tasks, but may not be
    a good fit for every span identification problem. Specifically, the loss
    function optimizes for whole entity accuracy, so if your inter-annotator
    agreement on boundary tokens is low, the component will likely perform poorly
    on your problem. The transition-based algorithm also assumes that the most
    decisive information about your entities will be close to their initial tokens.
    If your entities are long and characterised by tokens in their middle, the
    component will likely do poorly on your task.

    model (Model): The model for the transition-based parser. The model needs
        to have a specific substructure of named components --- see the
        spacy.ml.tb_framework.TransitionModel for details.
    moves (Optional[TransitionSystem]): This defines how the parse-state is created,
        updated and evaluated. If 'moves' is None, a new instance is
        created with `self.TransitionSystem()`. Defaults to `None`.
    update_with_oracle_cut_size (int): During training, cut long sequences into
        shorter segments by creating intermediate states based on the gold-standard
        history. The model is not very sensitive to this parameter, so you usually
        won't need to change it. 100 is a good default.
    beam_width (int): The number of candidate analyses to maintain.
    beam_density (float): The minimum ratio between the scores of the first and
        last candidates in the beam. This allows the parser to avoid exploring
        candidates that are too far behind. This is mostly intended to improve
        efficiency, but it can also improve accuracy as deeper search is not
        always better.
    beam_update_prob (float): The chance of making a beam update, instead of a
        greedy update. Greedy updates are an approximation for the beam updates,
        and are faster to compute.
    incorrect_spans_key (Optional[str]): Optional key into span groups of
        entities known to be non-entities.
    scorer (Optional[Callable]): The scoring method.
    """
    return EntityRecognizer(
        nlp.vocab,
        model,
        name,
        moves=moves,
        update_with_oracle_cut_size=update_with_oracle_cut_size,
        multitasks=[],
        beam_width=beam_width,
        beam_density=beam_density,
        beam_update_prob=beam_update_prob,
        incorrect_spans_key=incorrect_spans_key,
        scorer=scorer,
    )


def ner_score(examples, **kwargs):
    return get_ner_prf(examples, **kwargs)


@registry.scorers("spacy.ner_scorer.v1")
def make_ner_scorer():
    return ner_score


cdef class EntityRecognizer(Parser):
    """Pipeline component for named entity recognition.

    DOCS: https://spacy.io/api/entityrecognizer
    """
    TransitionSystem = BiluoPushDown

    def __init__(
        self,
        vocab,
        model,
        name="ner",
        moves=None,
        *,
        update_with_oracle_cut_size=100,
        beam_width=1,
        beam_density=0.0,
        beam_update_prob=0.0,
        multitasks=tuple(),
        incorrect_spans_key=None,
        scorer=ner_score,
    ):
        """Create an EntityRecognizer.
        """
        super().__init__(
            vocab,
            model,
            name,
            moves,
            update_with_oracle_cut_size=update_with_oracle_cut_size,
            min_action_freq=1,   # not relevant for NER
            learn_tokens=False,  # not relevant for NER
            beam_width=beam_width,
            beam_density=beam_density,
            beam_update_prob=beam_update_prob,
            multitasks=multitasks,
            incorrect_spans_key=incorrect_spans_key,
            scorer=scorer,
        )

    def add_multitask_objective(self, mt_component):
        """Register another component as a multi-task objective. Experimental."""
        self._multitasks.append(mt_component)

    def init_multitask_objectives(self, get_examples, nlp=None, **cfg):
        """Setup multi-task objective components. Experimental and internal."""
        # TODO: transfer self.model.get_ref("tok2vec") to the multitask's model ?
        for labeller in self._multitasks:
            labeller.model.set_dim("nO", len(self.labels))
            if labeller.model.has_ref("output_layer"):
                labeller.model.get_ref("output_layer").set_dim("nO", len(self.labels))
            labeller.initialize(get_examples, nlp=nlp)

    @property
    def labels(self):
        # Get the labels from the model by looking at the available moves, e.g.
        # B-PERSON, I-PERSON, L-PERSON, U-PERSON
        labels = set(remove_bilu_prefix(move) for move in self.move_names
                     if move[0] in ("B", "I", "L", "U"))
        return tuple(sorted(labels))

    def scored_ents(self, beams):
        """Return a dictionary of (start, end, label) tuples with corresponding scores
        for each beam/doc that was processed.
        """
        entity_scores = []
        for beam in beams:
            score_dict = defaultdict(float)
            for score, ents in self.moves.get_beam_parses(beam):
                for start, end, label in ents:
                    score_dict[(start, end, label)] += score
            entity_scores.append(score_dict)
        return entity_scores