mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 02:06:31 +03:00
fix tests
This commit is contained in:
parent
6b8bdb2d39
commit
efedccea8d
|
@ -80,8 +80,8 @@ def entity_linker():
|
|||
|
||||
return create_kb
|
||||
|
||||
config = {"kb_loader": {"@misc": "TestIssue5230KB.v1"}}
|
||||
entity_linker = nlp.add_pipe("entity_linker", config=config)
|
||||
init_config = {"kb_loader": {"@misc": "TestIssue5230KB.v1"}}
|
||||
entity_linker = nlp.add_pipe("entity_linker", init_config=init_config)
|
||||
# need to add model for two reasons:
|
||||
# 1. no model leads to error in serialization,
|
||||
# 2. the affected line is the one for model serialization
|
||||
|
|
|
@ -6,6 +6,7 @@ from spacy.util import ensure_path, registry
|
|||
from spacy.kb import KnowledgeBase
|
||||
|
||||
from ..util import make_tempdir
|
||||
from numpy import zeros
|
||||
|
||||
|
||||
def test_serialize_kb_disk(en_vocab):
|
||||
|
@ -90,11 +91,13 @@ def test_serialize_subclassed_kb():
|
|||
entity_vector_length: int, custom_field: int
|
||||
) -> Callable[["Vocab"], KnowledgeBase]:
|
||||
def custom_kb_factory(vocab):
|
||||
return SubKnowledgeBase(
|
||||
kb = SubKnowledgeBase(
|
||||
vocab=vocab,
|
||||
entity_vector_length=entity_vector_length,
|
||||
custom_field=custom_field,
|
||||
)
|
||||
kb.add_entity("random_entity", 0.0, zeros(entity_vector_length))
|
||||
return kb
|
||||
|
||||
return custom_kb_factory
|
||||
|
||||
|
@ -106,7 +109,8 @@ def test_serialize_subclassed_kb():
|
|||
"custom_field": 666,
|
||||
}
|
||||
}
|
||||
entity_linker = nlp.add_pipe("entity_linker", config=config)
|
||||
entity_linker = nlp.add_pipe("entity_linker", init_config=config)
|
||||
nlp.initialize()
|
||||
assert type(entity_linker.kb) == SubKnowledgeBase
|
||||
assert entity_linker.kb.entity_vector_length == 342
|
||||
assert entity_linker.kb.custom_field == 666
|
||||
|
@ -116,6 +120,7 @@ def test_serialize_subclassed_kb():
|
|||
nlp.to_disk(tmp_dir)
|
||||
nlp2 = util.load_model_from_path(tmp_dir)
|
||||
entity_linker2 = nlp2.get_pipe("entity_linker")
|
||||
assert type(entity_linker2.kb) == SubKnowledgeBase
|
||||
# After IO, the KB is the standard one
|
||||
assert type(entity_linker2.kb) == KnowledgeBase
|
||||
assert entity_linker2.kb.entity_vector_length == 342
|
||||
assert entity_linker2.kb.custom_field == 666
|
||||
assert not hasattr(entity_linker2.kb, "custom_field")
|
||||
|
|
Loading…
Reference in New Issue
Block a user