mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-14 03:26:24 +03:00
add test for vocab after serializing KB
This commit is contained in:
parent
539b0c10da
commit
68d79796c6
|
@ -243,7 +243,7 @@ cdef class TrainablePipe(Pipe):
|
||||||
def _validate_serialization_attrs(self):
|
def _validate_serialization_attrs(self):
|
||||||
"""Check that the pipe implements the required attributes. If a subclass
|
"""Check that the pipe implements the required attributes. If a subclass
|
||||||
implements a custom __init__ method but doesn't set these attributes,
|
implements a custom __init__ method but doesn't set these attributes,
|
||||||
the currently default to None, so we need to perform additonal checks.
|
they currently default to None, so we need to perform additonal checks.
|
||||||
"""
|
"""
|
||||||
if not hasattr(self, "vocab") or self.vocab is None:
|
if not hasattr(self, "vocab") or self.vocab is None:
|
||||||
raise ValueError(Errors.E899.format(name=util.get_object_name(self)))
|
raise ValueError(Errors.E899.format(name=util.get_object_name(self)))
|
||||||
|
|
|
@ -5,6 +5,7 @@ from spacy.kb import KnowledgeBase, get_candidates, Candidate
|
||||||
from spacy.vocab import Vocab
|
from spacy.vocab import Vocab
|
||||||
|
|
||||||
from spacy import util, registry
|
from spacy import util, registry
|
||||||
|
from spacy.ml import load_kb
|
||||||
from spacy.scorer import Scorer
|
from spacy.scorer import Scorer
|
||||||
from spacy.training import Example
|
from spacy.training import Example
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
|
@ -215,7 +216,7 @@ def test_el_pipe_configuration(nlp):
|
||||||
return kb
|
return kb
|
||||||
|
|
||||||
# run an EL pipe without a trained context encoder, to check the candidate generation step only
|
# run an EL pipe without a trained context encoder, to check the candidate generation step only
|
||||||
entity_linker = nlp.add_pipe("entity_linker", config={"incl_context": False},)
|
entity_linker = nlp.add_pipe("entity_linker", config={"incl_context": False})
|
||||||
entity_linker.set_kb(create_kb)
|
entity_linker.set_kb(create_kb)
|
||||||
# With the default get_candidates function, matching is case-sensitive
|
# With the default get_candidates function, matching is case-sensitive
|
||||||
text = "Douglas and douglas are not the same."
|
text = "Douglas and douglas are not the same."
|
||||||
|
@ -496,6 +497,31 @@ def test_overfitting_IO():
|
||||||
assert predictions == GOLD_entities
|
assert predictions == GOLD_entities
|
||||||
|
|
||||||
|
|
||||||
|
def test_kb_serialization():
|
||||||
|
# Test that the KB can be used in a pipeline with a different vocab
|
||||||
|
vector_length = 3
|
||||||
|
with make_tempdir() as tmp_dir:
|
||||||
|
kb_dir = tmp_dir / "kb"
|
||||||
|
nlp1 = English()
|
||||||
|
assert "Q2146908" not in nlp1.vocab.strings
|
||||||
|
mykb = KnowledgeBase(nlp1.vocab, entity_vector_length=vector_length)
|
||||||
|
mykb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3])
|
||||||
|
mykb.add_alias(alias="Russ Cochran", entities=["Q2146908"], probabilities=[0.8])
|
||||||
|
assert "Q2146908" in nlp1.vocab.strings
|
||||||
|
mykb.to_disk(kb_dir)
|
||||||
|
|
||||||
|
nlp2 = English()
|
||||||
|
nlp2.vocab.strings.add("RandomWord")
|
||||||
|
assert "RandomWord" in nlp2.vocab.strings
|
||||||
|
assert "Q2146908" not in nlp2.vocab.strings
|
||||||
|
|
||||||
|
# Create the Entity Linker component with the KB from file, and check the final vocab
|
||||||
|
entity_linker = nlp2.add_pipe("entity_linker", last=True)
|
||||||
|
entity_linker.set_kb(load_kb(kb_dir))
|
||||||
|
assert "Q2146908" in nlp2.vocab.strings
|
||||||
|
assert "RandomWord" in nlp2.vocab.strings
|
||||||
|
|
||||||
|
|
||||||
def test_scorer_links():
|
def test_scorer_links():
|
||||||
train_examples = []
|
train_examples = []
|
||||||
nlp = English()
|
nlp = English()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user