from typing import Optional, List
from thinc.api import Model, chain, list2array, Linear, zero_init, use_ops
from thinc.types import Floats2d

from ...errors import Errors
from ...compat import Literal
from ...util import registry
from .._precomputable_affine import PrecomputableAffine
from ..tb_framework import TransitionModel
from ...tokens import Doc


@registry.architectures("spacy.TransitionBasedParser.v2")
def build_tb_parser_model(
    tok2vec: Model[List[Doc], List[Floats2d]],
    state_type: Literal["parser", "ner"],
    extra_state_tokens: bool,
    hidden_width: int,
    maxout_pieces: int,
    use_upper: bool,
    nO: Optional[int] = None,
) -> Model:
    """
    Build a transition-based parser model. Can apply to NER or dependency-parsing.

    Transition-based parsing is an approach to structured prediction where the
    task of predicting the structure is mapped to a series of state transitions.
    You might find this tutorial helpful as background:
    https://explosion.ai/blog/parsing-english-in-python

    The neural network state prediction model consists of either two or three
    subnetworks:

    * tok2vec: Map each token into a vector representations. This subnetwork
        is run once for each batch.
    * lower: Construct a feature-specific vector for each (token, feature) pair.
        This is also run once for each batch. Constructing the state
        representation is then simply a matter of summing the component features
        and applying the non-linearity.
    * upper (optional): A feed-forward network that predicts scores from the
        state representation. If not present, the output from the lower model is
        used as action scores directly.

    tok2vec (Model[List[Doc], List[Floats2d]]):
        Subnetwork to map tokens into vector representations.
    state_type (str):
        String value denoting the type of parser model: "parser" or "ner"
    extra_state_tokens (bool): Whether or not to use additional tokens in the context
        to construct the state vector. Defaults to `False`, which means 3 and 8
        for the NER and parser respectively. When set to `True`, this would become 6
        feature sets (for the NER) or 13 (for the parser).
    hidden_width (int): The width of the hidden layer.
    maxout_pieces (int): How many pieces to use in the state prediction layer.
        Recommended values are 1, 2 or 3. If 1, the maxout non-linearity
        is replaced with a ReLu non-linearity if use_upper=True, and no
        non-linearity if use_upper=False.
    use_upper (bool): Whether to use an additional hidden layer after the state
        vector in order to predict the action scores. It is recommended to set
        this to False for large pretrained models such as transformers, and False
        for smaller networks. The upper layer is computed on CPU, which becomes
        a bottleneck on larger GPU-based models, where it's also less necessary.
    nO (int or None): The number of actions the model will predict between.
        Usually inferred from data at the beginning of training, or loaded from
        disk.
    """
    if state_type == "parser":
        nr_feature_tokens = 13 if extra_state_tokens else 8
    elif state_type == "ner":
        nr_feature_tokens = 6 if extra_state_tokens else 3
    else:
        raise ValueError(Errors.E917.format(value=state_type))
    t2v_width = tok2vec.get_dim("nO") if tok2vec.has_dim("nO") else None
    tok2vec = chain(tok2vec, list2array(), Linear(hidden_width, t2v_width))
    tok2vec.set_dim("nO", hidden_width)
    lower = _define_lower(
        nO=hidden_width if use_upper else nO,
        nF=nr_feature_tokens,
        nI=tok2vec.get_dim("nO"),
        nP=maxout_pieces,
    )
    upper = None
    if use_upper:
        with use_ops("numpy"):
            # Initialize weights at zero, as it's a classification layer.
            upper = _define_upper(nO=nO, nI=None)
    return TransitionModel(tok2vec, lower, upper, resize_output)


def _define_upper(nO, nI):
    return Linear(nO=nO, nI=nI, init_W=zero_init)


def _define_lower(nO, nF, nI, nP):
    return PrecomputableAffine(nO=nO, nF=nF, nI=nI, nP=nP)


def resize_output(model, new_nO):
    if model.attrs["has_upper"]:
        return _resize_upper(model, new_nO)
    return _resize_lower(model, new_nO)


def _resize_upper(model, new_nO):
    upper = model.get_ref("upper")
    if upper.has_dim("nO") is None:
        upper.set_dim("nO", new_nO)
        return model
    elif new_nO == upper.get_dim("nO"):
        return model

    smaller = upper
    nI = smaller.maybe_get_dim("nI")
    with use_ops("numpy"):
        larger = _define_upper(nO=new_nO, nI=nI)
    # it could be that the model is not initialized yet, then skip this bit
    if smaller.has_param("W"):
        larger_W = larger.ops.alloc2f(new_nO, nI)
        larger_b = larger.ops.alloc1f(new_nO)
        smaller_W = smaller.get_param("W")
        smaller_b = smaller.get_param("b")
        # Weights are stored in (nr_out, nr_in) format, so we're basically
        # just adding rows here.
        if smaller.has_dim("nO"):
            old_nO = smaller.get_dim("nO")
            larger_W[:old_nO] = smaller_W
            larger_b[:old_nO] = smaller_b
            for i in range(old_nO, new_nO):
                model.attrs["unseen_classes"].add(i)

        larger.set_param("W", larger_W)
        larger.set_param("b", larger_b)
    model._layers[-1] = larger
    model.set_ref("upper", larger)
    return model


def _resize_lower(model, new_nO):
    lower = model.get_ref("lower")
    if lower.has_dim("nO") is None:
        lower.set_dim("nO", new_nO)
        return model

    smaller = lower
    nI = smaller.maybe_get_dim("nI")
    nF = smaller.maybe_get_dim("nF")
    nP = smaller.maybe_get_dim("nP")
    larger = _define_lower(nO=new_nO, nI=nI, nF=nF, nP=nP)
    # it could be that the model is not initialized yet, then skip this bit
    if smaller.has_param("W"):
        larger_W = larger.ops.alloc4f(nF, new_nO, nP, nI)
        larger_b = larger.ops.alloc2f(new_nO, nP)
        larger_pad = larger.ops.alloc4f(1, nF, new_nO, nP)
        smaller_W = smaller.get_param("W")
        smaller_b = smaller.get_param("b")
        smaller_pad = smaller.get_param("pad")
        # Copy the old weights and padding into the new layer
        if smaller.has_dim("nO"):
            old_nO = smaller.get_dim("nO")
            larger_W[:, 0:old_nO, :, :] = smaller_W
            larger_pad[:, :, 0:old_nO, :] = smaller_pad
            larger_b[0:old_nO, :] = smaller_b
            for i in range(old_nO, new_nO):
                model.attrs["unseen_classes"].add(i)

        larger.set_param("W", larger_W)
        larger.set_param("b", larger_b)
        larger.set_param("pad", larger_pad)
    model._layers[1] = larger
    model.set_ref("lower", larger)
    return model