# cython: infer_types=True
# cython: profile=True
cimport numpy as np
import numpy
from cpython.ref cimport PyObject, Py_XDECREF
from thinc.extra.search cimport Beam
from thinc.extra.search import MaxViolation
from thinc.extra.search cimport MaxViolation

from ...typedefs cimport hash_t, class_t
from .transition_system cimport TransitionSystem, Transition
from ...errors import Errors
from .stateclass cimport StateC, StateClass


# These are passed as callbacks to thinc.search.Beam
cdef int transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1:
    dest = <StateC*>_dest
    src = <StateC*>_src
    moves = <const Transition*>_moves
    dest.clone(src)
    moves[clas].do(dest, moves[clas].label)


cdef int check_final_state(void* _state, void* extra_args) except -1:
    state = <StateC*>_state
    return state.is_final()


cdef class BeamBatch(object):
    cdef public TransitionSystem moves
    cdef public object states
    cdef public object docs
    cdef public object golds
    cdef public object beams

    def __init__(self, TransitionSystem moves, states, golds,
                 int width, float density=0.):
        cdef StateClass state
        self.moves = moves
        self.states = states
        self.docs = [state.doc for state in states]
        self.golds = golds
        self.beams = []
        cdef Beam beam
        cdef StateC* st
        for state in states:
            beam = Beam(self.moves.n_moves, width, min_density=density)
            beam.initialize(self.moves.init_beam_state,
                            self.moves.del_beam_state, state.c.length,
                            <void*>state.c._sent)
            for i in range(beam.width):
                st = <StateC*>beam.at(i)
                st.offset = state.c.offset
            beam.check_done(check_final_state, NULL)
            self.beams.append(beam)

    @property
    def is_done(self):
        return all(b.is_done for b in self.beams)

    def __getitem__(self, i):
        return self.beams[i]

    def __len__(self):
        return len(self.beams)

    def get_states(self):
        cdef Beam beam
        cdef StateC* state
        cdef StateClass stcls
        states = []
        for beam, doc in zip(self, self.docs):
            for i in range(beam.size):
                state = <StateC*>beam.at(i)
                stcls = StateClass.borrow(state, doc)
                states.append(stcls)
        return states

    def get_unfinished_states(self):
        return [st for st in self.get_states() if not st.is_final()]

    def advance(self, float[:, ::1] scores, follow_gold=False):
        cdef Beam beam
        cdef int nr_class = scores.shape[1]
        cdef const float* c_scores = &scores[0, 0]
        docs = self.docs
        for i, beam in enumerate(self):
            if not beam.is_done:
                nr_state = self._set_scores(beam, c_scores, nr_class)
                assert nr_state
                if self.golds is not None:
                    self._set_costs(
                        beam,
                        docs[i],
                        self.golds[i],
                        follow_gold=follow_gold
                    )
                c_scores += nr_state * nr_class
                beam.advance(transition_state, NULL, <void*>self.moves.c)
                beam.check_done(check_final_state, NULL)

    cdef int _set_scores(self, Beam beam, const float* scores, int nr_class) except -1:
        cdef int nr_state = 0
        for i in range(beam.size):
            state = <StateC*>beam.at(i)
            if not state.is_final():
                for j in range(nr_class):
                    beam.scores[i][j] = scores[nr_state * nr_class + j]
                self.moves.set_valid(beam.is_valid[i], state)
                nr_state += 1
            else:
                for j in range(beam.nr_class):
                    beam.scores[i][j] = 0
                    beam.costs[i][j] = 0
        return nr_state

    def _set_costs(self, Beam beam, doc, gold, int follow_gold=False):
        cdef const StateC* state
        for i in range(beam.size):
            state = <const StateC*>beam.at(i)
            if state.is_final():
                for j in range(beam.nr_class):
                    beam.is_valid[i][j] = 0
                    beam.costs[i][j] = 9000
            else:
                self.moves.set_costs(beam.is_valid[i], beam.costs[i],
                                     state, gold)
                if follow_gold:
                    min_cost = 0
                    for j in range(beam.nr_class):
                        if beam.is_valid[i][j] and beam.costs[i][j] < min_cost:
                            min_cost = beam.costs[i][j]
                    for j in range(beam.nr_class):
                        if beam.costs[i][j] > min_cost:
                            beam.is_valid[i][j] = 0


def update_beam(TransitionSystem moves, states, golds, model, int width, beam_density=0.0):
    cdef MaxViolation violn
    pbeam = BeamBatch(moves, states, golds, width=width, density=beam_density)
    gbeam = BeamBatch(moves, states, golds, width=width, density=0.0)
    cdef StateClass state
    beam_maps = []
    backprops = []
    violns = [MaxViolation() for _ in range(len(states))]
    dones = [False for _ in states]
    while not pbeam.is_done or not gbeam.is_done:
        # The beam maps let us find the right row in the flattened scores
        # array for each state. States are identified by (example id,
        # history). We keep a different beam map for each step (since we'll
        # have a flat scores array for each step). The beam map will let us
        # take the per-state losses, and compute the gradient for each (step,
        # state, class).
        # Gather all states from the two beams in a list. Some stats may occur
        # in both beams. To figure out which beam each state belonged to,
        # we keep two lists of indices, p_indices and g_indices
        states, p_indices, g_indices, beam_map = get_unique_states(pbeam, gbeam)
        beam_maps.append(beam_map)
        if not states:
            break
        # Now that we have our flat list of states, feed them through the model
        scores, bp_scores = model.begin_update(states)
        assert scores.size != 0
        # Store the callbacks for the backward pass
        backprops.append(bp_scores)
        # Unpack the scores for the two beams. The indices arrays
        # tell us which example and state the scores-row refers to.
        # Now advance the states in the beams. The gold beam is constrained to
        # to follow only gold analyses.
        if not pbeam.is_done:
            pbeam.advance(model.ops.as_contig(scores[p_indices]))
        if not gbeam.is_done:
            gbeam.advance(model.ops.as_contig(scores[g_indices]), follow_gold=True)
        # Track the "maximum violation", to use in the update.
        for i, violn in enumerate(violns):
            if not dones[i]:
                violn.check_crf(pbeam[i], gbeam[i])
                if pbeam[i].is_done and gbeam[i].is_done:
                    dones[i] = True
    histories = []
    grads = []
    for violn in violns:
        if violn.p_hist:
            histories.append(violn.p_hist + violn.g_hist)
            d_loss = [d_l * violn.cost for d_l in violn.p_probs + violn.g_probs]
            grads.append(d_loss)
        else:
            histories.append([])
            grads.append([])
    loss = 0.0
    states_d_scores = get_gradient(moves.n_moves, beam_maps, histories, grads)
    for i, (d_scores, bp_scores) in enumerate(zip(states_d_scores, backprops)):
        loss += (d_scores**2).mean()
        bp_scores(d_scores)
    return loss


def collect_states(beams, docs):
    cdef StateClass state
    cdef Beam beam
    states = []
    for state_or_beam, doc in zip(beams, docs):
        if isinstance(state_or_beam, StateClass):
            states.append(state_or_beam)
        else:
            beam = state_or_beam
            state = StateClass.borrow(<StateC*>beam.at(0), doc)
            states.append(state)
    return states


def get_unique_states(pbeams, gbeams):
    seen = {}
    states = []
    p_indices = []
    g_indices = []
    beam_map = {}
    docs = pbeams.docs
    cdef Beam pbeam, gbeam
    if len(pbeams) != len(gbeams):
        raise ValueError(Errors.E079.format(pbeams=len(pbeams), gbeams=len(gbeams)))
    for eg_id, (pbeam, gbeam, doc) in enumerate(zip(pbeams, gbeams, docs)):
        if not pbeam.is_done:
            for i in range(pbeam.size):
                state = StateClass.borrow(<StateC*>pbeam.at(i), doc)
                if not state.is_final():
                    key = tuple([eg_id] + pbeam.histories[i])
                    if key in seen:
                        raise ValueError(Errors.E080.format(key=key))
                    seen[key] = len(states)
                    p_indices.append(len(states))
                    states.append(state)
            beam_map.update(seen)
        if not gbeam.is_done:
            for i in range(gbeam.size):
                state = StateClass.borrow(<StateC*>gbeam.at(i), doc)
                if not state.is_final():
                    key = tuple([eg_id] + gbeam.histories[i])
                    if key in seen:
                        g_indices.append(seen[key])
                    else:
                        g_indices.append(len(states))
                        beam_map[key] = len(states)
                        states.append(state)
    p_indices = numpy.asarray(p_indices, dtype='i')
    g_indices = numpy.asarray(g_indices, dtype='i')
    return states, p_indices, g_indices, beam_map


def get_gradient(nr_class, beam_maps, histories, losses):
    """The global model assigns a loss to each parse. The beam scores
    are additive, so the same gradient is applied to each action
    in the history. This gives the gradient of a single *action*
    for a beam state -- so we have "the gradient of loss for taking
    action i given history H."

    Histories: Each history is a list of actions
    Each candidate has a history
    Each beam has multiple candidates
    Each batch has multiple beams
    So history is list of lists of lists of ints
    """
    grads = []
    nr_steps = []
    for eg_id, hists in enumerate(histories):
        nr_step = 0
        for loss, hist in zip(losses[eg_id], hists):
            assert not numpy.isnan(loss)
            if loss != 0.0:
                nr_step = max(nr_step, len(hist))
        nr_steps.append(nr_step)
    for i in range(max(nr_steps)):
        grads.append(numpy.zeros((max(beam_maps[i].values())+1, nr_class),
                                 dtype='f'))
    if len(histories) != len(losses):
        raise ValueError(Errors.E081.format(n_hist=len(histories), losses=len(losses)))
    for eg_id, hists in enumerate(histories):
        for loss, hist in zip(losses[eg_id], hists):
            assert not numpy.isnan(loss)
            if loss == 0.0:
                continue
            key = tuple([eg_id])
            # Adjust loss for length
            # We need to do this because each state in a short path is scored
            # multiple times, as we add in the average cost when we run out
            # of actions.
            avg_loss = loss / len(hist)
            loss += avg_loss * (nr_steps[eg_id] - len(hist))
            for step, clas in enumerate(hist):
                i = beam_maps[step][key]
                # In step j, at state i action clas
                # resulted in loss
                grads[step][i, clas] += loss
                key = key + tuple([clas])
    return grads