spaCy/spacy/tests/regression/test_issue5230.py
Adriane Boyd 9130094199
Prevent Tagger model init with 0 labels (#5984)
* Prevent Tagger model init with 0 labels

Raise an error before trying to initialize a tagger model with 0 labels.

* Add dummy tagger label for test

* Remove tagless tagger model initializiation

* Fix error number after merge

* Add dummy tagger label to test

* Fix formatting

Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com>
2020-08-31 21:24:33 +02:00

151 lines
4.4 KiB
Python

from typing import Callable
import warnings
from unittest import TestCase
import pytest
import srsly
from numpy import zeros
from spacy.kb import KnowledgeBase, Writer
from spacy.vectors import Vectors
from spacy.language import Language
from spacy.pipeline import Pipe
from spacy.util import registry
from ..util import make_tempdir
def nlp():
return Language()
def vectors():
data = zeros((3, 1), dtype="f")
keys = ["cat", "dog", "rat"]
return Vectors(data=data, keys=keys)
def custom_pipe():
# create dummy pipe partially implementing interface -- only want to test to_disk
class SerializableDummy:
def __init__(self, **cfg):
if cfg:
self.cfg = cfg
else:
self.cfg = None
super(SerializableDummy, self).__init__()
def to_bytes(self, exclude=tuple(), disable=None, **kwargs):
return srsly.msgpack_dumps({"dummy": srsly.json_dumps(None)})
def from_bytes(self, bytes_data, exclude):
return self
def to_disk(self, path, exclude=tuple(), **kwargs):
pass
def from_disk(self, path, exclude=tuple(), **kwargs):
return self
class MyPipe(Pipe):
def __init__(self, vocab, model=True, **cfg):
if cfg:
self.cfg = cfg
else:
self.cfg = None
self.model = SerializableDummy()
self.vocab = SerializableDummy()
return MyPipe(None)
def tagger():
nlp = Language()
tagger = nlp.add_pipe("tagger")
# need to add model for two reasons:
# 1. no model leads to error in serialization,
# 2. the affected line is the one for model serialization
tagger.add_label("A")
tagger.begin_training(lambda: [], pipeline=nlp.pipeline)
return tagger
def entity_linker():
nlp = Language()
@registry.assets.register("TestIssue5230KB.v1")
def dummy_kb() -> Callable[["Vocab"], KnowledgeBase]:
def create_kb(vocab):
kb = KnowledgeBase(vocab, entity_vector_length=1)
kb.add_entity("test", 0.0, zeros((1, 1), dtype="f"))
return kb
return create_kb
config = {"kb_loader": {"@assets": "TestIssue5230KB.v1"}}
entity_linker = nlp.add_pipe("entity_linker", config=config)
# need to add model for two reasons:
# 1. no model leads to error in serialization,
# 2. the affected line is the one for model serialization
entity_linker.begin_training(lambda: [], pipeline=nlp.pipeline)
return entity_linker
objects_to_test = (
[nlp(), vectors(), custom_pipe(), tagger(), entity_linker()],
["nlp", "vectors", "custom_pipe", "tagger", "entity_linker"],
)
def write_obj_and_catch_warnings(obj):
with make_tempdir() as d:
with warnings.catch_warnings(record=True) as warnings_list:
warnings.filterwarnings("always", category=ResourceWarning)
obj.to_disk(d)
# in python3.5 it seems that deprecation warnings are not filtered by filterwarnings
return list(filter(lambda x: isinstance(x, ResourceWarning), warnings_list))
@pytest.mark.parametrize("obj", objects_to_test[0], ids=objects_to_test[1])
def test_to_disk_resource_warning(obj):
warnings_list = write_obj_and_catch_warnings(obj)
assert len(warnings_list) == 0
def test_writer_with_path_py35():
writer = None
with make_tempdir() as d:
path = d / "test"
try:
writer = Writer(path)
except Exception as e:
pytest.fail(str(e))
finally:
if writer:
writer.close()
def test_save_and_load_knowledge_base():
nlp = Language()
kb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
with make_tempdir() as d:
path = d / "kb"
try:
kb.to_disk(path)
except Exception as e:
pytest.fail(str(e))
try:
kb_loaded = KnowledgeBase(nlp.vocab, entity_vector_length=1)
kb_loaded.from_disk(path)
except Exception as e:
pytest.fail(str(e))
class TestToDiskResourceWarningUnittest(TestCase):
def test_resource_warning(self):
scenarios = zip(*objects_to_test)
for scenario in scenarios:
with self.subTest(msg=scenario[1]):
warnings_list = write_obj_and_catch_warnings(scenario[0])
self.assertEqual(len(warnings_list), 0)