import warnings
from unittest import TestCase
import pytest
import srsly
from numpy import zeros
from spacy.kb import KnowledgeBase, Writer
from spacy.vectors import Vectors
from spacy.language import Language
from spacy.pipeline import TrainablePipe
from spacy.vocab import Vocab

from ..util import make_tempdir


def nlp():
    return Language()


def vectors():
    data = zeros((3, 1), dtype="f")
    keys = ["cat", "dog", "rat"]
    return Vectors(data=data, keys=keys)


def custom_pipe():
    # create dummy pipe partially implementing interface -- only want to test to_disk
    class SerializableDummy:
        def __init__(self, **cfg):
            if cfg:
                self.cfg = cfg
            else:
                self.cfg = None
            super(SerializableDummy, self).__init__()

        def to_bytes(self, exclude=tuple(), disable=None, **kwargs):
            return srsly.msgpack_dumps({"dummy": srsly.json_dumps(None)})

        def from_bytes(self, bytes_data, exclude):
            return self

        def to_disk(self, path, exclude=tuple(), **kwargs):
            pass

        def from_disk(self, path, exclude=tuple(), **kwargs):
            return self

    class MyPipe(TrainablePipe):
        def __init__(self, vocab, model=True, **cfg):
            if cfg:
                self.cfg = cfg
            else:
                self.cfg = None
            self.model = SerializableDummy()
            self.vocab = vocab

    return MyPipe(Vocab())


def tagger():
    nlp = Language()
    tagger = nlp.add_pipe("tagger")
    # need to add model for two reasons:
    # 1. no model leads to error in serialization,
    # 2. the affected line is the one for model serialization
    tagger.add_label("A")
    nlp.initialize()
    return tagger


def entity_linker():
    nlp = Language()

    def create_kb(vocab):
        kb = KnowledgeBase(vocab, entity_vector_length=1)
        kb.add_entity("test", 0.0, zeros((1, 1), dtype="f"))
        return kb

    entity_linker = nlp.add_pipe("entity_linker")
    entity_linker.set_kb(create_kb)
    # need to add model for two reasons:
    # 1. no model leads to error in serialization,
    # 2. the affected line is the one for model serialization
    nlp.initialize()
    return entity_linker


objects_to_test = (
    [nlp(), vectors(), custom_pipe(), tagger(), entity_linker()],
    ["nlp", "vectors", "custom_pipe", "tagger", "entity_linker"],
)


def write_obj_and_catch_warnings(obj):
    with make_tempdir() as d:
        with warnings.catch_warnings(record=True) as warnings_list:
            warnings.filterwarnings("always", category=ResourceWarning)
            obj.to_disk(d)
            # in python3.5 it seems that deprecation warnings are not filtered by filterwarnings
            return list(filter(lambda x: isinstance(x, ResourceWarning), warnings_list))


@pytest.mark.parametrize("obj", objects_to_test[0], ids=objects_to_test[1])
def test_to_disk_resource_warning(obj):
    warnings_list = write_obj_and_catch_warnings(obj)
    assert len(warnings_list) == 0


def test_writer_with_path_py35():
    writer = None
    with make_tempdir() as d:
        path = d / "test"
        try:
            writer = Writer(path)
        except Exception as e:
            pytest.fail(str(e))
        finally:
            if writer:
                writer.close()


def test_save_and_load_knowledge_base():
    nlp = Language()
    kb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
    with make_tempdir() as d:
        path = d / "kb"
        try:
            kb.to_disk(path)
        except Exception as e:
            pytest.fail(str(e))

        try:
            kb_loaded = KnowledgeBase(nlp.vocab, entity_vector_length=1)
            kb_loaded.from_disk(path)
        except Exception as e:
            pytest.fail(str(e))


class TestToDiskResourceWarningUnittest(TestCase):
    def test_resource_warning(self):
        scenarios = zip(*objects_to_test)

        for scenario in scenarios:
            with self.subTest(msg=scenario[1]):
                warnings_list = write_obj_and_catch_warnings(scenario[0])
                self.assertEqual(len(warnings_list), 0)