diff --git a/spacy/ml/models/entity_linker.py b/spacy/ml/models/entity_linker.py
index 299b6bb52..7332ca199 100644
--- a/spacy/ml/models/entity_linker.py
+++ b/spacy/ml/models/entity_linker.py
@@ -89,6 +89,14 @@ def load_kb(
return kb_from_file
+@registry.misc("spacy.EmptyKB.v2")
+def empty_kb_for_config() -> Callable[[Vocab, int], KnowledgeBase]:
+ def empty_kb_factory(vocab: Vocab, entity_vector_length: int):
+ return InMemoryLookupKB(vocab=vocab, entity_vector_length=entity_vector_length)
+
+ return empty_kb_factory
+
+
@registry.misc("spacy.EmptyKB.v1")
def empty_kb(
entity_vector_length: int,
diff --git a/spacy/pipeline/entity_linker.py b/spacy/pipeline/entity_linker.py
index a11964117..f2dae0529 100644
--- a/spacy/pipeline/entity_linker.py
+++ b/spacy/pipeline/entity_linker.py
@@ -54,6 +54,7 @@ DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"]
"entity_vector_length": 64,
"get_candidates": {"@misc": "spacy.CandidateGenerator.v1"},
"get_candidates_batch": {"@misc": "spacy.CandidateBatchGenerator.v1"},
+ "generate_empty_kb": {"@misc": "spacy.EmptyKB.v2"},
"overwrite": True,
"scorer": {"@scorers": "spacy.entity_linker_scorer.v1"},
"use_gold_ents": True,
@@ -80,6 +81,7 @@ def make_entity_linker(
get_candidates_batch: Callable[
[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]
],
+ generate_empty_kb: Callable[[Vocab, int], KnowledgeBase],
overwrite: bool,
scorer: Optional[Callable],
use_gold_ents: bool,
@@ -101,6 +103,7 @@ def make_entity_linker(
get_candidates_batch (
Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]], Iterable[Candidate]]
): Function that produces a list of candidates, given a certain knowledge base and several textual mentions.
+ generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase.
scorer (Optional[Callable]): The scoring method.
use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another
component must provide entity annotations.
@@ -135,6 +138,7 @@ def make_entity_linker(
entity_vector_length=entity_vector_length,
get_candidates=get_candidates,
get_candidates_batch=get_candidates_batch,
+ generate_empty_kb=generate_empty_kb,
overwrite=overwrite,
scorer=scorer,
use_gold_ents=use_gold_ents,
@@ -175,6 +179,7 @@ class EntityLinker(TrainablePipe):
get_candidates_batch: Callable[
[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]
],
+ generate_empty_kb: Callable[[Vocab, int], KnowledgeBase],
overwrite: bool = BACKWARD_OVERWRITE,
scorer: Optional[Callable] = entity_linker_score,
use_gold_ents: bool,
@@ -198,6 +203,7 @@ class EntityLinker(TrainablePipe):
Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]],
Iterable[Candidate]]
): Function that produces a list of candidates, given a certain knowledge base and several textual mentions.
+ generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase.
scorer (Optional[Callable]): The scoring method. Defaults to Scorer.score_links.
use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another
component must provide entity annotations.
@@ -220,6 +226,7 @@ class EntityLinker(TrainablePipe):
self.model = model
self.name = name
self.labels_discard = list(labels_discard)
+ # how many neighbour sentences to take into account
self.n_sents = n_sents
self.incl_prior = incl_prior
self.incl_context = incl_context
@@ -227,9 +234,7 @@ class EntityLinker(TrainablePipe):
self.get_candidates_batch = get_candidates_batch
self.cfg: Dict[str, Any] = {"overwrite": overwrite}
self.distance = CosineDistance(normalize=False)
- # how many neighbour sentences to take into account
- # create an empty KB by default
- self.kb = empty_kb(entity_vector_length)(self.vocab)
+ self.kb = generate_empty_kb(self.vocab, entity_vector_length)
self.scorer = scorer
self.use_gold_ents = use_gold_ents
self.candidates_batch_size = candidates_batch_size
diff --git a/spacy/tests/serialize/test_serialize_kb.py b/spacy/tests/serialize/test_serialize_kb.py
index 8d3653ab1..f9d2e226b 100644
--- a/spacy/tests/serialize/test_serialize_kb.py
+++ b/spacy/tests/serialize/test_serialize_kb.py
@@ -1,7 +1,10 @@
-from typing import Callable
+from pathlib import Path
+from typing import Callable, Iterable, Any, Dict
-from spacy import util
-from spacy.util import ensure_path, registry, load_model_from_config
+import srsly
+
+from spacy import util, Errors
+from spacy.util import ensure_path, registry, load_model_from_config, SimpleFrozenList
from spacy.kb.kb_in_memory import InMemoryLookupKB
from spacy.vocab import Vocab
from thinc.api import Config
@@ -91,7 +94,10 @@ def test_serialize_subclassed_kb():
[components.entity_linker]
factory = "entity_linker"
-
+
+ [components.entity_linker.generate_empty_kb]
+ @misc = "kb_test.CustomEmptyKB.v1"
+
[initialize]
[initialize.components]
@@ -99,7 +105,7 @@ def test_serialize_subclassed_kb():
[initialize.components.entity_linker]
[initialize.components.entity_linker.kb_loader]
- @misc = "spacy.CustomKB.v1"
+ @misc = "kb_test.CustomKB.v1"
entity_vector_length = 342
custom_field = 666
"""
@@ -109,10 +115,57 @@ def test_serialize_subclassed_kb():
super().__init__(vocab, entity_vector_length)
self.custom_field = custom_field
- @registry.misc("spacy.CustomKB.v1")
+ def to_disk(self, path, exclude: Iterable[str] = SimpleFrozenList()):
+ """We overwrite InMemoryLookupKB.to_disk() to ensure that self.custom_field is stored as well."""
+ path = ensure_path(path)
+ if not path.exists():
+ path.mkdir(parents=True)
+ if not path.is_dir():
+ raise ValueError(Errors.E928.format(loc=path))
+
+ def serialize_custom_fields(file_path: Path) -> None:
+ srsly.write_json(file_path, {"custom_field": self.custom_field})
+
+ serialize = {
+ "contents": lambda p: self.write_contents(p),
+ "strings.json": lambda p: self.vocab.strings.to_disk(p),
+ "custom_fields": lambda p: serialize_custom_fields(p),
+ }
+ util.to_disk(path, serialize, exclude)
+
+ def from_disk(self, path, exclude: Iterable[str] = SimpleFrozenList()):
+ """We overwrite InMemoryLookupKB.from_disk() to ensure that self.custom_field is loaded as well."""
+ path = ensure_path(path)
+ if not path.exists():
+ raise ValueError(Errors.E929.format(loc=path))
+ if not path.is_dir():
+ raise ValueError(Errors.E928.format(loc=path))
+
+ def deserialize_custom_fields(file_path: Path) -> None:
+ self.custom_field = srsly.read_json(file_path)["custom_field"]
+
+ deserialize: Dict[str, Callable[[Any], Any]] = {
+ "contents": lambda p: self.read_contents(p),
+ "strings.json": lambda p: self.vocab.strings.from_disk(p),
+ "custom_fields": lambda p: deserialize_custom_fields(p),
+ }
+ util.from_disk(path, deserialize, exclude)
+
+ @registry.misc("kb_test.CustomEmptyKB.v1")
+ def empty_custom_kb() -> Callable[[Vocab, int], SubInMemoryLookupKB]:
+ def empty_kb_factory(vocab: Vocab, entity_vector_length: int):
+ return SubInMemoryLookupKB(
+ vocab=vocab,
+ entity_vector_length=entity_vector_length,
+ custom_field=0,
+ )
+
+ return empty_kb_factory
+
+ @registry.misc("kb_test.CustomKB.v1")
def custom_kb(
entity_vector_length: int, custom_field: int
- ) -> Callable[[Vocab], InMemoryLookupKB]:
+ ) -> Callable[[Vocab], SubInMemoryLookupKB]:
def custom_kb_factory(vocab):
kb = SubInMemoryLookupKB(
vocab=vocab,
@@ -139,6 +192,6 @@ def test_serialize_subclassed_kb():
nlp2 = util.load_model_from_path(tmp_dir)
entity_linker2 = nlp2.get_pipe("entity_linker")
# After IO, the KB is the standard one
- assert type(entity_linker2.kb) == InMemoryLookupKB
+ assert type(entity_linker2.kb) == SubInMemoryLookupKB
assert entity_linker2.kb.entity_vector_length == 342
- assert not hasattr(entity_linker2.kb, "custom_field")
+ assert entity_linker2.kb.custom_field == 666
diff --git a/website/docs/api/architectures.mdx b/website/docs/api/architectures.mdx
index 966b5830a..268c04a07 100644
--- a/website/docs/api/architectures.mdx
+++ b/website/docs/api/architectures.mdx
@@ -899,15 +899,21 @@ The `EntityLinker` model architecture is a Thinc `Model` with a
| `nO` | Output dimension, determined by the length of the vectors encoding each entity in the KB. If the `nO` dimension is not set, the entity linking component will set it when `initialize` is called. ~~Optional[int]~~ |
| **CREATES** | The model using the architecture. ~~Model[List[Doc], Floats2d]~~ |
-### spacy.EmptyKB.v1 {id="EmptyKB"}
+### spacy.EmptyKB.v1 {id="EmptyKB.v1"}
A function that creates an empty `KnowledgeBase` from a [`Vocab`](/api/vocab)
-instance. This is the default when a new entity linker component is created.
+instance.
| Name | Description |
| ---------------------- | ----------------------------------------------------------------------------------- |
| `entity_vector_length` | The length of the vectors encoding each entity in the KB. Defaults to `64`. ~~int~~ |
+### spacy.EmptyKB.v2 {id="EmptyKB"}
+
+A function that creates an empty `KnowledgeBase` from a [`Vocab`](/api/vocab)
+instance. This is the default when a new entity linker component is created. It
+returns a `Callable[[Vocab, int], InMemoryLookupKB]`.
+
### spacy.KBFromFile.v1 {id="KBFromFile"}
A function that reads an existing `KnowledgeBase` from file.
diff --git a/website/docs/api/entitylinker.mdx b/website/docs/api/entitylinker.mdx
index bafb2f2da..d84dd3ca9 100644
--- a/website/docs/api/entitylinker.mdx
+++ b/website/docs/api/entitylinker.mdx
@@ -53,19 +53,21 @@ architectures and their arguments and hyperparameters.
> nlp.add_pipe("entity_linker", config=config)
> ```
-| Setting | Description |
-| ---------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
-| `labels_discard` | NER labels that will automatically get a "NIL" prediction. Defaults to `[]`. ~~Iterable[str]~~ |
-| `n_sents` | The number of neighbouring sentences to take into account. Defaults to 0. ~~int~~ |
-| `incl_prior` | Whether or not to include prior probabilities from the KB in the model. Defaults to `True`. ~~bool~~ |
-| `incl_context` | Whether or not to include the local context in the model. Defaults to `True`. ~~bool~~ |
-| `model` | The [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. Defaults to [EntityLinker](/api/architectures#EntityLinker). ~~Model~~ |
-| `entity_vector_length` | Size of encoding vectors in the KB. Defaults to `64`. ~~int~~ |
-| `use_gold_ents` | Whether to copy entities from the gold docs or not. Defaults to `True`. If `False`, entities must be set in the training data or by an annotating component in the pipeline. ~~int~~ |
-| `get_candidates` | Function that generates plausible candidates for a given `Span` object. Defaults to [CandidateGenerator](/api/architectures#CandidateGenerator), a function looking up exact, case-dependent aliases in the KB. ~~Callable[[KnowledgeBase, Span], Iterable[Candidate]]~~ |
-| `overwrite` 3.2 | Whether existing annotation is overwritten. Defaults to `True`. ~~bool~~ |
-| `scorer` 3.2 | The scoring method. Defaults to [`Scorer.score_links`](/api/scorer#score_links). ~~Optional[Callable]~~ |
-| `threshold` 3.4 | Confidence threshold for entity predictions. The default of `None` implies that all predictions are accepted, otherwise those with a score beneath the treshold are discarded. If there are no predictions with scores above the threshold, the linked entity is `NIL`. ~~Optional[float]~~ |
+| Setting | Description |
+| --------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| `labels_discard` | NER labels that will automatically get a "NIL" prediction. Defaults to `[]`. ~~Iterable[str]~~ |
+| `n_sents` | The number of neighbouring sentences to take into account. Defaults to 0. ~~int~~ |
+| `incl_prior` | Whether or not to include prior probabilities from the KB in the model. Defaults to `True`. ~~bool~~ |
+| `incl_context` | Whether or not to include the local context in the model. Defaults to `True`. ~~bool~~ |
+| `model` | The [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. Defaults to [EntityLinker](/api/architectures#EntityLinker). ~~Model~~ |
+| `entity_vector_length` | Size of encoding vectors in the KB. Defaults to `64`. ~~int~~ |
+| `use_gold_ents` | Whether to copy entities from the gold docs or not. Defaults to `True`. If `False`, entities must be set in the training data or by an annotating component in the pipeline. ~~int~~ |
+| `get_candidates` | Function that generates plausible candidates for a given `Span` object. Defaults to [CandidateGenerator](/api/architectures#CandidateGenerator), a function looking up exact, case-dependent aliases in the KB. ~~Callable[[KnowledgeBase, Span], Iterable[Candidate]]~~ |
+| `get_candidates_batch` 3.5 | Function that generates plausible candidates for a given batch of `Span` objects. Defaults to [CandidateBatchGenerator](/api/architectures#CandidateBatchGenerator), a function looking up exact, case-dependent aliases in the KB. ~~Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]]~~ |
+| `generate_empty_kb` 3.6 | Function that generates an empty `KnowledgeBase` object. Defaults to [`spacy.EmptyKB.v2`](/api/architectures#EmptyKB), which generates an empty [`InMemoryLookupKB`](/api/inmemorylookupkb). ~~Callable[[Vocab, int], KnowledgeBase]~~ |
+| `overwrite` 3.2 | Whether existing annotation is overwritten. Defaults to `True`. ~~bool~~ |
+| `scorer` 3.2 | The scoring method. Defaults to [`Scorer.score_links`](/api/scorer#score_links). ~~Optional[Callable]~~ |
+| `threshold` 3.4 | Confidence threshold for entity predictions. The default of `None` implies that all predictions are accepted, otherwise those with a score beneath the treshold are discarded. If there are no predictions with scores above the threshold, the linked entity is `NIL`. ~~Optional[float]~~ |
```python
%%GITHUB_SPACY/spacy/pipeline/entity_linker.py