mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 01:04:34 +03:00
fix tests
This commit is contained in:
parent
6b8bdb2d39
commit
efedccea8d
|
@ -80,8 +80,8 @@ def entity_linker():
|
||||||
|
|
||||||
return create_kb
|
return create_kb
|
||||||
|
|
||||||
config = {"kb_loader": {"@misc": "TestIssue5230KB.v1"}}
|
init_config = {"kb_loader": {"@misc": "TestIssue5230KB.v1"}}
|
||||||
entity_linker = nlp.add_pipe("entity_linker", config=config)
|
entity_linker = nlp.add_pipe("entity_linker", init_config=init_config)
|
||||||
# need to add model for two reasons:
|
# need to add model for two reasons:
|
||||||
# 1. no model leads to error in serialization,
|
# 1. no model leads to error in serialization,
|
||||||
# 2. the affected line is the one for model serialization
|
# 2. the affected line is the one for model serialization
|
||||||
|
|
|
@ -6,6 +6,7 @@ from spacy.util import ensure_path, registry
|
||||||
from spacy.kb import KnowledgeBase
|
from spacy.kb import KnowledgeBase
|
||||||
|
|
||||||
from ..util import make_tempdir
|
from ..util import make_tempdir
|
||||||
|
from numpy import zeros
|
||||||
|
|
||||||
|
|
||||||
def test_serialize_kb_disk(en_vocab):
|
def test_serialize_kb_disk(en_vocab):
|
||||||
|
@ -90,11 +91,13 @@ def test_serialize_subclassed_kb():
|
||||||
entity_vector_length: int, custom_field: int
|
entity_vector_length: int, custom_field: int
|
||||||
) -> Callable[["Vocab"], KnowledgeBase]:
|
) -> Callable[["Vocab"], KnowledgeBase]:
|
||||||
def custom_kb_factory(vocab):
|
def custom_kb_factory(vocab):
|
||||||
return SubKnowledgeBase(
|
kb = SubKnowledgeBase(
|
||||||
vocab=vocab,
|
vocab=vocab,
|
||||||
entity_vector_length=entity_vector_length,
|
entity_vector_length=entity_vector_length,
|
||||||
custom_field=custom_field,
|
custom_field=custom_field,
|
||||||
)
|
)
|
||||||
|
kb.add_entity("random_entity", 0.0, zeros(entity_vector_length))
|
||||||
|
return kb
|
||||||
|
|
||||||
return custom_kb_factory
|
return custom_kb_factory
|
||||||
|
|
||||||
|
@ -106,7 +109,8 @@ def test_serialize_subclassed_kb():
|
||||||
"custom_field": 666,
|
"custom_field": 666,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
entity_linker = nlp.add_pipe("entity_linker", config=config)
|
entity_linker = nlp.add_pipe("entity_linker", init_config=config)
|
||||||
|
nlp.initialize()
|
||||||
assert type(entity_linker.kb) == SubKnowledgeBase
|
assert type(entity_linker.kb) == SubKnowledgeBase
|
||||||
assert entity_linker.kb.entity_vector_length == 342
|
assert entity_linker.kb.entity_vector_length == 342
|
||||||
assert entity_linker.kb.custom_field == 666
|
assert entity_linker.kb.custom_field == 666
|
||||||
|
@ -116,6 +120,7 @@ def test_serialize_subclassed_kb():
|
||||||
nlp.to_disk(tmp_dir)
|
nlp.to_disk(tmp_dir)
|
||||||
nlp2 = util.load_model_from_path(tmp_dir)
|
nlp2 = util.load_model_from_path(tmp_dir)
|
||||||
entity_linker2 = nlp2.get_pipe("entity_linker")
|
entity_linker2 = nlp2.get_pipe("entity_linker")
|
||||||
assert type(entity_linker2.kb) == SubKnowledgeBase
|
# After IO, the KB is the standard one
|
||||||
|
assert type(entity_linker2.kb) == KnowledgeBase
|
||||||
assert entity_linker2.kb.entity_vector_length == 342
|
assert entity_linker2.kb.entity_vector_length == 342
|
||||||
assert entity_linker2.kb.custom_field == 666
|
assert not hasattr(entity_linker2.kb, "custom_field")
|
||||||
|
|
Loading…
Reference in New Issue
Block a user