spaCy/spacy/ml/models/entity_linker.py
Raphael Mitsch 1f23c615d7
Refactor KB for easier customization (#11268)
* Add implementation of batching + backwards compatibility fixes. Tests indicate issue with batch disambiguation for custom singular entity lookups.

* Fix tests. Add distinction w.r.t. batch size.

* Remove redundant and add new comments.

* Adjust comments. Fix variable naming in EL prediction.

* Fix mypy errors.

* Remove KB entity type config option. Change return types of candidate retrieval functions to Iterable from Iterator. Fix various other issues.

* Update spacy/pipeline/entity_linker.py

Co-authored-by: Paul O'Leary McCann <polm@dampfkraft.com>

* Update spacy/pipeline/entity_linker.py

Co-authored-by: Paul O'Leary McCann <polm@dampfkraft.com>

* Update spacy/kb_base.pyx

Co-authored-by: Paul O'Leary McCann <polm@dampfkraft.com>

* Update spacy/kb_base.pyx

Co-authored-by: Paul O'Leary McCann <polm@dampfkraft.com>

* Update spacy/pipeline/entity_linker.py

Co-authored-by: Paul O'Leary McCann <polm@dampfkraft.com>

* Add error messages to NotImplementedErrors. Remove redundant comment.

* Fix imports.

* Remove redundant comments.

* Rename KnowledgeBase to InMemoryLookupKB and BaseKnowledgeBase to KnowledgeBase.

* Fix tests.

* Update spacy/errors.py

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Move KB into subdirectory.

* Adjust imports after KB move to dedicated subdirectory.

* Fix config imports.

* Move Candidate + retrieval functions to separate module. Fix other, small issues.

* Fix docstrings and error message w.r.t. class names. Fix typing for candidate retrieval functions.

* Update spacy/kb/kb_in_memory.pyx

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Update spacy/ml/models/entity_linker.py

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Fix typing.

* Change typing of mentions to be Span instead of Union[Span, str].

* Update docs.

* Update EntityLinker and _architecture docs.

* Update website/docs/api/entitylinker.md

Co-authored-by: Paul O'Leary McCann <polm@dampfkraft.com>

* Adjust message for E1046.

* Re-add section for Candidate in kb.md, add reference to dedicated page.

* Update docs and docstrings.

* Re-add section + reference for KnowledgeBase.get_alias_candidates() in docs.

* Update spacy/kb/candidate.pyx

* Update spacy/kb/kb_in_memory.pyx

* Update spacy/pipeline/legacy/entity_linker.py

* Remove canididate.md. Remove mistakenly added config snippet in entity_linker.py.

Co-authored-by: Paul O'Leary McCann <polm@dampfkraft.com>
Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
2022-09-08 10:38:07 +02:00

113 lines
3.9 KiB
Python

from pathlib import Path
from typing import Optional, Callable, Iterable, List, Tuple
from thinc.types import Floats2d
from thinc.api import chain, list2ragged, reduce_mean, residual
from thinc.api import Model, Maxout, Linear, tuplify, Ragged
from ...util import registry
from ...kb import KnowledgeBase, InMemoryLookupKB
from ...kb import Candidate, get_candidates, get_candidates_batch
from ...vocab import Vocab
from ...tokens import Span, Doc
from ..extract_spans import extract_spans
from ...errors import Errors
@registry.architectures("spacy.EntityLinker.v2")
def build_nel_encoder(
tok2vec: Model, nO: Optional[int] = None
) -> Model[List[Doc], Floats2d]:
with Model.define_operators({">>": chain, "&": tuplify}):
token_width = tok2vec.maybe_get_dim("nO")
output_layer = Linear(nO=nO, nI=token_width)
model = (
((tok2vec >> list2ragged()) & build_span_maker())
>> extract_spans()
>> reduce_mean()
>> residual(Maxout(nO=token_width, nI=token_width, nP=2, dropout=0.0)) # type: ignore
>> output_layer
)
model.set_ref("output_layer", output_layer)
model.set_ref("tok2vec", tok2vec)
# flag to show this isn't legacy
model.attrs["include_span_maker"] = True
return model
def build_span_maker(n_sents: int = 0) -> Model:
model: Model = Model("span_maker", forward=span_maker_forward)
model.attrs["n_sents"] = n_sents
return model
def span_maker_forward(model, docs: List[Doc], is_train) -> Tuple[Ragged, Callable]:
ops = model.ops
n_sents = model.attrs["n_sents"]
candidates = []
for doc in docs:
cands = []
try:
sentences = [s for s in doc.sents]
except ValueError:
# no sentence info, normal in initialization
for tok in doc:
tok.is_sent_start = tok.i == 0
sentences = [doc[:]]
for ent in doc.ents:
try:
# find the sentence in the list of sentences.
sent_index = sentences.index(ent.sent)
except AttributeError:
# Catch the exception when ent.sent is None and provide a user-friendly warning
raise RuntimeError(Errors.E030) from None
# get n previous sentences, if there are any
start_sentence = max(0, sent_index - n_sents)
# get n posterior sentences, or as many < n as there are
end_sentence = min(len(sentences) - 1, sent_index + n_sents)
# get token positions
start_token = sentences[start_sentence].start
end_token = sentences[end_sentence].end
# save positions for extraction
cands.append((start_token, end_token))
candidates.append(ops.asarray2i(cands))
candlens = ops.asarray1i([len(cands) for cands in candidates])
candidates = ops.xp.concatenate(candidates)
outputs = Ragged(candidates, candlens)
# because this is just rearranging docs, the backprop does nothing
return outputs, lambda x: []
@registry.misc("spacy.KBFromFile.v1")
def load_kb(
kb_path: Path,
) -> Callable[[Vocab], KnowledgeBase]:
def kb_from_file(vocab: Vocab):
kb = InMemoryLookupKB(vocab, entity_vector_length=1)
kb.from_disk(kb_path)
return kb
return kb_from_file
@registry.misc("spacy.EmptyKB.v1")
def empty_kb(
entity_vector_length: int,
) -> Callable[[Vocab], KnowledgeBase]:
def empty_kb_factory(vocab: Vocab):
return InMemoryLookupKB(vocab=vocab, entity_vector_length=entity_vector_length)
return empty_kb_factory
@registry.misc("spacy.CandidateGenerator.v1")
def create_candidates() -> Callable[[KnowledgeBase, Span], Iterable[Candidate]]:
return get_candidates
@registry.misc("spacy.CandidateBatchGenerator.v1")
def create_candidates_batch() -> Callable[
[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]
]:
return get_candidates_batch