From 705c4b976e9f07b4a77fb5765bddc13d454033f8 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Thu, 23 Feb 2023 12:03:10 +0100 Subject: [PATCH] Make empty_kb() configurable. --- spacy/ml/models/entity_linker.py | 8 +++++++ spacy/pipeline/entity_linker.py | 11 ++++++--- spacy/tests/serialize/test_serialize_kb.py | 26 ++++++++++++++++++---- 3 files changed, 38 insertions(+), 7 deletions(-) diff --git a/spacy/ml/models/entity_linker.py b/spacy/ml/models/entity_linker.py index 299b6bb52..7332ca199 100644 --- a/spacy/ml/models/entity_linker.py +++ b/spacy/ml/models/entity_linker.py @@ -89,6 +89,14 @@ def load_kb( return kb_from_file +@registry.misc("spacy.EmptyKB.v2") +def empty_kb_for_config() -> Callable[[Vocab, int], KnowledgeBase]: + def empty_kb_factory(vocab: Vocab, entity_vector_length: int): + return InMemoryLookupKB(vocab=vocab, entity_vector_length=entity_vector_length) + + return empty_kb_factory + + @registry.misc("spacy.EmptyKB.v1") def empty_kb( entity_vector_length: int, diff --git a/spacy/pipeline/entity_linker.py b/spacy/pipeline/entity_linker.py index 62845287b..759c35f6f 100644 --- a/spacy/pipeline/entity_linker.py +++ b/spacy/pipeline/entity_linker.py @@ -54,6 +54,7 @@ DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"] "entity_vector_length": 64, "get_candidates": {"@misc": "spacy.CandidateGenerator.v1"}, "get_candidates_batch": {"@misc": "spacy.CandidateBatchGenerator.v1"}, + "generate_empty_kb": {"@misc": "spacy.EmptyKB.v2"}, "overwrite": True, "scorer": {"@scorers": "spacy.entity_linker_scorer.v1"}, "use_gold_ents": True, @@ -80,6 +81,7 @@ def make_entity_linker( get_candidates_batch: Callable[ [KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]] ], + generate_empty_kb: Callable[[Vocab, int], KnowledgeBase], overwrite: bool, scorer: Optional[Callable], use_gold_ents: bool, @@ -101,6 +103,7 @@ def make_entity_linker( get_candidates_batch ( Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]], Iterable[Candidate]] ): Function that produces a list of candidates, given a certain knowledge base and several textual mentions. + generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase. scorer (Optional[Callable]): The scoring method. use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another component must provide entity annotations. @@ -135,6 +138,7 @@ def make_entity_linker( entity_vector_length=entity_vector_length, get_candidates=get_candidates, get_candidates_batch=get_candidates_batch, + generate_empty_kb=generate_empty_kb, overwrite=overwrite, scorer=scorer, use_gold_ents=use_gold_ents, @@ -175,6 +179,7 @@ class EntityLinker(TrainablePipe): get_candidates_batch: Callable[ [KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]] ], + generate_empty_kb: Callable[[Vocab, int], KnowledgeBase], overwrite: bool = BACKWARD_OVERWRITE, scorer: Optional[Callable] = entity_linker_score, use_gold_ents: bool, @@ -198,6 +203,7 @@ class EntityLinker(TrainablePipe): Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]], Iterable[Candidate]] ): Function that produces a list of candidates, given a certain knowledge base and several textual mentions. + generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase. scorer (Optional[Callable]): The scoring method. Defaults to Scorer.score_links. use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another component must provide entity annotations. @@ -220,6 +226,7 @@ class EntityLinker(TrainablePipe): self.model = model self.name = name self.labels_discard = list(labels_discard) + # how many neighbour sentences to take into account self.n_sents = n_sents self.incl_prior = incl_prior self.incl_context = incl_context @@ -227,9 +234,7 @@ class EntityLinker(TrainablePipe): self.get_candidates_batch = get_candidates_batch self.cfg: Dict[str, Any] = {"overwrite": overwrite} self.distance = CosineDistance(normalize=False) - # how many neighbour sentences to take into account - # create an empty KB by default - self.kb = empty_kb(entity_vector_length)(self.vocab) + self.kb = generate_empty_kb(self.vocab, entity_vector_length) self.scorer = scorer self.use_gold_ents = use_gold_ents self.candidates_batch_size = candidates_batch_size diff --git a/spacy/tests/serialize/test_serialize_kb.py b/spacy/tests/serialize/test_serialize_kb.py index 8d3653ab1..7736ad6b1 100644 --- a/spacy/tests/serialize/test_serialize_kb.py +++ b/spacy/tests/serialize/test_serialize_kb.py @@ -86,12 +86,19 @@ def test_serialize_subclassed_kb(): [nlp] lang = "en" pipeline = ["entity_linker"] + + [default_values] + custom_field = 666 [components] [components.entity_linker] factory = "entity_linker" - + + [components.entity_linker.generate_empty_kb] + @misc = "spacy.CustomEmptyKB.v1" + custom_field = ${default_values.custom_field} + [initialize] [initialize.components] @@ -101,7 +108,7 @@ def test_serialize_subclassed_kb(): [initialize.components.entity_linker.kb_loader] @misc = "spacy.CustomKB.v1" entity_vector_length = 342 - custom_field = 666 + custom_field = ${default_values.custom_field} """ class SubInMemoryLookupKB(InMemoryLookupKB): @@ -109,6 +116,17 @@ def test_serialize_subclassed_kb(): super().__init__(vocab, entity_vector_length) self.custom_field = custom_field + @registry.misc("spacy.CustomEmptyKB.v1") + def empty_custom_kb(custom_field: int) -> Callable[[Vocab, int], SubInMemoryLookupKB]: + def empty_kb_factory(vocab: Vocab, entity_vector_length: int): + return SubInMemoryLookupKB( + vocab=vocab, + entity_vector_length=entity_vector_length, + custom_field=custom_field + ) + + return empty_kb_factory + @registry.misc("spacy.CustomKB.v1") def custom_kb( entity_vector_length: int, custom_field: int @@ -139,6 +157,6 @@ def test_serialize_subclassed_kb(): nlp2 = util.load_model_from_path(tmp_dir) entity_linker2 = nlp2.get_pipe("entity_linker") # After IO, the KB is the standard one - assert type(entity_linker2.kb) == InMemoryLookupKB + assert type(entity_linker2.kb) == SubInMemoryLookupKB assert entity_linker2.kb.entity_vector_length == 342 - assert not hasattr(entity_linker2.kb, "custom_field") + assert hasattr(entity_linker2.kb, "custom_field")