mirror of
https://github.com/explosion/spaCy.git
synced 2025-04-22 18:12:00 +03:00
Make empty_kb() configurable.
This commit is contained in:
parent
1e8bac99f3
commit
705c4b976e
|
@ -89,6 +89,14 @@ def load_kb(
|
|||
return kb_from_file
|
||||
|
||||
|
||||
@registry.misc("spacy.EmptyKB.v2")
|
||||
def empty_kb_for_config() -> Callable[[Vocab, int], KnowledgeBase]:
|
||||
def empty_kb_factory(vocab: Vocab, entity_vector_length: int):
|
||||
return InMemoryLookupKB(vocab=vocab, entity_vector_length=entity_vector_length)
|
||||
|
||||
return empty_kb_factory
|
||||
|
||||
|
||||
@registry.misc("spacy.EmptyKB.v1")
|
||||
def empty_kb(
|
||||
entity_vector_length: int,
|
||||
|
|
|
@ -54,6 +54,7 @@ DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"]
|
|||
"entity_vector_length": 64,
|
||||
"get_candidates": {"@misc": "spacy.CandidateGenerator.v1"},
|
||||
"get_candidates_batch": {"@misc": "spacy.CandidateBatchGenerator.v1"},
|
||||
"generate_empty_kb": {"@misc": "spacy.EmptyKB.v2"},
|
||||
"overwrite": True,
|
||||
"scorer": {"@scorers": "spacy.entity_linker_scorer.v1"},
|
||||
"use_gold_ents": True,
|
||||
|
@ -80,6 +81,7 @@ def make_entity_linker(
|
|||
get_candidates_batch: Callable[
|
||||
[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]
|
||||
],
|
||||
generate_empty_kb: Callable[[Vocab, int], KnowledgeBase],
|
||||
overwrite: bool,
|
||||
scorer: Optional[Callable],
|
||||
use_gold_ents: bool,
|
||||
|
@ -101,6 +103,7 @@ def make_entity_linker(
|
|||
get_candidates_batch (
|
||||
Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]], Iterable[Candidate]]
|
||||
): Function that produces a list of candidates, given a certain knowledge base and several textual mentions.
|
||||
generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase.
|
||||
scorer (Optional[Callable]): The scoring method.
|
||||
use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another
|
||||
component must provide entity annotations.
|
||||
|
@ -135,6 +138,7 @@ def make_entity_linker(
|
|||
entity_vector_length=entity_vector_length,
|
||||
get_candidates=get_candidates,
|
||||
get_candidates_batch=get_candidates_batch,
|
||||
generate_empty_kb=generate_empty_kb,
|
||||
overwrite=overwrite,
|
||||
scorer=scorer,
|
||||
use_gold_ents=use_gold_ents,
|
||||
|
@ -175,6 +179,7 @@ class EntityLinker(TrainablePipe):
|
|||
get_candidates_batch: Callable[
|
||||
[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]
|
||||
],
|
||||
generate_empty_kb: Callable[[Vocab, int], KnowledgeBase],
|
||||
overwrite: bool = BACKWARD_OVERWRITE,
|
||||
scorer: Optional[Callable] = entity_linker_score,
|
||||
use_gold_ents: bool,
|
||||
|
@ -198,6 +203,7 @@ class EntityLinker(TrainablePipe):
|
|||
Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]],
|
||||
Iterable[Candidate]]
|
||||
): Function that produces a list of candidates, given a certain knowledge base and several textual mentions.
|
||||
generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase.
|
||||
scorer (Optional[Callable]): The scoring method. Defaults to Scorer.score_links.
|
||||
use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another
|
||||
component must provide entity annotations.
|
||||
|
@ -220,6 +226,7 @@ class EntityLinker(TrainablePipe):
|
|||
self.model = model
|
||||
self.name = name
|
||||
self.labels_discard = list(labels_discard)
|
||||
# how many neighbour sentences to take into account
|
||||
self.n_sents = n_sents
|
||||
self.incl_prior = incl_prior
|
||||
self.incl_context = incl_context
|
||||
|
@ -227,9 +234,7 @@ class EntityLinker(TrainablePipe):
|
|||
self.get_candidates_batch = get_candidates_batch
|
||||
self.cfg: Dict[str, Any] = {"overwrite": overwrite}
|
||||
self.distance = CosineDistance(normalize=False)
|
||||
# how many neighbour sentences to take into account
|
||||
# create an empty KB by default
|
||||
self.kb = empty_kb(entity_vector_length)(self.vocab)
|
||||
self.kb = generate_empty_kb(self.vocab, entity_vector_length)
|
||||
self.scorer = scorer
|
||||
self.use_gold_ents = use_gold_ents
|
||||
self.candidates_batch_size = candidates_batch_size
|
||||
|
|
|
@ -86,12 +86,19 @@ def test_serialize_subclassed_kb():
|
|||
[nlp]
|
||||
lang = "en"
|
||||
pipeline = ["entity_linker"]
|
||||
|
||||
[default_values]
|
||||
custom_field = 666
|
||||
|
||||
[components]
|
||||
|
||||
[components.entity_linker]
|
||||
factory = "entity_linker"
|
||||
|
||||
|
||||
[components.entity_linker.generate_empty_kb]
|
||||
@misc = "spacy.CustomEmptyKB.v1"
|
||||
custom_field = ${default_values.custom_field}
|
||||
|
||||
[initialize]
|
||||
|
||||
[initialize.components]
|
||||
|
@ -101,7 +108,7 @@ def test_serialize_subclassed_kb():
|
|||
[initialize.components.entity_linker.kb_loader]
|
||||
@misc = "spacy.CustomKB.v1"
|
||||
entity_vector_length = 342
|
||||
custom_field = 666
|
||||
custom_field = ${default_values.custom_field}
|
||||
"""
|
||||
|
||||
class SubInMemoryLookupKB(InMemoryLookupKB):
|
||||
|
@ -109,6 +116,17 @@ def test_serialize_subclassed_kb():
|
|||
super().__init__(vocab, entity_vector_length)
|
||||
self.custom_field = custom_field
|
||||
|
||||
@registry.misc("spacy.CustomEmptyKB.v1")
|
||||
def empty_custom_kb(custom_field: int) -> Callable[[Vocab, int], SubInMemoryLookupKB]:
|
||||
def empty_kb_factory(vocab: Vocab, entity_vector_length: int):
|
||||
return SubInMemoryLookupKB(
|
||||
vocab=vocab,
|
||||
entity_vector_length=entity_vector_length,
|
||||
custom_field=custom_field
|
||||
)
|
||||
|
||||
return empty_kb_factory
|
||||
|
||||
@registry.misc("spacy.CustomKB.v1")
|
||||
def custom_kb(
|
||||
entity_vector_length: int, custom_field: int
|
||||
|
@ -139,6 +157,6 @@ def test_serialize_subclassed_kb():
|
|||
nlp2 = util.load_model_from_path(tmp_dir)
|
||||
entity_linker2 = nlp2.get_pipe("entity_linker")
|
||||
# After IO, the KB is the standard one
|
||||
assert type(entity_linker2.kb) == InMemoryLookupKB
|
||||
assert type(entity_linker2.kb) == SubInMemoryLookupKB
|
||||
assert entity_linker2.kb.entity_vector_length == 342
|
||||
assert not hasattr(entity_linker2.kb, "custom_field")
|
||||
assert hasattr(entity_linker2.kb, "custom_field")
|
||||
|
|
Loading…
Reference in New Issue
Block a user