mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-26 05:31:15 +03:00 
			
		
		
		
	Make generation of empty KnowledgeBase instances configurable in EntityLinker (#12320)
				
					
				
			* Make empty_kb() configurable. * Format. * Update docs. * Be more specific in KB serialization test. * Update KB serialization tests. Update docs. * Remove doc update for batched candidate generation. * Fix serialization of subclassed KB in tests. * Format. * Update docstring. * Update docstring. * Switch from pickle to json for custom field serialization.
This commit is contained in:
		
							parent
							
								
									e325de3ff8
								
							
						
					
					
						commit
						6f1632b3e9
					
				|  | @ -89,6 +89,14 @@ def load_kb( | |||
|     return kb_from_file | ||||
| 
 | ||||
| 
 | ||||
| @registry.misc("spacy.EmptyKB.v2") | ||||
| def empty_kb_for_config() -> Callable[[Vocab, int], KnowledgeBase]: | ||||
|     def empty_kb_factory(vocab: Vocab, entity_vector_length: int): | ||||
|         return InMemoryLookupKB(vocab=vocab, entity_vector_length=entity_vector_length) | ||||
| 
 | ||||
|     return empty_kb_factory | ||||
| 
 | ||||
| 
 | ||||
| @registry.misc("spacy.EmptyKB.v1") | ||||
| def empty_kb( | ||||
|     entity_vector_length: int, | ||||
|  |  | |||
|  | @ -54,6 +54,7 @@ DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"] | |||
|         "entity_vector_length": 64, | ||||
|         "get_candidates": {"@misc": "spacy.CandidateGenerator.v1"}, | ||||
|         "get_candidates_batch": {"@misc": "spacy.CandidateBatchGenerator.v1"}, | ||||
|         "generate_empty_kb": {"@misc": "spacy.EmptyKB.v2"}, | ||||
|         "overwrite": True, | ||||
|         "scorer": {"@scorers": "spacy.entity_linker_scorer.v1"}, | ||||
|         "use_gold_ents": True, | ||||
|  | @ -80,6 +81,7 @@ def make_entity_linker( | |||
|     get_candidates_batch: Callable[ | ||||
|         [KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]] | ||||
|     ], | ||||
|     generate_empty_kb: Callable[[Vocab, int], KnowledgeBase], | ||||
|     overwrite: bool, | ||||
|     scorer: Optional[Callable], | ||||
|     use_gold_ents: bool, | ||||
|  | @ -101,6 +103,7 @@ def make_entity_linker( | |||
|     get_candidates_batch ( | ||||
|         Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]], Iterable[Candidate]] | ||||
|         ): Function that produces a list of candidates, given a certain knowledge base and several textual mentions. | ||||
|     generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase. | ||||
|     scorer (Optional[Callable]): The scoring method. | ||||
|     use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another | ||||
|         component must provide entity annotations. | ||||
|  | @ -135,6 +138,7 @@ def make_entity_linker( | |||
|         entity_vector_length=entity_vector_length, | ||||
|         get_candidates=get_candidates, | ||||
|         get_candidates_batch=get_candidates_batch, | ||||
|         generate_empty_kb=generate_empty_kb, | ||||
|         overwrite=overwrite, | ||||
|         scorer=scorer, | ||||
|         use_gold_ents=use_gold_ents, | ||||
|  | @ -175,6 +179,7 @@ class EntityLinker(TrainablePipe): | |||
|         get_candidates_batch: Callable[ | ||||
|             [KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]] | ||||
|         ], | ||||
|         generate_empty_kb: Callable[[Vocab, int], KnowledgeBase], | ||||
|         overwrite: bool = BACKWARD_OVERWRITE, | ||||
|         scorer: Optional[Callable] = entity_linker_score, | ||||
|         use_gold_ents: bool, | ||||
|  | @ -198,6 +203,7 @@ class EntityLinker(TrainablePipe): | |||
|             Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]], | ||||
|             Iterable[Candidate]] | ||||
|             ): Function that produces a list of candidates, given a certain knowledge base and several textual mentions. | ||||
|         generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase. | ||||
|         scorer (Optional[Callable]): The scoring method. Defaults to Scorer.score_links. | ||||
|         use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another | ||||
|             component must provide entity annotations. | ||||
|  | @ -220,6 +226,7 @@ class EntityLinker(TrainablePipe): | |||
|         self.model = model | ||||
|         self.name = name | ||||
|         self.labels_discard = list(labels_discard) | ||||
|         # how many neighbour sentences to take into account | ||||
|         self.n_sents = n_sents | ||||
|         self.incl_prior = incl_prior | ||||
|         self.incl_context = incl_context | ||||
|  | @ -227,9 +234,7 @@ class EntityLinker(TrainablePipe): | |||
|         self.get_candidates_batch = get_candidates_batch | ||||
|         self.cfg: Dict[str, Any] = {"overwrite": overwrite} | ||||
|         self.distance = CosineDistance(normalize=False) | ||||
|         # how many neighbour sentences to take into account | ||||
|         # create an empty KB by default | ||||
|         self.kb = empty_kb(entity_vector_length)(self.vocab) | ||||
|         self.kb = generate_empty_kb(self.vocab, entity_vector_length) | ||||
|         self.scorer = scorer | ||||
|         self.use_gold_ents = use_gold_ents | ||||
|         self.candidates_batch_size = candidates_batch_size | ||||
|  |  | |||
|  | @ -1,7 +1,10 @@ | |||
| from typing import Callable | ||||
| from pathlib import Path | ||||
| from typing import Callable, Iterable, Any, Dict | ||||
| 
 | ||||
| from spacy import util | ||||
| from spacy.util import ensure_path, registry, load_model_from_config | ||||
| import srsly | ||||
| 
 | ||||
| from spacy import util, Errors | ||||
| from spacy.util import ensure_path, registry, load_model_from_config, SimpleFrozenList | ||||
| from spacy.kb.kb_in_memory import InMemoryLookupKB | ||||
| from spacy.vocab import Vocab | ||||
| from thinc.api import Config | ||||
|  | @ -91,7 +94,10 @@ def test_serialize_subclassed_kb(): | |||
| 
 | ||||
|     [components.entity_linker] | ||||
|     factory = "entity_linker" | ||||
| 
 | ||||
|      | ||||
|     [components.entity_linker.generate_empty_kb] | ||||
|     @misc = "kb_test.CustomEmptyKB.v1" | ||||
|      | ||||
|     [initialize] | ||||
| 
 | ||||
|     [initialize.components] | ||||
|  | @ -99,7 +105,7 @@ def test_serialize_subclassed_kb(): | |||
|     [initialize.components.entity_linker] | ||||
| 
 | ||||
|     [initialize.components.entity_linker.kb_loader] | ||||
|     @misc = "spacy.CustomKB.v1" | ||||
|     @misc = "kb_test.CustomKB.v1" | ||||
|     entity_vector_length = 342 | ||||
|     custom_field = 666 | ||||
|     """ | ||||
|  | @ -109,10 +115,57 @@ def test_serialize_subclassed_kb(): | |||
|             super().__init__(vocab, entity_vector_length) | ||||
|             self.custom_field = custom_field | ||||
| 
 | ||||
|     @registry.misc("spacy.CustomKB.v1") | ||||
|         def to_disk(self, path, exclude: Iterable[str] = SimpleFrozenList()): | ||||
|             """We overwrite InMemoryLookupKB.to_disk() to ensure that self.custom_field is stored as well.""" | ||||
|             path = ensure_path(path) | ||||
|             if not path.exists(): | ||||
|                 path.mkdir(parents=True) | ||||
|             if not path.is_dir(): | ||||
|                 raise ValueError(Errors.E928.format(loc=path)) | ||||
| 
 | ||||
|             def serialize_custom_fields(file_path: Path) -> None: | ||||
|                 srsly.write_json(file_path, {"custom_field": self.custom_field}) | ||||
| 
 | ||||
|             serialize = { | ||||
|                 "contents": lambda p: self.write_contents(p), | ||||
|                 "strings.json": lambda p: self.vocab.strings.to_disk(p), | ||||
|                 "custom_fields": lambda p: serialize_custom_fields(p), | ||||
|             } | ||||
|             util.to_disk(path, serialize, exclude) | ||||
| 
 | ||||
|         def from_disk(self, path, exclude: Iterable[str] = SimpleFrozenList()): | ||||
|             """We overwrite InMemoryLookupKB.from_disk() to ensure that self.custom_field is loaded as well.""" | ||||
|             path = ensure_path(path) | ||||
|             if not path.exists(): | ||||
|                 raise ValueError(Errors.E929.format(loc=path)) | ||||
|             if not path.is_dir(): | ||||
|                 raise ValueError(Errors.E928.format(loc=path)) | ||||
| 
 | ||||
|             def deserialize_custom_fields(file_path: Path) -> None: | ||||
|                 self.custom_field = srsly.read_json(file_path)["custom_field"] | ||||
| 
 | ||||
|             deserialize: Dict[str, Callable[[Any], Any]] = { | ||||
|                 "contents": lambda p: self.read_contents(p), | ||||
|                 "strings.json": lambda p: self.vocab.strings.from_disk(p), | ||||
|                 "custom_fields": lambda p: deserialize_custom_fields(p), | ||||
|             } | ||||
|             util.from_disk(path, deserialize, exclude) | ||||
| 
 | ||||
|     @registry.misc("kb_test.CustomEmptyKB.v1") | ||||
|     def empty_custom_kb() -> Callable[[Vocab, int], SubInMemoryLookupKB]: | ||||
|         def empty_kb_factory(vocab: Vocab, entity_vector_length: int): | ||||
|             return SubInMemoryLookupKB( | ||||
|                 vocab=vocab, | ||||
|                 entity_vector_length=entity_vector_length, | ||||
|                 custom_field=0, | ||||
|             ) | ||||
| 
 | ||||
|         return empty_kb_factory | ||||
| 
 | ||||
|     @registry.misc("kb_test.CustomKB.v1") | ||||
|     def custom_kb( | ||||
|         entity_vector_length: int, custom_field: int | ||||
|     ) -> Callable[[Vocab], InMemoryLookupKB]: | ||||
|     ) -> Callable[[Vocab], SubInMemoryLookupKB]: | ||||
|         def custom_kb_factory(vocab): | ||||
|             kb = SubInMemoryLookupKB( | ||||
|                 vocab=vocab, | ||||
|  | @ -139,6 +192,6 @@ def test_serialize_subclassed_kb(): | |||
|         nlp2 = util.load_model_from_path(tmp_dir) | ||||
|         entity_linker2 = nlp2.get_pipe("entity_linker") | ||||
|         # After IO, the KB is the standard one | ||||
|         assert type(entity_linker2.kb) == InMemoryLookupKB | ||||
|         assert type(entity_linker2.kb) == SubInMemoryLookupKB | ||||
|         assert entity_linker2.kb.entity_vector_length == 342 | ||||
|         assert not hasattr(entity_linker2.kb, "custom_field") | ||||
|         assert entity_linker2.kb.custom_field == 666 | ||||
|  |  | |||
|  | @ -899,15 +899,21 @@ The `EntityLinker` model architecture is a Thinc `Model` with a | |||
| | `nO`        | Output dimension, determined by the length of the vectors encoding each entity in the KB. If the `nO` dimension is not set, the entity linking component will set it when `initialize` is called. ~~Optional[int]~~ | | ||||
| | **CREATES** | The model using the architecture. ~~Model[List[Doc], Floats2d]~~                                                                                                                                                    | | ||||
| 
 | ||||
| ### spacy.EmptyKB.v1 {id="EmptyKB"} | ||||
| ### spacy.EmptyKB.v1 {id="EmptyKB.v1"} | ||||
| 
 | ||||
| A function that creates an empty `KnowledgeBase` from a [`Vocab`](/api/vocab) | ||||
| instance. This is the default when a new entity linker component is created. | ||||
| instance. | ||||
| 
 | ||||
| | Name                   | Description                                                                         | | ||||
| | ---------------------- | ----------------------------------------------------------------------------------- | | ||||
| | `entity_vector_length` | The length of the vectors encoding each entity in the KB. Defaults to `64`. ~~int~~ | | ||||
| 
 | ||||
| ### spacy.EmptyKB.v2 {id="EmptyKB"} | ||||
| 
 | ||||
| A function that creates an empty `KnowledgeBase` from a [`Vocab`](/api/vocab) | ||||
| instance. This is the default when a new entity linker component is created. It | ||||
| returns a `Callable[[Vocab, int], InMemoryLookupKB]`. | ||||
| 
 | ||||
| ### spacy.KBFromFile.v1 {id="KBFromFile"} | ||||
| 
 | ||||
| A function that reads an existing `KnowledgeBase` from file. | ||||
|  |  | |||
|  | @ -53,19 +53,21 @@ architectures and their arguments and hyperparameters. | |||
| > nlp.add_pipe("entity_linker", config=config) | ||||
| > ``` | ||||
| 
 | ||||
| | Setting                                  | Description                                                                                                                                                                                                                                                                                 | | ||||
| | ---------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | ||||
| | `labels_discard`                         | NER labels that will automatically get a "NIL" prediction. Defaults to `[]`. ~~Iterable[str]~~                                                                                                                                                                                              | | ||||
| | `n_sents`                                | The number of neighbouring sentences to take into account. Defaults to 0. ~~int~~                                                                                                                                                                                                           | | ||||
| | `incl_prior`                             | Whether or not to include prior probabilities from the KB in the model. Defaults to `True`. ~~bool~~                                                                                                                                                                                        | | ||||
| | `incl_context`                           | Whether or not to include the local context in the model. Defaults to `True`. ~~bool~~                                                                                                                                                                                                      | | ||||
| | `model`                                  | The [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. Defaults to [EntityLinker](/api/architectures#EntityLinker). ~~Model~~                                                                                                                                      | | ||||
| | `entity_vector_length`                   | Size of encoding vectors in the KB. Defaults to `64`. ~~int~~                                                                                                                                                                                                                               | | ||||
| | `use_gold_ents`                          | Whether to copy entities from the gold docs or not. Defaults to `True`. If `False`, entities must be set in the training data or by an annotating component in the pipeline. ~~int~~                                                                                                        | | ||||
| | `get_candidates`                         | Function that generates plausible candidates for a given `Span` object. Defaults to [CandidateGenerator](/api/architectures#CandidateGenerator), a function looking up exact, case-dependent aliases in the KB. ~~Callable[[KnowledgeBase, Span], Iterable[Candidate]]~~                    | | ||||
| | `overwrite` <Tag variant="new">3.2</Tag> | Whether existing annotation is overwritten. Defaults to `True`. ~~bool~~                                                                                                                                                                                                                    | | ||||
| | `scorer` <Tag variant="new">3.2</Tag>    | The scoring method. Defaults to [`Scorer.score_links`](/api/scorer#score_links). ~~Optional[Callable]~~                                                                                                                                                                                     | | ||||
| | `threshold` <Tag variant="new">3.4</Tag> | Confidence threshold for entity predictions. The default of `None` implies that all predictions are accepted, otherwise those with a score beneath the treshold are discarded. If there are no predictions with scores above the threshold, the linked entity is `NIL`. ~~Optional[float]~~ | | ||||
| | Setting                                             | Description                                                                                                                                                                                                                                                                                                      | | ||||
| | --------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | ||||
| | `labels_discard`                                    | NER labels that will automatically get a "NIL" prediction. Defaults to `[]`. ~~Iterable[str]~~                                                                                                                                                                                                                   | | ||||
| | `n_sents`                                           | The number of neighbouring sentences to take into account. Defaults to 0. ~~int~~                                                                                                                                                                                                                                | | ||||
| | `incl_prior`                                        | Whether or not to include prior probabilities from the KB in the model. Defaults to `True`. ~~bool~~                                                                                                                                                                                                             | | ||||
| | `incl_context`                                      | Whether or not to include the local context in the model. Defaults to `True`. ~~bool~~                                                                                                                                                                                                                           | | ||||
| | `model`                                             | The [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. Defaults to [EntityLinker](/api/architectures#EntityLinker). ~~Model~~                                                                                                                                                           | | ||||
| | `entity_vector_length`                              | Size of encoding vectors in the KB. Defaults to `64`. ~~int~~                                                                                                                                                                                                                                                    | | ||||
| | `use_gold_ents`                                     | Whether to copy entities from the gold docs or not. Defaults to `True`. If `False`, entities must be set in the training data or by an annotating component in the pipeline. ~~int~~                                                                                                                             | | ||||
| | `get_candidates`                                    | Function that generates plausible candidates for a given `Span` object. Defaults to [CandidateGenerator](/api/architectures#CandidateGenerator), a function looking up exact, case-dependent aliases in the KB. ~~Callable[[KnowledgeBase, Span], Iterable[Candidate]]~~                                         | | ||||
| | `get_candidates_batch` <Tag variant="new">3.5</Tag> | Function that generates plausible candidates for a given batch of `Span` objects. Defaults to [CandidateBatchGenerator](/api/architectures#CandidateBatchGenerator), a function looking up exact, case-dependent aliases in the KB. ~~Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]]~~ | | ||||
| | `generate_empty_kb` <Tag variant="new">3.6</Tag>    | Function that generates an empty `KnowledgeBase` object. Defaults to [`spacy.EmptyKB.v2`](/api/architectures#EmptyKB), which generates an empty [`InMemoryLookupKB`](/api/inmemorylookupkb). ~~Callable[[Vocab, int], KnowledgeBase]~~                                                                           | | ||||
| | `overwrite` <Tag variant="new">3.2</Tag>            | Whether existing annotation is overwritten. Defaults to `True`. ~~bool~~                                                                                                                                                                                                                                         | | ||||
| | `scorer` <Tag variant="new">3.2</Tag>               | The scoring method. Defaults to [`Scorer.score_links`](/api/scorer#score_links). ~~Optional[Callable]~~                                                                                                                                                                                                          | | ||||
| | `threshold` <Tag variant="new">3.4</Tag>            | Confidence threshold for entity predictions. The default of `None` implies that all predictions are accepted, otherwise those with a score beneath the treshold are discarded. If there are no predictions with scores above the threshold, the linked entity is `NIL`. ~~Optional[float]~~                      | | ||||
| 
 | ||||
| ```python | ||||
| %%GITHUB_SPACY/spacy/pipeline/entity_linker.py | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user