import warnings from unittest import TestCase import pytest import srsly from numpy import zeros from spacy.kb import KnowledgeBase, Writer from spacy.vectors import Vectors from spacy.language import Language from spacy.pipeline import TrainablePipe from spacy.vocab import Vocab from ..util import make_tempdir def nlp(): return Language() def vectors(): data = zeros((3, 1), dtype="f") keys = ["cat", "dog", "rat"] return Vectors(data=data, keys=keys) def custom_pipe(): # create dummy pipe partially implementing interface -- only want to test to_disk class SerializableDummy: def __init__(self, **cfg): if cfg: self.cfg = cfg else: self.cfg = None super(SerializableDummy, self).__init__() def to_bytes(self, exclude=tuple(), disable=None, **kwargs): return srsly.msgpack_dumps({"dummy": srsly.json_dumps(None)}) def from_bytes(self, bytes_data, exclude): return self def to_disk(self, path, exclude=tuple(), **kwargs): pass def from_disk(self, path, exclude=tuple(), **kwargs): return self class MyPipe(TrainablePipe): def __init__(self, vocab, model=True, **cfg): if cfg: self.cfg = cfg else: self.cfg = None self.model = SerializableDummy() self.vocab = vocab return MyPipe(Vocab()) def tagger(): nlp = Language() tagger = nlp.add_pipe("tagger") # need to add model for two reasons: # 1. no model leads to error in serialization, # 2. the affected line is the one for model serialization tagger.add_label("A") nlp.initialize() return tagger def entity_linker(): nlp = Language() def create_kb(vocab): kb = KnowledgeBase(vocab, entity_vector_length=1) kb.add_entity("test", 0.0, zeros((1, 1), dtype="f")) return kb entity_linker = nlp.add_pipe("entity_linker") entity_linker.set_kb(create_kb) # need to add model for two reasons: # 1. no model leads to error in serialization, # 2. the affected line is the one for model serialization nlp.initialize() return entity_linker objects_to_test = ( [nlp(), vectors(), custom_pipe(), tagger(), entity_linker()], ["nlp", "vectors", "custom_pipe", "tagger", "entity_linker"], ) def write_obj_and_catch_warnings(obj): with make_tempdir() as d: with warnings.catch_warnings(record=True) as warnings_list: warnings.filterwarnings("always", category=ResourceWarning) obj.to_disk(d) # in python3.5 it seems that deprecation warnings are not filtered by filterwarnings return list(filter(lambda x: isinstance(x, ResourceWarning), warnings_list)) @pytest.mark.parametrize("obj", objects_to_test[0], ids=objects_to_test[1]) def test_to_disk_resource_warning(obj): warnings_list = write_obj_and_catch_warnings(obj) assert len(warnings_list) == 0 def test_writer_with_path_py35(): writer = None with make_tempdir() as d: path = d / "test" try: writer = Writer(path) except Exception as e: pytest.fail(str(e)) finally: if writer: writer.close() def test_save_and_load_knowledge_base(): nlp = Language() kb = KnowledgeBase(nlp.vocab, entity_vector_length=1) with make_tempdir() as d: path = d / "kb" try: kb.to_disk(path) except Exception as e: pytest.fail(str(e)) try: kb_loaded = KnowledgeBase(nlp.vocab, entity_vector_length=1) kb_loaded.from_disk(path) except Exception as e: pytest.fail(str(e)) class TestToDiskResourceWarningUnittest(TestCase): def test_resource_warning(self): scenarios = zip(*objects_to_test) for scenario in scenarios: with self.subTest(msg=scenario[1]): warnings_list = write_obj_and_catch_warnings(scenario[0]) self.assertEqual(len(warnings_list), 0)