"""
MALT-style dependency parser
"""
# cython: profile=True
# cython: experimental_cpp_class_def=True
# cython: cdivision=True
# cython: infer_types=True
# coding: utf-8

from __future__ import unicode_literals, print_function
cimport cython

from cpython.ref cimport PyObject, Py_INCREF, Py_XDECREF
from libc.stdint cimport uint32_t, uint64_t
from libc.string cimport memset, memcpy
from libc.stdlib cimport rand
from libc.math cimport log, exp, isnan, isinf
from cymem.cymem cimport Pool, Address
from murmurhash.mrmr cimport real_hash64 as hash64
from thinc.typedefs cimport weight_t, class_t, feat_t, atom_t, hash_t
from thinc.linear.features cimport ConjunctionExtracter
from thinc.structs cimport FeatureC, ExampleC
from thinc.extra.search cimport Beam, MaxViolation
from thinc.extra.eg cimport Example
from thinc.extra.mb cimport Minibatch

from ..structs cimport TokenC
from ..tokens.doc cimport Doc
from ..strings cimport StringStore
from .transition_system cimport TransitionSystem, Transition
from ..gold cimport GoldParse
from . import _parse_features
from ._parse_features cimport CONTEXT_SIZE
from ._parse_features cimport fill_context
from .stateclass cimport StateClass
from .parser cimport Parser


DEBUG = False
def set_debug(val):
    global DEBUG
    DEBUG = val


def get_templates(name):
    pf = _parse_features
    if name == 'ner':
        return pf.ner
    elif name == 'debug':
        return pf.unigrams
    else:
        return (pf.unigrams + pf.s0_n0 + pf.s1_n0 + pf.s1_s0 + pf.s0_n1 + pf.n0_n1 + \
                pf.tree_shape + pf.trigrams)


cdef int BEAM_WIDTH = 16
cdef weight_t BEAM_DENSITY = 0.001

cdef class BeamParser(Parser):
    def __init__(self, *args, **kwargs):
        self.beam_width = kwargs.get('beam_width', BEAM_WIDTH)
        self.beam_density = kwargs.get('beam_density', BEAM_DENSITY)
        Parser.__init__(self, *args, **kwargs)

    cdef int parseC(self, TokenC* tokens, int length, int nr_feat) nogil:
        with gil:
            self._parseC(tokens, length, nr_feat, self.moves.n_moves)

    cdef int _parseC(self, TokenC* tokens, int length, int nr_feat, int nr_class) except -1:
        cdef Beam beam = Beam(self.moves.n_moves, self.beam_width, min_density=self.beam_density)
        # TODO: How do we handle new labels here? This increases nr_class
        beam.initialize(self.moves.init_beam_state, length, tokens)
        beam.check_done(_check_final_state, NULL)
        if beam.is_done:
            _cleanup(beam)
            return 0
        while not beam.is_done:
            self._advance_beam(beam, None, False)
        state = <StateClass>beam.at(0)
        self.moves.finalize_state(state.c)
        for i in range(length):
            tokens[i] = state.c._sent[i]
        _cleanup(beam)

    def update(self, Doc tokens, GoldParse gold_parse, itn=0):
        self.moves.preprocess_gold(gold_parse)
        cdef Beam pred = Beam(self.moves.n_moves, self.beam_width)
        pred.initialize(self.moves.init_beam_state, tokens.length, tokens.c)
        pred.check_done(_check_final_state, NULL)
        # Hack for NER
        for i in range(pred.size):
            stcls = <StateClass>pred.at(i)
            self.moves.initialize_state(stcls.c)

        cdef Beam gold = Beam(self.moves.n_moves, self.beam_width, min_density=0.0)
        gold.initialize(self.moves.init_beam_state, tokens.length, tokens.c)
        gold.check_done(_check_final_state, NULL)
        violn = MaxViolation()
        while not pred.is_done and not gold.is_done:
            # We search separately here, to allow for ambiguity in the gold parse.
            self._advance_beam(pred, gold_parse, False)
            self._advance_beam(gold, gold_parse, True)
            violn.check_crf(pred, gold)
            if pred.loss > 0 and pred.min_score > (gold.score + self.model.time):
                break
        else:
            # The non-monotonic oracle makes it difficult to ensure final costs are
            # correct. Therefore do final correction
            for i in range(pred.size):
                if is_gold(<StateClass>pred.at(i), gold_parse, self.moves.strings):
                    pred._states[i].loss = 0.0
                elif pred._states[i].loss == 0.0:
                    pred._states[i].loss = 1.0
            violn.check_crf(pred, gold)
        if pred.size < 1:
            raise Exception("No candidates", tokens.length)
        if gold.size < 1:
            raise Exception("No gold", tokens.length)
        if pred.loss == 0:
            self.model.update_from_histories(self.moves, tokens, [(0.0, [])])
        elif True:
            #_check_train_integrity(pred, gold, gold_parse, self.moves)
            histories = list(zip(violn.p_probs, violn.p_hist)) + \
                        list(zip(violn.g_probs, violn.g_hist))
            self.model.update_from_histories(self.moves, tokens, histories, min_grad=0.001**(itn+1))
        else:
            self.model.update_from_histories(self.moves, tokens,
                [(1.0, violn.p_hist[0]), (-1.0, violn.g_hist[0])])
        _cleanup(pred)
        _cleanup(gold)
        return pred.loss

    def _advance_beam(self, Beam beam, GoldParse gold, bint follow_gold):
        cdef atom_t[CONTEXT_SIZE] context
        cdef Pool mem = Pool()
        features = <FeatureC*>mem.alloc(self.model.nr_feat, sizeof(FeatureC))
        if False:
            mb = Minibatch(self.model.widths, beam.size)
            for i in range(beam.size):
                stcls = <StateClass>beam.at(i)
                if stcls.c.is_final():
                    nr_feat = 0
                else:
                    nr_feat = self.model.set_featuresC(context, features, stcls.c)
                    self.moves.set_valid(beam.is_valid[i], stcls.c)
                mb.c.push_back(features, nr_feat, beam.costs[i], beam.is_valid[i], 0)
            self.model(mb)
            for i in range(beam.size):
                memcpy(beam.scores[i], mb.c.scores(i), mb.c.nr_out() * sizeof(beam.scores[i][0]))
        else:
            for i in range(beam.size):
                stcls = <StateClass>beam.at(i)
                if not stcls.is_final():
                    nr_feat = self.model.set_featuresC(context, features, stcls.c)
                    self.moves.set_valid(beam.is_valid[i], stcls.c)
                    self.model.set_scoresC(beam.scores[i], features, nr_feat)
        if gold is not None:
            n_gold = 0
            lines = []
            for i in range(beam.size):
                stcls = <StateClass>beam.at(i)
                if not stcls.c.is_final():
                    self.moves.set_costs(beam.is_valid[i], beam.costs[i], stcls, gold)
                    if follow_gold:
                        for j in range(self.moves.n_moves):
                            if beam.costs[i][j] >= 1:
                                beam.is_valid[i][j] = 0
                                lines.append((stcls.B(0), stcls.B(1),
                                    stcls.B_(0).ent_iob, stcls.B_(1).ent_iob,
                                    stcls.B_(1).sent_start,
                                    j,
                                    beam.is_valid[i][j], 'set invalid',
                                    beam.costs[i][j], self.moves.c[j].move, self.moves.c[j].label))
                            n_gold += 1 if beam.is_valid[i][j] else 0
            if follow_gold and n_gold == 0:
                raise Exception("No gold")
        if follow_gold:
            beam.advance(_transition_state, NULL, <void*>self.moves.c)
        else:
            beam.advance(_transition_state, _hash_state, <void*>self.moves.c)
        beam.check_done(_check_final_state, NULL)


# These are passed as callbacks to thinc.search.Beam
cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1:
    dest = <StateClass>_dest
    src = <StateClass>_src
    moves = <const Transition*>_moves
    dest.clone(src)
    moves[clas].do(dest.c, moves[clas].label)


cdef int _check_final_state(void* _state, void* extra_args) except -1:
    return (<StateClass>_state).is_final()


def _cleanup(Beam beam):
    for i in range(beam.width):
        Py_XDECREF(<PyObject*>beam._states[i].content)
        Py_XDECREF(<PyObject*>beam._parents[i].content)


cdef hash_t _hash_state(void* _state, void* _) except 0:
    state = <StateClass>_state
    if state.c.is_final():
        return 1
    else:
        return state.c.hash()


def _check_train_integrity(Beam pred, Beam gold, GoldParse gold_parse, TransitionSystem moves):
    for i in range(pred.size):
        if not pred._states[i].is_done or pred._states[i].loss == 0:
            continue
        state = <StateClass>pred.at(i)
        if is_gold(state, gold_parse, moves.strings) == True:
            for dep in gold_parse.orig_annot:
                print(dep[1], dep[3], dep[4])
            print("Cost", pred._states[i].loss)
            for j in range(gold_parse.length):
                print(gold_parse.orig_annot[j][1], state.H(j), moves.strings[state.safe_get(j).dep])
            acts = [moves.c[clas].move for clas in pred.histories[i]]
            labels = [moves.c[clas].label for clas in pred.histories[i]]
            print([moves.move_name(move, label) for move, label in zip(acts, labels)])
            raise Exception("Predicted state is gold-standard")
    for i in range(gold.size):
        if not gold._states[i].is_done:
            continue
        state = <StateClass>gold.at(i)
        if is_gold(state, gold_parse, moves.strings) == False:
            print("Truth")
            for dep in gold_parse.orig_annot:
                print(dep[1], dep[3], dep[4])
            print("Predicted good")
            for j in range(gold_parse.length):
                print(gold_parse.orig_annot[j][1], state.H(j), moves.strings[state.safe_get(j).dep])
            raise Exception("Gold parse is not gold-standard")


def is_gold(StateClass state, GoldParse gold, StringStore strings):
    predicted = set()
    truth = set()
    for i in range(gold.length):
        if gold.cand_to_gold[i] is None:
            continue
        if state.safe_get(i).dep:
            predicted.add((i, state.H(i), strings[state.safe_get(i).dep]))
        else:
            predicted.add((i, state.H(i), 'ROOT'))
        id_, word, tag, head, dep, ner = gold.orig_annot[gold.cand_to_gold[i]]
        truth.add((id_, head, dep))
    return truth == predicted