Merge pull request #6238 from svlandeg/feature/kb-vocab-test

This commit is contained in:
Ines Montani 2020-10-11 11:55:03 +02:00 committed by GitHub
commit 4430b4b809
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 29 additions and 2 deletions

View File

@ -243,7 +243,7 @@ cdef class TrainablePipe(Pipe):
def _validate_serialization_attrs(self):
"""Check that the pipe implements the required attributes. If a subclass
implements a custom __init__ method but doesn't set these attributes,
the currently default to None, so we need to perform additonal checks.
they currently default to None, so we need to perform additonal checks.
"""
if not hasattr(self, "vocab") or self.vocab is None:
raise ValueError(Errors.E899.format(name=util.get_object_name(self)))

View File

@ -5,6 +5,7 @@ from spacy.kb import KnowledgeBase, get_candidates, Candidate
from spacy.vocab import Vocab
from spacy import util, registry
from spacy.ml import load_kb
from spacy.scorer import Scorer
from spacy.training import Example
from spacy.lang.en import English
@ -215,7 +216,7 @@ def test_el_pipe_configuration(nlp):
return kb
# run an EL pipe without a trained context encoder, to check the candidate generation step only
entity_linker = nlp.add_pipe("entity_linker", config={"incl_context": False},)
entity_linker = nlp.add_pipe("entity_linker", config={"incl_context": False})
entity_linker.set_kb(create_kb)
# With the default get_candidates function, matching is case-sensitive
text = "Douglas and douglas are not the same."
@ -496,6 +497,32 @@ def test_overfitting_IO():
assert predictions == GOLD_entities
def test_kb_serialization():
# Test that the KB can be used in a pipeline with a different vocab
vector_length = 3
with make_tempdir() as tmp_dir:
kb_dir = tmp_dir / "kb"
nlp1 = English()
assert "Q2146908" not in nlp1.vocab.strings
mykb = KnowledgeBase(nlp1.vocab, entity_vector_length=vector_length)
mykb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3])
mykb.add_alias(alias="Russ Cochran", entities=["Q2146908"], probabilities=[0.8])
assert "Q2146908" in nlp1.vocab.strings
mykb.to_disk(kb_dir)
nlp2 = English()
assert "RandomWord" not in nlp2.vocab.strings
nlp2.vocab.strings.add("RandomWord")
assert "RandomWord" in nlp2.vocab.strings
assert "Q2146908" not in nlp2.vocab.strings
# Create the Entity Linker component with the KB from file, and check the final vocab
entity_linker = nlp2.add_pipe("entity_linker", last=True)
entity_linker.set_kb(load_kb(kb_dir))
assert "Q2146908" in nlp2.vocab.strings
assert "RandomWord" in nlp2.vocab.strings
def test_scorer_links():
train_examples = []
nlp = English()