mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-26 13:41:21 +03:00 
			
		
		
		
	Make generation of empty KnowledgeBase instances configurable in EntityLinker (#12320)
				
					
				
			* Make empty_kb() configurable. * Format. * Update docs. * Be more specific in KB serialization test. * Update KB serialization tests. Update docs. * Remove doc update for batched candidate generation. * Fix serialization of subclassed KB in tests. * Format. * Update docstring. * Update docstring. * Switch from pickle to json for custom field serialization.
This commit is contained in:
		
							parent
							
								
									e325de3ff8
								
							
						
					
					
						commit
						6f1632b3e9
					
				|  | @ -89,6 +89,14 @@ def load_kb( | ||||||
|     return kb_from_file |     return kb_from_file | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @registry.misc("spacy.EmptyKB.v2") | ||||||
|  | def empty_kb_for_config() -> Callable[[Vocab, int], KnowledgeBase]: | ||||||
|  |     def empty_kb_factory(vocab: Vocab, entity_vector_length: int): | ||||||
|  |         return InMemoryLookupKB(vocab=vocab, entity_vector_length=entity_vector_length) | ||||||
|  | 
 | ||||||
|  |     return empty_kb_factory | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| @registry.misc("spacy.EmptyKB.v1") | @registry.misc("spacy.EmptyKB.v1") | ||||||
| def empty_kb( | def empty_kb( | ||||||
|     entity_vector_length: int, |     entity_vector_length: int, | ||||||
|  |  | ||||||
|  | @ -54,6 +54,7 @@ DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"] | ||||||
|         "entity_vector_length": 64, |         "entity_vector_length": 64, | ||||||
|         "get_candidates": {"@misc": "spacy.CandidateGenerator.v1"}, |         "get_candidates": {"@misc": "spacy.CandidateGenerator.v1"}, | ||||||
|         "get_candidates_batch": {"@misc": "spacy.CandidateBatchGenerator.v1"}, |         "get_candidates_batch": {"@misc": "spacy.CandidateBatchGenerator.v1"}, | ||||||
|  |         "generate_empty_kb": {"@misc": "spacy.EmptyKB.v2"}, | ||||||
|         "overwrite": True, |         "overwrite": True, | ||||||
|         "scorer": {"@scorers": "spacy.entity_linker_scorer.v1"}, |         "scorer": {"@scorers": "spacy.entity_linker_scorer.v1"}, | ||||||
|         "use_gold_ents": True, |         "use_gold_ents": True, | ||||||
|  | @ -80,6 +81,7 @@ def make_entity_linker( | ||||||
|     get_candidates_batch: Callable[ |     get_candidates_batch: Callable[ | ||||||
|         [KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]] |         [KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]] | ||||||
|     ], |     ], | ||||||
|  |     generate_empty_kb: Callable[[Vocab, int], KnowledgeBase], | ||||||
|     overwrite: bool, |     overwrite: bool, | ||||||
|     scorer: Optional[Callable], |     scorer: Optional[Callable], | ||||||
|     use_gold_ents: bool, |     use_gold_ents: bool, | ||||||
|  | @ -101,6 +103,7 @@ def make_entity_linker( | ||||||
|     get_candidates_batch ( |     get_candidates_batch ( | ||||||
|         Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]], Iterable[Candidate]] |         Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]], Iterable[Candidate]] | ||||||
|         ): Function that produces a list of candidates, given a certain knowledge base and several textual mentions. |         ): Function that produces a list of candidates, given a certain knowledge base and several textual mentions. | ||||||
|  |     generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase. | ||||||
|     scorer (Optional[Callable]): The scoring method. |     scorer (Optional[Callable]): The scoring method. | ||||||
|     use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another |     use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another | ||||||
|         component must provide entity annotations. |         component must provide entity annotations. | ||||||
|  | @ -135,6 +138,7 @@ def make_entity_linker( | ||||||
|         entity_vector_length=entity_vector_length, |         entity_vector_length=entity_vector_length, | ||||||
|         get_candidates=get_candidates, |         get_candidates=get_candidates, | ||||||
|         get_candidates_batch=get_candidates_batch, |         get_candidates_batch=get_candidates_batch, | ||||||
|  |         generate_empty_kb=generate_empty_kb, | ||||||
|         overwrite=overwrite, |         overwrite=overwrite, | ||||||
|         scorer=scorer, |         scorer=scorer, | ||||||
|         use_gold_ents=use_gold_ents, |         use_gold_ents=use_gold_ents, | ||||||
|  | @ -175,6 +179,7 @@ class EntityLinker(TrainablePipe): | ||||||
|         get_candidates_batch: Callable[ |         get_candidates_batch: Callable[ | ||||||
|             [KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]] |             [KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]] | ||||||
|         ], |         ], | ||||||
|  |         generate_empty_kb: Callable[[Vocab, int], KnowledgeBase], | ||||||
|         overwrite: bool = BACKWARD_OVERWRITE, |         overwrite: bool = BACKWARD_OVERWRITE, | ||||||
|         scorer: Optional[Callable] = entity_linker_score, |         scorer: Optional[Callable] = entity_linker_score, | ||||||
|         use_gold_ents: bool, |         use_gold_ents: bool, | ||||||
|  | @ -198,6 +203,7 @@ class EntityLinker(TrainablePipe): | ||||||
|             Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]], |             Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]], | ||||||
|             Iterable[Candidate]] |             Iterable[Candidate]] | ||||||
|             ): Function that produces a list of candidates, given a certain knowledge base and several textual mentions. |             ): Function that produces a list of candidates, given a certain knowledge base and several textual mentions. | ||||||
|  |         generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase. | ||||||
|         scorer (Optional[Callable]): The scoring method. Defaults to Scorer.score_links. |         scorer (Optional[Callable]): The scoring method. Defaults to Scorer.score_links. | ||||||
|         use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another |         use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another | ||||||
|             component must provide entity annotations. |             component must provide entity annotations. | ||||||
|  | @ -220,6 +226,7 @@ class EntityLinker(TrainablePipe): | ||||||
|         self.model = model |         self.model = model | ||||||
|         self.name = name |         self.name = name | ||||||
|         self.labels_discard = list(labels_discard) |         self.labels_discard = list(labels_discard) | ||||||
|  |         # how many neighbour sentences to take into account | ||||||
|         self.n_sents = n_sents |         self.n_sents = n_sents | ||||||
|         self.incl_prior = incl_prior |         self.incl_prior = incl_prior | ||||||
|         self.incl_context = incl_context |         self.incl_context = incl_context | ||||||
|  | @ -227,9 +234,7 @@ class EntityLinker(TrainablePipe): | ||||||
|         self.get_candidates_batch = get_candidates_batch |         self.get_candidates_batch = get_candidates_batch | ||||||
|         self.cfg: Dict[str, Any] = {"overwrite": overwrite} |         self.cfg: Dict[str, Any] = {"overwrite": overwrite} | ||||||
|         self.distance = CosineDistance(normalize=False) |         self.distance = CosineDistance(normalize=False) | ||||||
|         # how many neighbour sentences to take into account |         self.kb = generate_empty_kb(self.vocab, entity_vector_length) | ||||||
|         # create an empty KB by default |  | ||||||
|         self.kb = empty_kb(entity_vector_length)(self.vocab) |  | ||||||
|         self.scorer = scorer |         self.scorer = scorer | ||||||
|         self.use_gold_ents = use_gold_ents |         self.use_gold_ents = use_gold_ents | ||||||
|         self.candidates_batch_size = candidates_batch_size |         self.candidates_batch_size = candidates_batch_size | ||||||
|  |  | ||||||
|  | @ -1,7 +1,10 @@ | ||||||
| from typing import Callable | from pathlib import Path | ||||||
|  | from typing import Callable, Iterable, Any, Dict | ||||||
| 
 | 
 | ||||||
| from spacy import util | import srsly | ||||||
| from spacy.util import ensure_path, registry, load_model_from_config | 
 | ||||||
|  | from spacy import util, Errors | ||||||
|  | from spacy.util import ensure_path, registry, load_model_from_config, SimpleFrozenList | ||||||
| from spacy.kb.kb_in_memory import InMemoryLookupKB | from spacy.kb.kb_in_memory import InMemoryLookupKB | ||||||
| from spacy.vocab import Vocab | from spacy.vocab import Vocab | ||||||
| from thinc.api import Config | from thinc.api import Config | ||||||
|  | @ -92,6 +95,9 @@ def test_serialize_subclassed_kb(): | ||||||
|     [components.entity_linker] |     [components.entity_linker] | ||||||
|     factory = "entity_linker" |     factory = "entity_linker" | ||||||
|      |      | ||||||
|  |     [components.entity_linker.generate_empty_kb] | ||||||
|  |     @misc = "kb_test.CustomEmptyKB.v1" | ||||||
|  |      | ||||||
|     [initialize] |     [initialize] | ||||||
| 
 | 
 | ||||||
|     [initialize.components] |     [initialize.components] | ||||||
|  | @ -99,7 +105,7 @@ def test_serialize_subclassed_kb(): | ||||||
|     [initialize.components.entity_linker] |     [initialize.components.entity_linker] | ||||||
| 
 | 
 | ||||||
|     [initialize.components.entity_linker.kb_loader] |     [initialize.components.entity_linker.kb_loader] | ||||||
|     @misc = "spacy.CustomKB.v1" |     @misc = "kb_test.CustomKB.v1" | ||||||
|     entity_vector_length = 342 |     entity_vector_length = 342 | ||||||
|     custom_field = 666 |     custom_field = 666 | ||||||
|     """ |     """ | ||||||
|  | @ -109,10 +115,57 @@ def test_serialize_subclassed_kb(): | ||||||
|             super().__init__(vocab, entity_vector_length) |             super().__init__(vocab, entity_vector_length) | ||||||
|             self.custom_field = custom_field |             self.custom_field = custom_field | ||||||
| 
 | 
 | ||||||
|     @registry.misc("spacy.CustomKB.v1") |         def to_disk(self, path, exclude: Iterable[str] = SimpleFrozenList()): | ||||||
|  |             """We overwrite InMemoryLookupKB.to_disk() to ensure that self.custom_field is stored as well.""" | ||||||
|  |             path = ensure_path(path) | ||||||
|  |             if not path.exists(): | ||||||
|  |                 path.mkdir(parents=True) | ||||||
|  |             if not path.is_dir(): | ||||||
|  |                 raise ValueError(Errors.E928.format(loc=path)) | ||||||
|  | 
 | ||||||
|  |             def serialize_custom_fields(file_path: Path) -> None: | ||||||
|  |                 srsly.write_json(file_path, {"custom_field": self.custom_field}) | ||||||
|  | 
 | ||||||
|  |             serialize = { | ||||||
|  |                 "contents": lambda p: self.write_contents(p), | ||||||
|  |                 "strings.json": lambda p: self.vocab.strings.to_disk(p), | ||||||
|  |                 "custom_fields": lambda p: serialize_custom_fields(p), | ||||||
|  |             } | ||||||
|  |             util.to_disk(path, serialize, exclude) | ||||||
|  | 
 | ||||||
|  |         def from_disk(self, path, exclude: Iterable[str] = SimpleFrozenList()): | ||||||
|  |             """We overwrite InMemoryLookupKB.from_disk() to ensure that self.custom_field is loaded as well.""" | ||||||
|  |             path = ensure_path(path) | ||||||
|  |             if not path.exists(): | ||||||
|  |                 raise ValueError(Errors.E929.format(loc=path)) | ||||||
|  |             if not path.is_dir(): | ||||||
|  |                 raise ValueError(Errors.E928.format(loc=path)) | ||||||
|  | 
 | ||||||
|  |             def deserialize_custom_fields(file_path: Path) -> None: | ||||||
|  |                 self.custom_field = srsly.read_json(file_path)["custom_field"] | ||||||
|  | 
 | ||||||
|  |             deserialize: Dict[str, Callable[[Any], Any]] = { | ||||||
|  |                 "contents": lambda p: self.read_contents(p), | ||||||
|  |                 "strings.json": lambda p: self.vocab.strings.from_disk(p), | ||||||
|  |                 "custom_fields": lambda p: deserialize_custom_fields(p), | ||||||
|  |             } | ||||||
|  |             util.from_disk(path, deserialize, exclude) | ||||||
|  | 
 | ||||||
|  |     @registry.misc("kb_test.CustomEmptyKB.v1") | ||||||
|  |     def empty_custom_kb() -> Callable[[Vocab, int], SubInMemoryLookupKB]: | ||||||
|  |         def empty_kb_factory(vocab: Vocab, entity_vector_length: int): | ||||||
|  |             return SubInMemoryLookupKB( | ||||||
|  |                 vocab=vocab, | ||||||
|  |                 entity_vector_length=entity_vector_length, | ||||||
|  |                 custom_field=0, | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |         return empty_kb_factory | ||||||
|  | 
 | ||||||
|  |     @registry.misc("kb_test.CustomKB.v1") | ||||||
|     def custom_kb( |     def custom_kb( | ||||||
|         entity_vector_length: int, custom_field: int |         entity_vector_length: int, custom_field: int | ||||||
|     ) -> Callable[[Vocab], InMemoryLookupKB]: |     ) -> Callable[[Vocab], SubInMemoryLookupKB]: | ||||||
|         def custom_kb_factory(vocab): |         def custom_kb_factory(vocab): | ||||||
|             kb = SubInMemoryLookupKB( |             kb = SubInMemoryLookupKB( | ||||||
|                 vocab=vocab, |                 vocab=vocab, | ||||||
|  | @ -139,6 +192,6 @@ def test_serialize_subclassed_kb(): | ||||||
|         nlp2 = util.load_model_from_path(tmp_dir) |         nlp2 = util.load_model_from_path(tmp_dir) | ||||||
|         entity_linker2 = nlp2.get_pipe("entity_linker") |         entity_linker2 = nlp2.get_pipe("entity_linker") | ||||||
|         # After IO, the KB is the standard one |         # After IO, the KB is the standard one | ||||||
|         assert type(entity_linker2.kb) == InMemoryLookupKB |         assert type(entity_linker2.kb) == SubInMemoryLookupKB | ||||||
|         assert entity_linker2.kb.entity_vector_length == 342 |         assert entity_linker2.kb.entity_vector_length == 342 | ||||||
|         assert not hasattr(entity_linker2.kb, "custom_field") |         assert entity_linker2.kb.custom_field == 666 | ||||||
|  |  | ||||||
|  | @ -899,15 +899,21 @@ The `EntityLinker` model architecture is a Thinc `Model` with a | ||||||
| | `nO`        | Output dimension, determined by the length of the vectors encoding each entity in the KB. If the `nO` dimension is not set, the entity linking component will set it when `initialize` is called. ~~Optional[int]~~ | | | `nO`        | Output dimension, determined by the length of the vectors encoding each entity in the KB. If the `nO` dimension is not set, the entity linking component will set it when `initialize` is called. ~~Optional[int]~~ | | ||||||
| | **CREATES** | The model using the architecture. ~~Model[List[Doc], Floats2d]~~                                                                                                                                                    | | | **CREATES** | The model using the architecture. ~~Model[List[Doc], Floats2d]~~                                                                                                                                                    | | ||||||
| 
 | 
 | ||||||
| ### spacy.EmptyKB.v1 {id="EmptyKB"} | ### spacy.EmptyKB.v1 {id="EmptyKB.v1"} | ||||||
| 
 | 
 | ||||||
| A function that creates an empty `KnowledgeBase` from a [`Vocab`](/api/vocab) | A function that creates an empty `KnowledgeBase` from a [`Vocab`](/api/vocab) | ||||||
| instance. This is the default when a new entity linker component is created. | instance. | ||||||
| 
 | 
 | ||||||
| | Name                   | Description                                                                         | | | Name                   | Description                                                                         | | ||||||
| | ---------------------- | ----------------------------------------------------------------------------------- | | | ---------------------- | ----------------------------------------------------------------------------------- | | ||||||
| | `entity_vector_length` | The length of the vectors encoding each entity in the KB. Defaults to `64`. ~~int~~ | | | `entity_vector_length` | The length of the vectors encoding each entity in the KB. Defaults to `64`. ~~int~~ | | ||||||
| 
 | 
 | ||||||
|  | ### spacy.EmptyKB.v2 {id="EmptyKB"} | ||||||
|  | 
 | ||||||
|  | A function that creates an empty `KnowledgeBase` from a [`Vocab`](/api/vocab) | ||||||
|  | instance. This is the default when a new entity linker component is created. It | ||||||
|  | returns a `Callable[[Vocab, int], InMemoryLookupKB]`. | ||||||
|  | 
 | ||||||
| ### spacy.KBFromFile.v1 {id="KBFromFile"} | ### spacy.KBFromFile.v1 {id="KBFromFile"} | ||||||
| 
 | 
 | ||||||
| A function that reads an existing `KnowledgeBase` from file. | A function that reads an existing `KnowledgeBase` from file. | ||||||
|  |  | ||||||
|  | @ -54,7 +54,7 @@ architectures and their arguments and hyperparameters. | ||||||
| > ``` | > ``` | ||||||
| 
 | 
 | ||||||
| | Setting                                             | Description                                                                                                                                                                                                                                                                                                      | | | Setting                                             | Description                                                                                                                                                                                                                                                                                                      | | ||||||
| | ---------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | | --------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | ||||||
| | `labels_discard`                                    | NER labels that will automatically get a "NIL" prediction. Defaults to `[]`. ~~Iterable[str]~~                                                                                                                                                                                                                   | | | `labels_discard`                                    | NER labels that will automatically get a "NIL" prediction. Defaults to `[]`. ~~Iterable[str]~~                                                                                                                                                                                                                   | | ||||||
| | `n_sents`                                           | The number of neighbouring sentences to take into account. Defaults to 0. ~~int~~                                                                                                                                                                                                                                | | | `n_sents`                                           | The number of neighbouring sentences to take into account. Defaults to 0. ~~int~~                                                                                                                                                                                                                                | | ||||||
| | `incl_prior`                                        | Whether or not to include prior probabilities from the KB in the model. Defaults to `True`. ~~bool~~                                                                                                                                                                                                             | | | `incl_prior`                                        | Whether or not to include prior probabilities from the KB in the model. Defaults to `True`. ~~bool~~                                                                                                                                                                                                             | | ||||||
|  | @ -63,6 +63,8 @@ architectures and their arguments and hyperparameters. | ||||||
| | `entity_vector_length`                              | Size of encoding vectors in the KB. Defaults to `64`. ~~int~~                                                                                                                                                                                                                                                    | | | `entity_vector_length`                              | Size of encoding vectors in the KB. Defaults to `64`. ~~int~~                                                                                                                                                                                                                                                    | | ||||||
| | `use_gold_ents`                                     | Whether to copy entities from the gold docs or not. Defaults to `True`. If `False`, entities must be set in the training data or by an annotating component in the pipeline. ~~int~~                                                                                                                             | | | `use_gold_ents`                                     | Whether to copy entities from the gold docs or not. Defaults to `True`. If `False`, entities must be set in the training data or by an annotating component in the pipeline. ~~int~~                                                                                                                             | | ||||||
| | `get_candidates`                                    | Function that generates plausible candidates for a given `Span` object. Defaults to [CandidateGenerator](/api/architectures#CandidateGenerator), a function looking up exact, case-dependent aliases in the KB. ~~Callable[[KnowledgeBase, Span], Iterable[Candidate]]~~                                         | | | `get_candidates`                                    | Function that generates plausible candidates for a given `Span` object. Defaults to [CandidateGenerator](/api/architectures#CandidateGenerator), a function looking up exact, case-dependent aliases in the KB. ~~Callable[[KnowledgeBase, Span], Iterable[Candidate]]~~                                         | | ||||||
|  | | `get_candidates_batch` <Tag variant="new">3.5</Tag> | Function that generates plausible candidates for a given batch of `Span` objects. Defaults to [CandidateBatchGenerator](/api/architectures#CandidateBatchGenerator), a function looking up exact, case-dependent aliases in the KB. ~~Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]]~~ | | ||||||
|  | | `generate_empty_kb` <Tag variant="new">3.6</Tag>    | Function that generates an empty `KnowledgeBase` object. Defaults to [`spacy.EmptyKB.v2`](/api/architectures#EmptyKB), which generates an empty [`InMemoryLookupKB`](/api/inmemorylookupkb). ~~Callable[[Vocab, int], KnowledgeBase]~~                                                                           | | ||||||
| | `overwrite` <Tag variant="new">3.2</Tag>            | Whether existing annotation is overwritten. Defaults to `True`. ~~bool~~                                                                                                                                                                                                                                         | | | `overwrite` <Tag variant="new">3.2</Tag>            | Whether existing annotation is overwritten. Defaults to `True`. ~~bool~~                                                                                                                                                                                                                                         | | ||||||
| | `scorer` <Tag variant="new">3.2</Tag>               | The scoring method. Defaults to [`Scorer.score_links`](/api/scorer#score_links). ~~Optional[Callable]~~                                                                                                                                                                                                          | | | `scorer` <Tag variant="new">3.2</Tag>               | The scoring method. Defaults to [`Scorer.score_links`](/api/scorer#score_links). ~~Optional[Callable]~~                                                                                                                                                                                                          | | ||||||
| | `threshold` <Tag variant="new">3.4</Tag>            | Confidence threshold for entity predictions. The default of `None` implies that all predictions are accepted, otherwise those with a score beneath the treshold are discarded. If there are no predictions with scores above the threshold, the linked entity is `NIL`. ~~Optional[float]~~                      | | | `threshold` <Tag variant="new">3.4</Tag>            | Confidence threshold for entity predictions. The default of `None` implies that all predictions are accepted, otherwise those with a score beneath the treshold are discarded. If there are no predictions with scores above the threshold, the linked entity is `NIL`. ~~Optional[float]~~                      | | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user