# coding: utf8
from __future__ import unicode_literals

import pytest
from thinc.neural.optimizers import Adam
from thinc.neural.ops import NumpyOps
from spacy.attrs import NORM
from spacy.gold import GoldParse
from spacy.vocab import Vocab
from spacy.tokens import Doc
from spacy.pipeline import DependencyParser, EntityRecognizer
from spacy.util import fix_random_seed


@pytest.fixture
def vocab():
    return Vocab(lex_attr_getters={NORM: lambda s: s})


@pytest.fixture
def parser(vocab):
    parser = DependencyParser(vocab)
    return parser


def test_init_parser(parser):
    pass


def _train_parser(parser):
    fix_random_seed(1)
    parser.add_label("left")
    parser.begin_training([], **parser.cfg)
    sgd = Adam(NumpyOps(), 0.001)

    for i in range(5):
        losses = {}
        doc = Doc(parser.vocab, words=["a", "b", "c", "d"])
        gold = GoldParse(doc, heads=[1, 1, 3, 3], deps=["left", "ROOT", "left", "ROOT"])
        parser.update([doc], [gold], sgd=sgd, losses=losses)
    return parser


def test_add_label(parser):
    parser = _train_parser(parser)
    parser.add_label("right")
    sgd = Adam(NumpyOps(), 0.001)
    for i in range(10):
        losses = {}
        doc = Doc(parser.vocab, words=["a", "b", "c", "d"])
        gold = GoldParse(
            doc, heads=[1, 1, 3, 3], deps=["right", "ROOT", "left", "ROOT"]
        )
        parser.update([doc], [gold], sgd=sgd, losses=losses)
    doc = Doc(parser.vocab, words=["a", "b", "c", "d"])
    doc = parser(doc)
    assert doc[0].dep_ == "right"
    assert doc[2].dep_ == "left"


def test_add_label_deserializes_correctly():
    ner1 = EntityRecognizer(Vocab())
    ner1.add_label("C")
    ner1.add_label("B")
    ner1.add_label("A")
    ner1.begin_training([])
    ner2 = EntityRecognizer(Vocab()).from_bytes(ner1.to_bytes())
    assert ner1.moves.n_moves == ner2.moves.n_moves
    for i in range(ner1.moves.n_moves):
        assert ner1.moves.get_class_name(i) == ner2.moves.get_class_name(i)