From efedccea8da3a71b2383e53feba552628d3ad770 Mon Sep 17 00:00:00 2001 From: svlandeg Date: Wed, 7 Oct 2020 15:29:52 +0200 Subject: [PATCH] fix tests --- spacy/tests/regression/test_issue5230.py | 4 ++-- spacy/tests/serialize/test_serialize_kb.py | 13 +++++++++---- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/spacy/tests/regression/test_issue5230.py b/spacy/tests/regression/test_issue5230.py index 5e320996a..aa4cc9be1 100644 --- a/spacy/tests/regression/test_issue5230.py +++ b/spacy/tests/regression/test_issue5230.py @@ -80,8 +80,8 @@ def entity_linker(): return create_kb - config = {"kb_loader": {"@misc": "TestIssue5230KB.v1"}} - entity_linker = nlp.add_pipe("entity_linker", config=config) + init_config = {"kb_loader": {"@misc": "TestIssue5230KB.v1"}} + entity_linker = nlp.add_pipe("entity_linker", init_config=init_config) # need to add model for two reasons: # 1. no model leads to error in serialization, # 2. the affected line is the one for model serialization diff --git a/spacy/tests/serialize/test_serialize_kb.py b/spacy/tests/serialize/test_serialize_kb.py index 63736418b..84e7c8ec2 100644 --- a/spacy/tests/serialize/test_serialize_kb.py +++ b/spacy/tests/serialize/test_serialize_kb.py @@ -6,6 +6,7 @@ from spacy.util import ensure_path, registry from spacy.kb import KnowledgeBase from ..util import make_tempdir +from numpy import zeros def test_serialize_kb_disk(en_vocab): @@ -90,11 +91,13 @@ def test_serialize_subclassed_kb(): entity_vector_length: int, custom_field: int ) -> Callable[["Vocab"], KnowledgeBase]: def custom_kb_factory(vocab): - return SubKnowledgeBase( + kb = SubKnowledgeBase( vocab=vocab, entity_vector_length=entity_vector_length, custom_field=custom_field, ) + kb.add_entity("random_entity", 0.0, zeros(entity_vector_length)) + return kb return custom_kb_factory @@ -106,7 +109,8 @@ def test_serialize_subclassed_kb(): "custom_field": 666, } } - entity_linker = nlp.add_pipe("entity_linker", config=config) + entity_linker = nlp.add_pipe("entity_linker", init_config=config) + nlp.initialize() assert type(entity_linker.kb) == SubKnowledgeBase assert entity_linker.kb.entity_vector_length == 342 assert entity_linker.kb.custom_field == 666 @@ -116,6 +120,7 @@ def test_serialize_subclassed_kb(): nlp.to_disk(tmp_dir) nlp2 = util.load_model_from_path(tmp_dir) entity_linker2 = nlp2.get_pipe("entity_linker") - assert type(entity_linker2.kb) == SubKnowledgeBase + # After IO, the KB is the standard one + assert type(entity_linker2.kb) == KnowledgeBase assert entity_linker2.kb.entity_vector_length == 342 - assert entity_linker2.kb.custom_field == 666 + assert not hasattr(entity_linker2.kb, "custom_field")