spaCy/spacy/tests/regression/test_issue5230.py

94 lines
2.6 KiB
Python
Raw Normal View History

# coding: utf8
import warnings
import pytest
import srsly
2020-04-06 21:51:12 +03:00
from numpy import zeros
from spacy.kb import KnowledgeBase
from spacy.vectors import Vectors
2020-04-06 21:51:12 +03:00
from spacy.language import Language
from spacy.pipeline import Pipe
from spacy.tests.util import make_tempdir
2020-04-06 21:51:12 +03:00
def nlp():
return Language()
2020-04-06 21:51:12 +03:00
def vectors():
data = zeros((3, 1), dtype="f")
keys = ["cat", "dog", "rat"]
2020-04-06 21:51:12 +03:00
return Vectors(data=data, keys=keys)
2020-04-06 21:51:12 +03:00
def custom_pipe():
# create dummy pipe partially implementing interface -- only want to test to_disk
class SerializableDummy(object):
def __init__(self, **cfg):
if cfg:
self.cfg = cfg
else:
self.cfg = None
super(SerializableDummy, self).__init__()
def to_bytes(self, exclude=tuple(), disable=None, **kwargs):
return srsly.msgpack_dumps({"dummy": srsly.json_dumps(None)})
def from_bytes(self, bytes_data, exclude):
return self
def to_disk(self, path, exclude=tuple(), **kwargs):
pass
def from_disk(self, path, exclude=tuple(), **kwargs):
return self
class MyPipe(Pipe):
def __init__(self, vocab, model=True, **cfg):
if cfg:
self.cfg = cfg
else:
self.cfg = None
self.model = SerializableDummy()
self.vocab = SerializableDummy()
2020-04-06 21:51:12 +03:00
return MyPipe(None)
2020-04-06 21:51:12 +03:00
def tagger():
nlp = Language()
nlp.add_pipe(nlp.create_pipe("tagger"))
tagger = nlp.get_pipe("tagger")
# need to add model for two reasons:
# 1. no model leads to error in serialization,
# 2. the affected line is the one for model serialization
tagger.begin_training(pipeline=nlp.pipeline)
2020-04-06 21:51:12 +03:00
return tagger
2020-04-06 21:51:12 +03:00
def entity_linker():
nlp = Language()
nlp.add_pipe(nlp.create_pipe("entity_linker"))
entity_linker = nlp.get_pipe("entity_linker")
# need to add model for two reasons:
# 1. no model leads to error in serialization,
# 2. the affected line is the one for model serialization
kb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
entity_linker.set_kb(kb)
entity_linker.begin_training(pipeline=nlp.pipeline)
2020-04-06 21:51:12 +03:00
return entity_linker
2020-04-06 21:51:12 +03:00
@pytest.mark.parametrize(
"obj",
[nlp(), vectors(), custom_pipe(), tagger(), entity_linker()],
ids=["nlp", "vectors", "custom_pipe", "tagger", "entity_linker"],
)
def test_to_disk_resource_warning(obj):
with make_tempdir() as d:
2020-04-06 21:51:12 +03:00
with warnings.catch_warnings(record=True) as warnings_list:
warnings.filterwarnings("always", category=ResourceWarning)
2020-04-06 21:51:12 +03:00
obj.to_disk(d)
assert len(warnings_list) == 0