"""Thinc layer to do simpler transition-based parsing, NER, etc."""
from typing import List, Tuple, Dict, Optional
import numpy
from thinc.api import Ops, Model, with_array, softmax_activation, padded2list
from thinc.api import to_numpy
from thinc.types import Padded, Ints1d, Ints3d, Floats2d, Floats3d

from ..tokens import Doc


def BILUO() -> Model[Padded, Padded]:
    return Model(
        "biluo",
        forward,
        init=init,
        dims={"nO": None},
        attrs={"get_num_actions": get_num_actions}
    )


def init(model, X: Optional[Padded]=None, Y: Optional[Padded]=None):
    if X is not None and Y is not None:
        if X.data.shape != Y.data.shape:
            # TODO: Fix error
            raise ValueError("Mismatched shapes (TODO: Fix message)")
        model.set_dim("nO", X.data.shape[2])
    elif X is not None:
        model.set_dim("nO", X.data.shape[2])
    elif Y is not None:
        model.set_dim("nO", Y.data.shape[2])
    elif model.get_dim("nO") is None:
        raise ValueError("Dimension unset for BILUO: nO")


def forward(model: Model[Padded, Padded], Xp: Padded, is_train: bool):
    n_labels = (model.get_dim("nO") - 1) // 4
    n_tokens, n_docs, n_actions = Xp.data.shape
    # At each timestep, we make a validity mask of shape (n_docs, n_actions)
    # to indicate which actions are valid next for each sequence. To construct
    # the mask, we have a state of shape (2, n_actions) and a validity table of
    # shape (2, n_actions+1, n_actions). The first dimension of the state indicates
    # whether it's the last token, the second dimension indicates the previous
    # action, plus a special 'null action' for the first entry.
    valid_transitions = model.ops.asarray(_get_transition_table(n_labels))
    prev_actions = model.ops.alloc1i(n_docs)
    # Initialize as though prev action was O
    prev_actions.fill(n_actions - 1)
    Y = model.ops.alloc3f(*Xp.data.shape)
    masks = model.ops.alloc3f(*Y.shape)
    max_value = Xp.data.max()
    for t in range(Xp.data.shape[0]):
        is_last = (Xp.lengths < (t+2)).astype("i")
        masks[t] = valid_transitions[is_last, prev_actions]
        # Don't train the out-of-bounds sequences.
        masks[t, Xp.size_at_t[t]:] = 0
        # Valid actions get 0*10e8, invalid get large negative value
        Y[t] = Xp.data[t] + ((masks[t]-1) * max_value * 10)
        prev_actions = Y[t].argmax(axis=-1)

    def backprop_biluo(dY: Padded) -> Padded:
        dY.data *= masks
        return dY

    return Padded(Y, Xp.size_at_t, Xp.lengths, Xp.indices), backprop_biluo


def get_num_actions(n_labels: int) -> int:
    # One BEGIN action per label
    # One IN action per label
    # One LAST action per label
    # One UNIT action per label
    # One OUT action
    return n_labels + n_labels + n_labels + n_labels + 1


def _get_transition_table(
    n_labels: int, *, _cache: Dict[int, Floats3d] = {}
) -> Floats3d:
    n_actions = get_num_actions(n_labels)
    if n_actions in _cache:
        return _cache[n_actions]
    table = numpy.zeros((2, n_actions, n_actions), dtype="f")
    B_start, B_end = (0, n_labels)
    I_start, I_end = (B_end, B_end + n_labels)
    L_start, L_end = (I_end, I_end + n_labels)
    U_start, U_end = (L_end, L_end + n_labels)
    # Using ranges allows us to set specific cells, which is necessary to express
    # that only actions of the same label are valid continuations.
    B_range = numpy.arange(B_start, B_end)
    I_range = numpy.arange(I_start, I_end)
    L_range = numpy.arange(L_start, L_end)
    O_action = U_end
    # If this is the last token and the previous action was B or I, only L
    # of that label is valid
    table[1, B_range, L_range] = 1
    table[1, I_range, L_range] = 1
    # If this isn't the last token and the previous action was B or I, only I or
    # L of that label are valid.
    table[0, B_range, I_range] = 1
    table[0, B_range, L_range] = 1
    table[0, I_range, I_range] = 1
    table[0, I_range, L_range] = 1
    # If this isn't the last token and the previous was L, U or O, B is valid
    table[0, L_start:, :B_end] = 1
    # Regardless of whether this is the last token, if the previous action was
    # {L, U, O}, U and O are valid.
    table[:, L_start:, U_start:] = 1
    _cache[n_actions] = table
    return table