import warnings
from unittest import TestCase

import pytest
import srsly
from numpy import zeros

from spacy.kb.kb_in_memory import InMemoryLookupKB, Writer
from spacy.language import Language
from spacy.pipeline import TrainablePipe
from spacy.vectors import Vectors
from spacy.vocab import Vocab

from ..util import make_tempdir


def nlp():
    return Language()


def vectors():
    data = zeros((3, 1), dtype="f")
    keys = ["cat", "dog", "rat"]
    return Vectors(data=data, keys=keys)


def custom_pipe():
    # create dummy pipe partially implementing interface -- only want to test to_disk
    class SerializableDummy:
        def __init__(self, **cfg):
            if cfg:
                self.cfg = cfg
            else:
                self.cfg = None
            super(SerializableDummy, self).__init__()

        def to_bytes(self, exclude=tuple(), disable=None, **kwargs):
            return srsly.msgpack_dumps({"dummy": srsly.json_dumps(None)})

        def from_bytes(self, bytes_data, exclude):
            return self

        def to_disk(self, path, exclude=tuple(), **kwargs):
            pass

        def from_disk(self, path, exclude=tuple(), **kwargs):
            return self

    class MyPipe(TrainablePipe):
        def __init__(self, vocab, model=True, **cfg):
            if cfg:
                self.cfg = cfg
            else:
                self.cfg = None
            self.model = SerializableDummy()
            self.vocab = vocab

    return MyPipe(Vocab())


def tagger():
    nlp = Language()
    tagger = nlp.add_pipe("tagger")
    # need to add model for two reasons:
    # 1. no model leads to error in serialization,
    # 2. the affected line is the one for model serialization
    tagger.add_label("A")
    nlp.initialize()
    return tagger


def entity_linker():
    nlp = Language()

    def create_kb(vocab):
        kb = InMemoryLookupKB(vocab, entity_vector_length=1)
        kb.add_entity("test", 0.0, zeros((1,), dtype="f"))
        return kb

    entity_linker = nlp.add_pipe("entity_linker")
    entity_linker.set_kb(create_kb)
    # need to add model for two reasons:
    # 1. no model leads to error in serialization,
    # 2. the affected line is the one for model serialization
    nlp.initialize()
    return entity_linker


objects_to_test = (
    [nlp(), vectors(), custom_pipe(), tagger(), entity_linker()],
    ["nlp", "vectors", "custom_pipe", "tagger", "entity_linker"],
)


def write_obj_and_catch_warnings(obj):
    with make_tempdir() as d:
        with warnings.catch_warnings(record=True) as warnings_list:
            warnings.filterwarnings("always", category=ResourceWarning)
            obj.to_disk(d)
            # in python3.5 it seems that deprecation warnings are not filtered by filterwarnings
            return list(filter(lambda x: isinstance(x, ResourceWarning), warnings_list))


@pytest.mark.parametrize("obj", objects_to_test[0], ids=objects_to_test[1])
def test_to_disk_resource_warning(obj):
    warnings_list = write_obj_and_catch_warnings(obj)
    assert len(warnings_list) == 0


def test_writer_with_path_py35():
    writer = None
    with make_tempdir() as d:
        path = d / "test"
        try:
            writer = Writer(path)
        except Exception as e:
            pytest.fail(str(e))
        finally:
            if writer:
                writer.close()


def test_save_and_load_knowledge_base():
    nlp = Language()
    kb = InMemoryLookupKB(nlp.vocab, entity_vector_length=1)
    with make_tempdir() as d:
        path = d / "kb"
        try:
            kb.to_disk(path)
        except Exception as e:
            pytest.fail(str(e))

        try:
            kb_loaded = InMemoryLookupKB(nlp.vocab, entity_vector_length=1)
            kb_loaded.from_disk(path)
        except Exception as e:
            pytest.fail(str(e))


class TestToDiskResourceWarningUnittest(TestCase):
    def test_resource_warning(self):
        scenarios = zip(*objects_to_test)

        for scenario in scenarios:
            with self.subTest(msg=scenario[1]):
                warnings_list = write_obj_and_catch_warnings(scenario[0])
                self.assertEqual(len(warnings_list), 0)