issue5230: optimized unit test a bit

This commit is contained in:
Leander Fiedler 2020-04-06 20:51:12 +02:00 committed by lfiedler
parent 71cc903d65
commit cde96f6c64

View File

@ -1,41 +1,28 @@
# coding: utf8 # coding: utf8
import warnings import warnings
import numpy
import pytest import pytest
import srsly import srsly
from numpy import zeros
from spacy.kb import KnowledgeBase from spacy.kb import KnowledgeBase
from spacy.vectors import Vectors from spacy.vectors import Vectors
from spacy.language import Language from spacy.language import Language
from spacy.pipeline import Pipe from spacy.pipeline import Pipe
from spacy.tests.util import make_tempdir from spacy.tests.util import make_tempdir
def test_language_to_disk_resource_warning(): def nlp():
nlp = Language() return Language()
with make_tempdir() as d:
with warnings.catch_warnings(record=True) as w:
# catch only warnings raised in spacy.language since there may be others from other components or pipelines
warnings.filterwarnings(
"always", module="spacy.language", category=ResourceWarning
)
nlp.to_disk(d)
assert len(w) == 0
def test_vectors_to_disk_resource_warning(): def vectors():
data = numpy.zeros((3, 300), dtype="f") data = zeros((3, 1), dtype="f")
keys = ["cat", "dog", "rat"] keys = ["cat", "dog", "rat"]
vectors = Vectors(data=data, keys=keys) return Vectors(data=data, keys=keys)
with make_tempdir() as d:
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings("always", category=ResourceWarning)
vectors.to_disk(d)
assert len(w) == 0
def test_custom_pipes_to_disk_resource_warning(): def custom_pipe():
# create dummy pipe partially implementing interface -- only want to test to_disk # create dummy pipe partially implementing interface -- only want to test to_disk
class SerializableDummy(object): class SerializableDummy(object):
def __init__(self, **cfg): def __init__(self, **cfg):
@ -66,15 +53,10 @@ def test_custom_pipes_to_disk_resource_warning():
self.model = SerializableDummy() self.model = SerializableDummy()
self.vocab = SerializableDummy() self.vocab = SerializableDummy()
pipe = MyPipe(None) return MyPipe(None)
with make_tempdir() as d:
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings("always", category=ResourceWarning)
pipe.to_disk(d)
assert len(w) == 0
def test_tagger_to_disk_resource_warning(): def tagger():
nlp = Language() nlp = Language()
nlp.add_pipe(nlp.create_pipe("tagger")) nlp.add_pipe(nlp.create_pipe("tagger"))
tagger = nlp.get_pipe("tagger") tagger = nlp.get_pipe("tagger")
@ -82,15 +64,10 @@ def test_tagger_to_disk_resource_warning():
# 1. no model leads to error in serialization, # 1. no model leads to error in serialization,
# 2. the affected line is the one for model serialization # 2. the affected line is the one for model serialization
tagger.begin_training(pipeline=nlp.pipeline) tagger.begin_training(pipeline=nlp.pipeline)
return tagger
with make_tempdir() as d:
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings("always", category=ResourceWarning)
tagger.to_disk(d)
assert len(w) == 0
def test_entity_linker_to_disk_resource_warning(): def entity_linker():
nlp = Language() nlp = Language()
nlp.add_pipe(nlp.create_pipe("entity_linker")) nlp.add_pipe(nlp.create_pipe("entity_linker"))
entity_linker = nlp.get_pipe("entity_linker") entity_linker = nlp.get_pipe("entity_linker")
@ -100,9 +77,17 @@ def test_entity_linker_to_disk_resource_warning():
kb = KnowledgeBase(nlp.vocab, entity_vector_length=1) kb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
entity_linker.set_kb(kb) entity_linker.set_kb(kb)
entity_linker.begin_training(pipeline=nlp.pipeline) entity_linker.begin_training(pipeline=nlp.pipeline)
return entity_linker
@pytest.mark.parametrize(
"obj",
[nlp(), vectors(), custom_pipe(), tagger(), entity_linker()],
ids=["nlp", "vectors", "custom_pipe", "tagger", "entity_linker"],
)
def test_to_disk_resource_warning(obj):
with make_tempdir() as d: with make_tempdir() as d:
with warnings.catch_warnings(record=True) as w: with warnings.catch_warnings(record=True) as warnings_list:
warnings.filterwarnings("always", category=ResourceWarning) warnings.filterwarnings("always", category=ResourceWarning)
entity_linker.to_disk(d) obj.to_disk(d)
assert len(w) == 0 assert len(warnings_list) == 0