import pytest
from spacy.vocab import Vocab
from spacy import registry
from spacy.training import Example
from spacy.pipeline import DependencyParser
from spacy.tokens import Doc
from spacy.pipeline._parser_internals.nonproj import projectivize
from spacy.pipeline._parser_internals.arc_eager import ArcEager
from spacy.pipeline.dep_parser import DEFAULT_PARSER_MODEL


def get_sequence_costs(M, words, heads, deps, transitions):
    doc = Doc(Vocab(), words=words)
    example = Example.from_dict(doc, {"heads": heads, "deps": deps})
    states, golds, _ = M.init_gold_batch([example])
    state = states[0]
    gold = golds[0]
    cost_history = []
    for gold_action in transitions:
        gold.update(state)
        state_costs = {}
        for i in range(M.n_moves):
            name = M.class_name(i)
            state_costs[name] = M.get_cost(state, gold, i)
        M.transition(state, gold_action)
        cost_history.append(state_costs)
    return state, cost_history


@pytest.fixture
def vocab():
    return Vocab()


@pytest.fixture
def arc_eager(vocab):
    moves = ArcEager(vocab.strings, ArcEager.get_actions())
    moves.add_action(2, "left")
    moves.add_action(3, "right")
    return moves


def test_oracle_four_words(arc_eager, vocab):
    words = ["a", "b", "c", "d"]
    heads = [1, 1, 3, 3]
    deps = ["left", "ROOT", "left", "ROOT"]
    for dep in deps:
        arc_eager.add_action(2, dep)  # Left
        arc_eager.add_action(3, dep)  # Right
    actions = ["S", "L-left", "B-ROOT", "S", "D", "S", "L-left", "S", "D"]
    state, cost_history = get_sequence_costs(arc_eager, words, heads, deps, actions)
    expected_gold = [
        ["S"],
        ["B-ROOT", "L-left"],
        ["B-ROOT"],
        ["S"],
        ["D"],
        ["S"],
        ["L-left"],
        ["S"],
        ["D"],
    ]
    assert state.is_final()
    for i, state_costs in enumerate(cost_history):
        # Check gold moves is 0 cost
        golds = [act for act, cost in state_costs.items() if cost < 1]
        assert golds == expected_gold[i], (i, golds, expected_gold[i])


annot_tuples = [
    (0, "When", "WRB", 11, "advmod", "O"),
    (1, "Walter", "NNP", 2, "compound", "B-PERSON"),
    (2, "Rodgers", "NNP", 11, "nsubj", "L-PERSON"),
    (3, ",", ",", 2, "punct", "O"),
    (4, "our", "PRP$", 6, "poss", "O"),
    (5, "embedded", "VBN", 6, "amod", "O"),
    (6, "reporter", "NN", 2, "appos", "O"),
    (7, "with", "IN", 6, "prep", "O"),
    (8, "the", "DT", 10, "det", "B-ORG"),
    (9, "3rd", "NNP", 10, "compound", "I-ORG"),
    (10, "Cavalry", "NNP", 7, "pobj", "L-ORG"),
    (11, "says", "VBZ", 44, "advcl", "O"),
    (12, "three", "CD", 13, "nummod", "U-CARDINAL"),
    (13, "battalions", "NNS", 16, "nsubj", "O"),
    (14, "of", "IN", 13, "prep", "O"),
    (15, "troops", "NNS", 14, "pobj", "O"),
    (16, "are", "VBP", 11, "ccomp", "O"),
    (17, "on", "IN", 16, "prep", "O"),
    (18, "the", "DT", 19, "det", "O"),
    (19, "ground", "NN", 17, "pobj", "O"),
    (20, ",", ",", 17, "punct", "O"),
    (21, "inside", "IN", 17, "prep", "O"),
    (22, "Baghdad", "NNP", 21, "pobj", "U-GPE"),
    (23, "itself", "PRP", 22, "appos", "O"),
    (24, ",", ",", 16, "punct", "O"),
    (25, "have", "VBP", 26, "aux", "O"),
    (26, "taken", "VBN", 16, "dep", "O"),
    (27, "up", "RP", 26, "prt", "O"),
    (28, "positions", "NNS", 26, "dobj", "O"),
    (29, "they", "PRP", 31, "nsubj", "O"),
    (30, "'re", "VBP", 31, "aux", "O"),
    (31, "going", "VBG", 26, "parataxis", "O"),
    (32, "to", "TO", 33, "aux", "O"),
    (33, "spend", "VB", 31, "xcomp", "O"),
    (34, "the", "DT", 35, "det", "B-TIME"),
    (35, "night", "NN", 33, "dobj", "L-TIME"),
    (36, "there", "RB", 33, "advmod", "O"),
    (37, "presumably", "RB", 33, "advmod", "O"),
    (38, ",", ",", 44, "punct", "O"),
    (39, "how", "WRB", 40, "advmod", "O"),
    (40, "many", "JJ", 41, "amod", "O"),
    (41, "soldiers", "NNS", 44, "pobj", "O"),
    (42, "are", "VBP", 44, "aux", "O"),
    (43, "we", "PRP", 44, "nsubj", "O"),
    (44, "talking", "VBG", 44, "ROOT", "O"),
    (45, "about", "IN", 44, "prep", "O"),
    (46, "right", "RB", 47, "advmod", "O"),
    (47, "now", "RB", 44, "advmod", "O"),
    (48, "?", ".", 44, "punct", "O"),
]


def test_get_oracle_actions():
    ids, words, tags, heads, deps, ents = [], [], [], [], [], []
    for id_, word, tag, head, dep, ent in annot_tuples:
        ids.append(id_)
        words.append(word)
        tags.append(tag)
        heads.append(head)
        deps.append(dep)
        ents.append(ent)
    doc = Doc(Vocab(), words=[t[1] for t in annot_tuples])
    config = {
        "learn_tokens": False,
        "min_action_freq": 0,
        "update_with_oracle_cut_size": 100,
    }
    cfg = {"model": DEFAULT_PARSER_MODEL}
    model = registry.resolve(cfg, validate=True)["model"]
    parser = DependencyParser(doc.vocab, model, **config)
    parser.moves.add_action(0, "")
    parser.moves.add_action(1, "")
    parser.moves.add_action(1, "")
    parser.moves.add_action(4, "ROOT")
    heads, deps = projectivize(heads, deps)
    for i, (head, dep) in enumerate(zip(heads, deps)):
        if head > i:
            parser.moves.add_action(2, dep)
        elif head < i:
            parser.moves.add_action(3, dep)
    example = Example.from_dict(
        doc, {"words": words, "tags": tags, "heads": heads, "deps": deps}
    )
    parser.moves.get_oracle_sequence(example)


def test_oracle_dev_sentence(vocab, arc_eager):
    words_deps_heads = """
        Rolls-Royce nn Inc.
        Motor nn Inc.
        Cars nn Inc.
        Inc. nsubj said
        said ROOT said
        it nsubj expects
        expects ccomp said
        its poss sales
        U.S. nn sales
        sales nsubj steady
        to aux steady
        remain cop steady
        steady xcomp expects
        at prep steady
        about quantmod 1,200
        1,200 num cars
        cars pobj at
        in prep steady
        1990 pobj in
        . punct said
    """
    expected_transitions = [
        "S",  # Shift "Rolls-Royce"
        "S",  # Shift 'Motor'
        "S",  # Shift 'Cars'
        "L-nn",  # Attach 'Cars' to 'Inc.'
        "L-nn",  # Attach 'Motor' to 'Inc.'
        "L-nn",  # Attach 'Rolls-Royce' to 'Inc.'
        "S",  # Shift "Inc."
        "L-nsubj",  # Attach 'Inc.' to 'said'
        "S",  # Shift 'said'
        "S",  # Shift 'it'
        "L-nsubj",  # Attach 'it.' to 'expects'
        "R-ccomp",  # Attach 'expects' to 'said'
        "S",  # Shift 'its'
        "S",  # Shift 'U.S.'
        "L-nn",  # Attach 'U.S.' to 'sales'
        "L-poss",  # Attach 'its' to 'sales'
        "S",  # Shift 'sales'
        "S",  # Shift 'to'
        "S",  # Shift 'remain'
        "L-cop",  # Attach 'remain' to 'steady'
        "L-aux",  # Attach 'to' to 'steady'
        "L-nsubj",  # Attach 'sales' to 'steady'
        "R-xcomp",  # Attach 'steady' to 'expects'
        "R-prep",  # Attach 'at' to 'steady'
        "S",  # Shift 'about'
        "L-quantmod",  # Attach "about" to "1,200"
        "S",  # Shift "1,200"
        "L-num",  # Attach "1,200" to "cars"
        "R-pobj",  # Attach "cars" to "at"
        "D",  # Reduce "cars"
        "D",  # Reduce "at"
        "R-prep",  # Attach "in" to "steady"
        "R-pobj",  # Attach "1990" to "in"
        "D",  # Reduce "1990"
        "D",  # Reduce "in"
        "D",  # Reduce "steady"
        "D",  # Reduce "expects"
        "R-punct",  # Attach "." to "said"
        "D",  # Reduce "."
        "D",  # Reduce "said"
    ]

    gold_words = []
    gold_deps = []
    gold_heads = []
    for line in words_deps_heads.strip().split("\n"):
        line = line.strip()
        if not line:
            continue
        word, dep, head = line.split()
        gold_words.append(word)
        gold_deps.append(dep)
        gold_heads.append(head)
    gold_heads = [gold_words.index(head) for head in gold_heads]
    for dep in gold_deps:
        arc_eager.add_action(2, dep)  # Left
        arc_eager.add_action(3, dep)  # Right
    doc = Doc(Vocab(), words=gold_words)
    example = Example.from_dict(doc, {"heads": gold_heads, "deps": gold_deps})
    ae_oracle_actions = arc_eager.get_oracle_sequence(example, _debug=False)
    ae_oracle_actions = [arc_eager.get_class_name(i) for i in ae_oracle_actions]
    assert ae_oracle_actions == expected_transitions


def test_oracle_bad_tokenization(vocab, arc_eager):
    words_deps_heads = """
        [catalase] dep is
        : punct is
        that nsubj is
        is root is
        bad comp is
    """

    gold_words = []
    gold_deps = []
    gold_heads = []
    for line in words_deps_heads.strip().split("\n"):
        line = line.strip()
        if not line:
            continue
        word, dep, head = line.split()
        gold_words.append(word)
        gold_deps.append(dep)
        gold_heads.append(head)
    gold_heads = [gold_words.index(head) for head in gold_heads]
    for dep in gold_deps:
        arc_eager.add_action(2, dep)  # Left
        arc_eager.add_action(3, dep)  # Right
    reference = Doc(Vocab(), words=gold_words, deps=gold_deps, heads=gold_heads)
    predicted = Doc(
        reference.vocab, words=["[", "catalase", "]", ":", "that", "is", "bad"]
    )
    example = Example(predicted=predicted, reference=reference)
    ae_oracle_actions = arc_eager.get_oracle_sequence(example, _debug=False)
    ae_oracle_actions = [arc_eager.get_class_name(i) for i in ae_oracle_actions]
    assert ae_oracle_actions