mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 01:16:28 +03:00
Make generation of empty KnowledgeBase
instances configurable in EntityLinker
(#12320)
* Make empty_kb() configurable. * Format. * Update docs. * Be more specific in KB serialization test. * Update KB serialization tests. Update docs. * Remove doc update for batched candidate generation. * Fix serialization of subclassed KB in tests. * Format. * Update docstring. * Update docstring. * Switch from pickle to json for custom field serialization.
This commit is contained in:
parent
56aa0cc75f
commit
6aa6b86d49
|
@ -89,6 +89,14 @@ def load_kb(
|
||||||
return kb_from_file
|
return kb_from_file
|
||||||
|
|
||||||
|
|
||||||
|
@registry.misc("spacy.EmptyKB.v2")
|
||||||
|
def empty_kb_for_config() -> Callable[[Vocab, int], KnowledgeBase]:
|
||||||
|
def empty_kb_factory(vocab: Vocab, entity_vector_length: int):
|
||||||
|
return InMemoryLookupKB(vocab=vocab, entity_vector_length=entity_vector_length)
|
||||||
|
|
||||||
|
return empty_kb_factory
|
||||||
|
|
||||||
|
|
||||||
@registry.misc("spacy.EmptyKB.v1")
|
@registry.misc("spacy.EmptyKB.v1")
|
||||||
def empty_kb(
|
def empty_kb(
|
||||||
entity_vector_length: int,
|
entity_vector_length: int,
|
||||||
|
|
|
@ -54,6 +54,7 @@ DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"]
|
||||||
"entity_vector_length": 64,
|
"entity_vector_length": 64,
|
||||||
"get_candidates": {"@misc": "spacy.CandidateGenerator.v1"},
|
"get_candidates": {"@misc": "spacy.CandidateGenerator.v1"},
|
||||||
"get_candidates_batch": {"@misc": "spacy.CandidateBatchGenerator.v1"},
|
"get_candidates_batch": {"@misc": "spacy.CandidateBatchGenerator.v1"},
|
||||||
|
"generate_empty_kb": {"@misc": "spacy.EmptyKB.v2"},
|
||||||
"overwrite": True,
|
"overwrite": True,
|
||||||
"scorer": {"@scorers": "spacy.entity_linker_scorer.v1"},
|
"scorer": {"@scorers": "spacy.entity_linker_scorer.v1"},
|
||||||
"use_gold_ents": True,
|
"use_gold_ents": True,
|
||||||
|
@ -80,6 +81,7 @@ def make_entity_linker(
|
||||||
get_candidates_batch: Callable[
|
get_candidates_batch: Callable[
|
||||||
[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]
|
[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]
|
||||||
],
|
],
|
||||||
|
generate_empty_kb: Callable[[Vocab, int], KnowledgeBase],
|
||||||
overwrite: bool,
|
overwrite: bool,
|
||||||
scorer: Optional[Callable],
|
scorer: Optional[Callable],
|
||||||
use_gold_ents: bool,
|
use_gold_ents: bool,
|
||||||
|
@ -101,6 +103,7 @@ def make_entity_linker(
|
||||||
get_candidates_batch (
|
get_candidates_batch (
|
||||||
Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]], Iterable[Candidate]]
|
Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]], Iterable[Candidate]]
|
||||||
): Function that produces a list of candidates, given a certain knowledge base and several textual mentions.
|
): Function that produces a list of candidates, given a certain knowledge base and several textual mentions.
|
||||||
|
generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase.
|
||||||
scorer (Optional[Callable]): The scoring method.
|
scorer (Optional[Callable]): The scoring method.
|
||||||
use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another
|
use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another
|
||||||
component must provide entity annotations.
|
component must provide entity annotations.
|
||||||
|
@ -135,6 +138,7 @@ def make_entity_linker(
|
||||||
entity_vector_length=entity_vector_length,
|
entity_vector_length=entity_vector_length,
|
||||||
get_candidates=get_candidates,
|
get_candidates=get_candidates,
|
||||||
get_candidates_batch=get_candidates_batch,
|
get_candidates_batch=get_candidates_batch,
|
||||||
|
generate_empty_kb=generate_empty_kb,
|
||||||
overwrite=overwrite,
|
overwrite=overwrite,
|
||||||
scorer=scorer,
|
scorer=scorer,
|
||||||
use_gold_ents=use_gold_ents,
|
use_gold_ents=use_gold_ents,
|
||||||
|
@ -175,6 +179,7 @@ class EntityLinker(TrainablePipe):
|
||||||
get_candidates_batch: Callable[
|
get_candidates_batch: Callable[
|
||||||
[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]
|
[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]
|
||||||
],
|
],
|
||||||
|
generate_empty_kb: Callable[[Vocab, int], KnowledgeBase],
|
||||||
overwrite: bool = BACKWARD_OVERWRITE,
|
overwrite: bool = BACKWARD_OVERWRITE,
|
||||||
scorer: Optional[Callable] = entity_linker_score,
|
scorer: Optional[Callable] = entity_linker_score,
|
||||||
use_gold_ents: bool,
|
use_gold_ents: bool,
|
||||||
|
@ -198,6 +203,7 @@ class EntityLinker(TrainablePipe):
|
||||||
Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]],
|
Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]],
|
||||||
Iterable[Candidate]]
|
Iterable[Candidate]]
|
||||||
): Function that produces a list of candidates, given a certain knowledge base and several textual mentions.
|
): Function that produces a list of candidates, given a certain knowledge base and several textual mentions.
|
||||||
|
generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase.
|
||||||
scorer (Optional[Callable]): The scoring method. Defaults to Scorer.score_links.
|
scorer (Optional[Callable]): The scoring method. Defaults to Scorer.score_links.
|
||||||
use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another
|
use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another
|
||||||
component must provide entity annotations.
|
component must provide entity annotations.
|
||||||
|
@ -220,6 +226,7 @@ class EntityLinker(TrainablePipe):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.name = name
|
self.name = name
|
||||||
self.labels_discard = list(labels_discard)
|
self.labels_discard = list(labels_discard)
|
||||||
|
# how many neighbour sentences to take into account
|
||||||
self.n_sents = n_sents
|
self.n_sents = n_sents
|
||||||
self.incl_prior = incl_prior
|
self.incl_prior = incl_prior
|
||||||
self.incl_context = incl_context
|
self.incl_context = incl_context
|
||||||
|
@ -227,9 +234,7 @@ class EntityLinker(TrainablePipe):
|
||||||
self.get_candidates_batch = get_candidates_batch
|
self.get_candidates_batch = get_candidates_batch
|
||||||
self.cfg: Dict[str, Any] = {"overwrite": overwrite}
|
self.cfg: Dict[str, Any] = {"overwrite": overwrite}
|
||||||
self.distance = CosineDistance(normalize=False)
|
self.distance = CosineDistance(normalize=False)
|
||||||
# how many neighbour sentences to take into account
|
self.kb = generate_empty_kb(self.vocab, entity_vector_length)
|
||||||
# create an empty KB by default
|
|
||||||
self.kb = empty_kb(entity_vector_length)(self.vocab)
|
|
||||||
self.scorer = scorer
|
self.scorer = scorer
|
||||||
self.use_gold_ents = use_gold_ents
|
self.use_gold_ents = use_gold_ents
|
||||||
self.candidates_batch_size = candidates_batch_size
|
self.candidates_batch_size = candidates_batch_size
|
||||||
|
|
|
@ -1,7 +1,10 @@
|
||||||
from typing import Callable
|
from pathlib import Path
|
||||||
|
from typing import Callable, Iterable, Any, Dict
|
||||||
|
|
||||||
from spacy import util
|
import srsly
|
||||||
from spacy.util import ensure_path, registry, load_model_from_config
|
|
||||||
|
from spacy import util, Errors
|
||||||
|
from spacy.util import ensure_path, registry, load_model_from_config, SimpleFrozenList
|
||||||
from spacy.kb.kb_in_memory import InMemoryLookupKB
|
from spacy.kb.kb_in_memory import InMemoryLookupKB
|
||||||
from spacy.vocab import Vocab
|
from spacy.vocab import Vocab
|
||||||
from thinc.api import Config
|
from thinc.api import Config
|
||||||
|
@ -92,6 +95,9 @@ def test_serialize_subclassed_kb():
|
||||||
[components.entity_linker]
|
[components.entity_linker]
|
||||||
factory = "entity_linker"
|
factory = "entity_linker"
|
||||||
|
|
||||||
|
[components.entity_linker.generate_empty_kb]
|
||||||
|
@misc = "kb_test.CustomEmptyKB.v1"
|
||||||
|
|
||||||
[initialize]
|
[initialize]
|
||||||
|
|
||||||
[initialize.components]
|
[initialize.components]
|
||||||
|
@ -99,7 +105,7 @@ def test_serialize_subclassed_kb():
|
||||||
[initialize.components.entity_linker]
|
[initialize.components.entity_linker]
|
||||||
|
|
||||||
[initialize.components.entity_linker.kb_loader]
|
[initialize.components.entity_linker.kb_loader]
|
||||||
@misc = "spacy.CustomKB.v1"
|
@misc = "kb_test.CustomKB.v1"
|
||||||
entity_vector_length = 342
|
entity_vector_length = 342
|
||||||
custom_field = 666
|
custom_field = 666
|
||||||
"""
|
"""
|
||||||
|
@ -109,10 +115,57 @@ def test_serialize_subclassed_kb():
|
||||||
super().__init__(vocab, entity_vector_length)
|
super().__init__(vocab, entity_vector_length)
|
||||||
self.custom_field = custom_field
|
self.custom_field = custom_field
|
||||||
|
|
||||||
@registry.misc("spacy.CustomKB.v1")
|
def to_disk(self, path, exclude: Iterable[str] = SimpleFrozenList()):
|
||||||
|
"""We overwrite InMemoryLookupKB.to_disk() to ensure that self.custom_field is stored as well."""
|
||||||
|
path = ensure_path(path)
|
||||||
|
if not path.exists():
|
||||||
|
path.mkdir(parents=True)
|
||||||
|
if not path.is_dir():
|
||||||
|
raise ValueError(Errors.E928.format(loc=path))
|
||||||
|
|
||||||
|
def serialize_custom_fields(file_path: Path) -> None:
|
||||||
|
srsly.write_json(file_path, {"custom_field": self.custom_field})
|
||||||
|
|
||||||
|
serialize = {
|
||||||
|
"contents": lambda p: self.write_contents(p),
|
||||||
|
"strings.json": lambda p: self.vocab.strings.to_disk(p),
|
||||||
|
"custom_fields": lambda p: serialize_custom_fields(p),
|
||||||
|
}
|
||||||
|
util.to_disk(path, serialize, exclude)
|
||||||
|
|
||||||
|
def from_disk(self, path, exclude: Iterable[str] = SimpleFrozenList()):
|
||||||
|
"""We overwrite InMemoryLookupKB.from_disk() to ensure that self.custom_field is loaded as well."""
|
||||||
|
path = ensure_path(path)
|
||||||
|
if not path.exists():
|
||||||
|
raise ValueError(Errors.E929.format(loc=path))
|
||||||
|
if not path.is_dir():
|
||||||
|
raise ValueError(Errors.E928.format(loc=path))
|
||||||
|
|
||||||
|
def deserialize_custom_fields(file_path: Path) -> None:
|
||||||
|
self.custom_field = srsly.read_json(file_path)["custom_field"]
|
||||||
|
|
||||||
|
deserialize: Dict[str, Callable[[Any], Any]] = {
|
||||||
|
"contents": lambda p: self.read_contents(p),
|
||||||
|
"strings.json": lambda p: self.vocab.strings.from_disk(p),
|
||||||
|
"custom_fields": lambda p: deserialize_custom_fields(p),
|
||||||
|
}
|
||||||
|
util.from_disk(path, deserialize, exclude)
|
||||||
|
|
||||||
|
@registry.misc("kb_test.CustomEmptyKB.v1")
|
||||||
|
def empty_custom_kb() -> Callable[[Vocab, int], SubInMemoryLookupKB]:
|
||||||
|
def empty_kb_factory(vocab: Vocab, entity_vector_length: int):
|
||||||
|
return SubInMemoryLookupKB(
|
||||||
|
vocab=vocab,
|
||||||
|
entity_vector_length=entity_vector_length,
|
||||||
|
custom_field=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
return empty_kb_factory
|
||||||
|
|
||||||
|
@registry.misc("kb_test.CustomKB.v1")
|
||||||
def custom_kb(
|
def custom_kb(
|
||||||
entity_vector_length: int, custom_field: int
|
entity_vector_length: int, custom_field: int
|
||||||
) -> Callable[[Vocab], InMemoryLookupKB]:
|
) -> Callable[[Vocab], SubInMemoryLookupKB]:
|
||||||
def custom_kb_factory(vocab):
|
def custom_kb_factory(vocab):
|
||||||
kb = SubInMemoryLookupKB(
|
kb = SubInMemoryLookupKB(
|
||||||
vocab=vocab,
|
vocab=vocab,
|
||||||
|
@ -139,6 +192,6 @@ def test_serialize_subclassed_kb():
|
||||||
nlp2 = util.load_model_from_path(tmp_dir)
|
nlp2 = util.load_model_from_path(tmp_dir)
|
||||||
entity_linker2 = nlp2.get_pipe("entity_linker")
|
entity_linker2 = nlp2.get_pipe("entity_linker")
|
||||||
# After IO, the KB is the standard one
|
# After IO, the KB is the standard one
|
||||||
assert type(entity_linker2.kb) == InMemoryLookupKB
|
assert type(entity_linker2.kb) == SubInMemoryLookupKB
|
||||||
assert entity_linker2.kb.entity_vector_length == 342
|
assert entity_linker2.kb.entity_vector_length == 342
|
||||||
assert not hasattr(entity_linker2.kb, "custom_field")
|
assert entity_linker2.kb.custom_field == 666
|
||||||
|
|
|
@ -899,15 +899,21 @@ The `EntityLinker` model architecture is a Thinc `Model` with a
|
||||||
| `nO` | Output dimension, determined by the length of the vectors encoding each entity in the KB. If the `nO` dimension is not set, the entity linking component will set it when `initialize` is called. ~~Optional[int]~~ |
|
| `nO` | Output dimension, determined by the length of the vectors encoding each entity in the KB. If the `nO` dimension is not set, the entity linking component will set it when `initialize` is called. ~~Optional[int]~~ |
|
||||||
| **CREATES** | The model using the architecture. ~~Model[List[Doc], Floats2d]~~ |
|
| **CREATES** | The model using the architecture. ~~Model[List[Doc], Floats2d]~~ |
|
||||||
|
|
||||||
### spacy.EmptyKB.v1 {id="EmptyKB"}
|
### spacy.EmptyKB.v1 {id="EmptyKB.v1"}
|
||||||
|
|
||||||
A function that creates an empty `KnowledgeBase` from a [`Vocab`](/api/vocab)
|
A function that creates an empty `KnowledgeBase` from a [`Vocab`](/api/vocab)
|
||||||
instance. This is the default when a new entity linker component is created.
|
instance.
|
||||||
|
|
||||||
| Name | Description |
|
| Name | Description |
|
||||||
| ---------------------- | ----------------------------------------------------------------------------------- |
|
| ---------------------- | ----------------------------------------------------------------------------------- |
|
||||||
| `entity_vector_length` | The length of the vectors encoding each entity in the KB. Defaults to `64`. ~~int~~ |
|
| `entity_vector_length` | The length of the vectors encoding each entity in the KB. Defaults to `64`. ~~int~~ |
|
||||||
|
|
||||||
|
### spacy.EmptyKB.v2 {id="EmptyKB"}
|
||||||
|
|
||||||
|
A function that creates an empty `KnowledgeBase` from a [`Vocab`](/api/vocab)
|
||||||
|
instance. This is the default when a new entity linker component is created. It
|
||||||
|
returns a `Callable[[Vocab, int], InMemoryLookupKB]`.
|
||||||
|
|
||||||
### spacy.KBFromFile.v1 {id="KBFromFile"}
|
### spacy.KBFromFile.v1 {id="KBFromFile"}
|
||||||
|
|
||||||
A function that reads an existing `KnowledgeBase` from file.
|
A function that reads an existing `KnowledgeBase` from file.
|
||||||
|
|
|
@ -54,7 +54,7 @@ architectures and their arguments and hyperparameters.
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
| Setting | Description |
|
| Setting | Description |
|
||||||
| ---------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
| --------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
| `labels_discard` | NER labels that will automatically get a "NIL" prediction. Defaults to `[]`. ~~Iterable[str]~~ |
|
| `labels_discard` | NER labels that will automatically get a "NIL" prediction. Defaults to `[]`. ~~Iterable[str]~~ |
|
||||||
| `n_sents` | The number of neighbouring sentences to take into account. Defaults to 0. ~~int~~ |
|
| `n_sents` | The number of neighbouring sentences to take into account. Defaults to 0. ~~int~~ |
|
||||||
| `incl_prior` | Whether or not to include prior probabilities from the KB in the model. Defaults to `True`. ~~bool~~ |
|
| `incl_prior` | Whether or not to include prior probabilities from the KB in the model. Defaults to `True`. ~~bool~~ |
|
||||||
|
@ -63,6 +63,8 @@ architectures and their arguments and hyperparameters.
|
||||||
| `entity_vector_length` | Size of encoding vectors in the KB. Defaults to `64`. ~~int~~ |
|
| `entity_vector_length` | Size of encoding vectors in the KB. Defaults to `64`. ~~int~~ |
|
||||||
| `use_gold_ents` | Whether to copy entities from the gold docs or not. Defaults to `True`. If `False`, entities must be set in the training data or by an annotating component in the pipeline. ~~int~~ |
|
| `use_gold_ents` | Whether to copy entities from the gold docs or not. Defaults to `True`. If `False`, entities must be set in the training data or by an annotating component in the pipeline. ~~int~~ |
|
||||||
| `get_candidates` | Function that generates plausible candidates for a given `Span` object. Defaults to [CandidateGenerator](/api/architectures#CandidateGenerator), a function looking up exact, case-dependent aliases in the KB. ~~Callable[[KnowledgeBase, Span], Iterable[Candidate]]~~ |
|
| `get_candidates` | Function that generates plausible candidates for a given `Span` object. Defaults to [CandidateGenerator](/api/architectures#CandidateGenerator), a function looking up exact, case-dependent aliases in the KB. ~~Callable[[KnowledgeBase, Span], Iterable[Candidate]]~~ |
|
||||||
|
| `get_candidates_batch` <Tag variant="new">3.5</Tag> | Function that generates plausible candidates for a given batch of `Span` objects. Defaults to [CandidateBatchGenerator](/api/architectures#CandidateBatchGenerator), a function looking up exact, case-dependent aliases in the KB. ~~Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]]~~ |
|
||||||
|
| `generate_empty_kb` <Tag variant="new">3.6</Tag> | Function that generates an empty `KnowledgeBase` object. Defaults to [`spacy.EmptyKB.v2`](/api/architectures#EmptyKB), which generates an empty [`InMemoryLookupKB`](/api/inmemorylookupkb). ~~Callable[[Vocab, int], KnowledgeBase]~~ |
|
||||||
| `overwrite` <Tag variant="new">3.2</Tag> | Whether existing annotation is overwritten. Defaults to `True`. ~~bool~~ |
|
| `overwrite` <Tag variant="new">3.2</Tag> | Whether existing annotation is overwritten. Defaults to `True`. ~~bool~~ |
|
||||||
| `scorer` <Tag variant="new">3.2</Tag> | The scoring method. Defaults to [`Scorer.score_links`](/api/scorer#score_links). ~~Optional[Callable]~~ |
|
| `scorer` <Tag variant="new">3.2</Tag> | The scoring method. Defaults to [`Scorer.score_links`](/api/scorer#score_links). ~~Optional[Callable]~~ |
|
||||||
| `threshold` <Tag variant="new">3.4</Tag> | Confidence threshold for entity predictions. The default of `None` implies that all predictions are accepted, otherwise those with a score beneath the treshold are discarded. If there are no predictions with scores above the threshold, the linked entity is `NIL`. ~~Optional[float]~~ |
|
| `threshold` <Tag variant="new">3.4</Tag> | Confidence threshold for entity predictions. The default of `None` implies that all predictions are accepted, otherwise those with a score beneath the treshold are discarded. If there are no predictions with scores above the threshold, the linked entity is `NIL`. ~~Optional[float]~~ |
|
||||||
|
|
Loading…
Reference in New Issue
Block a user