from typing import List, Tuple, Callable, Optional, Sequence, cast
from thinc.initializers import glorot_uniform_init
from thinc.util import partial
from thinc.types import Ragged, Floats2d, Floats1d, Ints1d
from thinc.api import Model, Ops, registry

from ..tokens import Doc
from ..errors import Errors
from ..vectors import Mode
from ..vocab import Vocab


@registry.layers("spacy.StaticVectors.v2")
def StaticVectors(
    nO: Optional[int] = None,
    nM: Optional[int] = None,
    *,
    dropout: Optional[float] = None,
    init_W: Callable = glorot_uniform_init,
    key_attr: str = "ORTH"
) -> Model[List[Doc], Ragged]:
    """Embed Doc objects with their vocab's vectors table, applying a learned
    linear projection to control the dimensionality. If a dropout rate is
    specified, the dropout is applied per dimension over the whole batch.
    """
    return Model(
        "static_vectors",
        forward,
        init=partial(init, init_W),
        params={"W": None},
        attrs={"key_attr": key_attr, "dropout_rate": dropout},
        dims={"nO": nO, "nM": nM},
    )


def forward(
    model: Model[List[Doc], Ragged], docs: List[Doc], is_train: bool
) -> Tuple[Ragged, Callable]:
    token_count = sum(len(doc) for doc in docs)
    if not token_count:
        return _handle_empty(model.ops, model.get_dim("nO"))
    key_attr: int = model.attrs["key_attr"]
    keys = model.ops.flatten([cast(Ints1d, doc.to_array(key_attr)) for doc in docs])
    vocab: Vocab = docs[0].vocab
    W = cast(Floats2d, model.ops.as_contig(model.get_param("W")))
    if vocab.vectors.mode == Mode.default:
        V = model.ops.asarray(vocab.vectors.data)
        rows = vocab.vectors.find(keys=keys)
        V = model.ops.as_contig(V[rows])
    elif vocab.vectors.mode == Mode.floret:
        V = vocab.vectors.get_batch(keys)
        V = model.ops.as_contig(V)
    else:
        raise RuntimeError(Errors.E896)
    try:
        vectors_data = model.ops.gemm(V, W, trans2=True)
    except ValueError:
        raise RuntimeError(Errors.E896)
    if vocab.vectors.mode == Mode.default:
        # Convert negative indices to 0-vectors
        # TODO: more options for UNK tokens
        vectors_data[rows < 0] = 0
    output = Ragged(vectors_data, model.ops.asarray1i([len(doc) for doc in docs]))
    mask = None
    if is_train:
        mask = _get_drop_mask(model.ops, W.shape[0], model.attrs.get("dropout_rate"))
        if mask is not None:
            output.data *= mask

    def backprop(d_output: Ragged) -> List[Doc]:
        if mask is not None:
            d_output.data *= mask
        model.inc_grad(
            "W",
            model.ops.gemm(
                cast(Floats2d, d_output.data),
                cast(Floats2d, model.ops.as_contig(V)),
                trans1=True,
            ),
        )
        return []

    return output, backprop


def init(
    init_W: Callable,
    model: Model[List[Doc], Ragged],
    X: Optional[List[Doc]] = None,
    Y: Optional[Ragged] = None,
) -> Model[List[Doc], Ragged]:
    nM = model.get_dim("nM") if model.has_dim("nM") else None
    nO = model.get_dim("nO") if model.has_dim("nO") else None
    if X is not None and len(X):
        nM = X[0].vocab.vectors.shape[1]
    if Y is not None:
        nO = Y.data.shape[1]

    if nM is None:
        raise ValueError(Errors.E905)
    if nO is None:
        raise ValueError(Errors.E904)
    model.set_dim("nM", nM)
    model.set_dim("nO", nO)
    model.set_param("W", init_W(model.ops, (nO, nM)))
    return model


def _handle_empty(ops: Ops, nO: int):
    return Ragged(ops.alloc2f(0, nO), ops.alloc1i(0)), lambda d_ragged: []


def _get_drop_mask(ops: Ops, nO: int, rate: Optional[float]) -> Optional[Floats1d]:
    if rate is not None:
        mask = ops.get_dropout_mask((nO,), rate)
        return mask  # type: ignore
    return None