issue5230: optimized unit test a bit

This commit is contained in:
Leander Fiedler 2020-04-06 20:51:12 +02:00 committed by lfiedler
parent 71cc903d65
commit cde96f6c64

View File

@ -1,41 +1,28 @@
# coding: utf8
import warnings
import numpy
import pytest
import srsly
from numpy import zeros
from spacy.kb import KnowledgeBase
from spacy.vectors import Vectors
from spacy.language import Language
from spacy.pipeline import Pipe
from spacy.tests.util import make_tempdir
def test_language_to_disk_resource_warning():
nlp = Language()
with make_tempdir() as d:
with warnings.catch_warnings(record=True) as w:
# catch only warnings raised in spacy.language since there may be others from other components or pipelines
warnings.filterwarnings(
"always", module="spacy.language", category=ResourceWarning
)
nlp.to_disk(d)
assert len(w) == 0
def nlp():
return Language()
def test_vectors_to_disk_resource_warning():
data = numpy.zeros((3, 300), dtype="f")
def vectors():
data = zeros((3, 1), dtype="f")
keys = ["cat", "dog", "rat"]
vectors = Vectors(data=data, keys=keys)
with make_tempdir() as d:
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings("always", category=ResourceWarning)
vectors.to_disk(d)
assert len(w) == 0
return Vectors(data=data, keys=keys)
def test_custom_pipes_to_disk_resource_warning():
def custom_pipe():
# create dummy pipe partially implementing interface -- only want to test to_disk
class SerializableDummy(object):
def __init__(self, **cfg):
@ -66,15 +53,10 @@ def test_custom_pipes_to_disk_resource_warning():
self.model = SerializableDummy()
self.vocab = SerializableDummy()
pipe = MyPipe(None)
with make_tempdir() as d:
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings("always", category=ResourceWarning)
pipe.to_disk(d)
assert len(w) == 0
return MyPipe(None)
def test_tagger_to_disk_resource_warning():
def tagger():
nlp = Language()
nlp.add_pipe(nlp.create_pipe("tagger"))
tagger = nlp.get_pipe("tagger")
@ -82,15 +64,10 @@ def test_tagger_to_disk_resource_warning():
# 1. no model leads to error in serialization,
# 2. the affected line is the one for model serialization
tagger.begin_training(pipeline=nlp.pipeline)
with make_tempdir() as d:
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings("always", category=ResourceWarning)
tagger.to_disk(d)
assert len(w) == 0
return tagger
def test_entity_linker_to_disk_resource_warning():
def entity_linker():
nlp = Language()
nlp.add_pipe(nlp.create_pipe("entity_linker"))
entity_linker = nlp.get_pipe("entity_linker")
@ -100,9 +77,17 @@ def test_entity_linker_to_disk_resource_warning():
kb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
entity_linker.set_kb(kb)
entity_linker.begin_training(pipeline=nlp.pipeline)
return entity_linker
@pytest.mark.parametrize(
"obj",
[nlp(), vectors(), custom_pipe(), tagger(), entity_linker()],
ids=["nlp", "vectors", "custom_pipe", "tagger", "entity_linker"],
)
def test_to_disk_resource_warning(obj):
with make_tempdir() as d:
with warnings.catch_warnings(record=True) as w:
with warnings.catch_warnings(record=True) as warnings_list:
warnings.filterwarnings("always", category=ResourceWarning)
entity_linker.to_disk(d)
assert len(w) == 0
obj.to_disk(d)
assert len(warnings_list) == 0