add test for vocab after serializing KB

This commit is contained in:
svlandeg 2020-10-10 20:59:48 +02:00
parent 539b0c10da
commit 68d79796c6
2 changed files with 28 additions and 2 deletions

View File

@ -243,7 +243,7 @@ cdef class TrainablePipe(Pipe):
def _validate_serialization_attrs(self): def _validate_serialization_attrs(self):
"""Check that the pipe implements the required attributes. If a subclass """Check that the pipe implements the required attributes. If a subclass
implements a custom __init__ method but doesn't set these attributes, implements a custom __init__ method but doesn't set these attributes,
the currently default to None, so we need to perform additonal checks. they currently default to None, so we need to perform additonal checks.
""" """
if not hasattr(self, "vocab") or self.vocab is None: if not hasattr(self, "vocab") or self.vocab is None:
raise ValueError(Errors.E899.format(name=util.get_object_name(self))) raise ValueError(Errors.E899.format(name=util.get_object_name(self)))

View File

@ -5,6 +5,7 @@ from spacy.kb import KnowledgeBase, get_candidates, Candidate
from spacy.vocab import Vocab from spacy.vocab import Vocab
from spacy import util, registry from spacy import util, registry
from spacy.ml import load_kb
from spacy.scorer import Scorer from spacy.scorer import Scorer
from spacy.training import Example from spacy.training import Example
from spacy.lang.en import English from spacy.lang.en import English
@ -215,7 +216,7 @@ def test_el_pipe_configuration(nlp):
return kb return kb
# run an EL pipe without a trained context encoder, to check the candidate generation step only # run an EL pipe without a trained context encoder, to check the candidate generation step only
entity_linker = nlp.add_pipe("entity_linker", config={"incl_context": False},) entity_linker = nlp.add_pipe("entity_linker", config={"incl_context": False})
entity_linker.set_kb(create_kb) entity_linker.set_kb(create_kb)
# With the default get_candidates function, matching is case-sensitive # With the default get_candidates function, matching is case-sensitive
text = "Douglas and douglas are not the same." text = "Douglas and douglas are not the same."
@ -496,6 +497,31 @@ def test_overfitting_IO():
assert predictions == GOLD_entities assert predictions == GOLD_entities
def test_kb_serialization():
# Test that the KB can be used in a pipeline with a different vocab
vector_length = 3
with make_tempdir() as tmp_dir:
kb_dir = tmp_dir / "kb"
nlp1 = English()
assert "Q2146908" not in nlp1.vocab.strings
mykb = KnowledgeBase(nlp1.vocab, entity_vector_length=vector_length)
mykb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3])
mykb.add_alias(alias="Russ Cochran", entities=["Q2146908"], probabilities=[0.8])
assert "Q2146908" in nlp1.vocab.strings
mykb.to_disk(kb_dir)
nlp2 = English()
nlp2.vocab.strings.add("RandomWord")
assert "RandomWord" in nlp2.vocab.strings
assert "Q2146908" not in nlp2.vocab.strings
# Create the Entity Linker component with the KB from file, and check the final vocab
entity_linker = nlp2.add_pipe("entity_linker", last=True)
entity_linker.set_kb(load_kb(kb_dir))
assert "Q2146908" in nlp2.vocab.strings
assert "RandomWord" in nlp2.vocab.strings
def test_scorer_links(): def test_scorer_links():
train_examples = [] train_examples = []
nlp = English() nlp = English()