fix tests

This commit is contained in:
svlandeg 2020-10-07 15:29:52 +02:00
parent 6b8bdb2d39
commit efedccea8d
2 changed files with 11 additions and 6 deletions

View File

@ -80,8 +80,8 @@ def entity_linker():
return create_kb
config = {"kb_loader": {"@misc": "TestIssue5230KB.v1"}}
entity_linker = nlp.add_pipe("entity_linker", config=config)
init_config = {"kb_loader": {"@misc": "TestIssue5230KB.v1"}}
entity_linker = nlp.add_pipe("entity_linker", init_config=init_config)
# need to add model for two reasons:
# 1. no model leads to error in serialization,
# 2. the affected line is the one for model serialization

View File

@ -6,6 +6,7 @@ from spacy.util import ensure_path, registry
from spacy.kb import KnowledgeBase
from ..util import make_tempdir
from numpy import zeros
def test_serialize_kb_disk(en_vocab):
@ -90,11 +91,13 @@ def test_serialize_subclassed_kb():
entity_vector_length: int, custom_field: int
) -> Callable[["Vocab"], KnowledgeBase]:
def custom_kb_factory(vocab):
return SubKnowledgeBase(
kb = SubKnowledgeBase(
vocab=vocab,
entity_vector_length=entity_vector_length,
custom_field=custom_field,
)
kb.add_entity("random_entity", 0.0, zeros(entity_vector_length))
return kb
return custom_kb_factory
@ -106,7 +109,8 @@ def test_serialize_subclassed_kb():
"custom_field": 666,
}
}
entity_linker = nlp.add_pipe("entity_linker", config=config)
entity_linker = nlp.add_pipe("entity_linker", init_config=config)
nlp.initialize()
assert type(entity_linker.kb) == SubKnowledgeBase
assert entity_linker.kb.entity_vector_length == 342
assert entity_linker.kb.custom_field == 666
@ -116,6 +120,7 @@ def test_serialize_subclassed_kb():
nlp.to_disk(tmp_dir)
nlp2 = util.load_model_from_path(tmp_dir)
entity_linker2 = nlp2.get_pipe("entity_linker")
assert type(entity_linker2.kb) == SubKnowledgeBase
# After IO, the KB is the standard one
assert type(entity_linker2.kb) == KnowledgeBase
assert entity_linker2.kb.entity_vector_length == 342
assert entity_linker2.kb.custom_field == 666
assert not hasattr(entity_linker2.kb, "custom_field")