Convert batched into doc-wise batched candidate generation.

This commit is contained in:
Raphael Mitsch 2022-10-18 15:31:15 +02:00
parent 2ce6aadda2
commit 7c28424f47
7 changed files with 182 additions and 155 deletions

View File

@ -946,11 +946,10 @@ class Errors(metaclass=ErrorsWithCodes):
"case pass an empty list for the previously not specified argument to avoid this error.")
E1043 = ("Expected None or a value in range [{range_start}, {range_end}] for entity linker threshold, but got "
"{value}.")
E1044 = ("Expected `candidates_batch_size` to be >= 1, but got: {value}")
E1045 = ("Encountered {parent} subclass without `{parent}.{method}` "
E1044 = ("Encountered {parent} subclass without `{parent}.{method}` "
"method in '{name}'. If you want to use this method, make "
"sure it's overwritten on the subclass.")
E1046 = ("{cls_name} is an abstract class and cannot be instantiated. If you are looking for spaCy's default "
E1045 = ("{cls_name} is an abstract class and cannot be instantiated. If you are looking for spaCy's default "
"knowledge base, use `InMemoryLookupKB`.")

View File

@ -1,3 +1,3 @@
from .kb import KnowledgeBase
from .kb_in_memory import InMemoryLookupKB
from .candidate import Candidate, get_candidates, get_candidates_batch
from .candidate import Candidate, get_candidates, get_candidates_all

View File

@ -1,6 +1,6 @@
# cython: infer_types=True, profile=True
from typing import Iterable
from typing import Iterable, Generator
from .kb cimport KnowledgeBase
from ..tokens import Span
@ -64,11 +64,13 @@ def get_candidates(kb: KnowledgeBase, mention: Span) -> Iterable[Candidate]:
return kb.get_candidates(mention)
def get_candidates_batch(kb: KnowledgeBase, mentions: Iterable[Span]) -> Iterable[Iterable[Candidate]]:
def get_candidates_all(
kb: KnowledgeBase, mentions: Generator[Iterable[Span], None, None]
) -> Generator[Iterable[Iterable[Candidate]], None, None]:
"""
Return candidate entities for the given mentions and fetching appropriate entries from the index.
kb (KnowledgeBase): Knowledge base to query.
mention (Iterable[Span]): Entity mentions for which to identify candidates.
RETURNS (Iterable[Iterable[Candidate]]): Identified candidates.
mention (Generator[Iterable[Span]]): Entity mentions per document for which to identify candidates.
RETURNS (Generator[Iterable[Iterable[Candidate]]]): Identified candidates per document.
"""
return kb.get_candidates_batch(mentions)
return kb.get_candidates_all(mentions)

View File

@ -1,7 +1,7 @@
# cython: infer_types=True, profile=True
from pathlib import Path
from typing import Iterable, Tuple, Union
from typing import Iterable, Tuple, Union, Generator
from cymem.cymem cimport Pool
from .candidate import Candidate
@ -23,22 +23,24 @@ cdef class KnowledgeBase:
# Make sure abstract KB is not instantiated.
if self.__class__ == KnowledgeBase:
raise TypeError(
Errors.E1046.format(cls_name=self.__class__.__name__)
Errors.E1045.format(cls_name=self.__class__.__name__)
)
self.vocab = vocab
self.entity_vector_length = entity_vector_length
self.mem = Pool()
def get_candidates_batch(self, mentions: Iterable[Span]) -> Iterable[Iterable[Candidate]]:
def get_candidates_all(self, mentions: Generator[Iterable[Span]]) -> Generator[Iterable[Iterable[Candidate]]]:
"""
Return candidate entities for specified texts. Each candidate defines the entity, the original alias,
and the prior probability of that alias resolving to that entity.
If no candidate is found for a given text, an empty list is returned.
mentions (Iterable[Span]): Mentions for which to get candidates.
RETURNS (Iterable[Iterable[Candidate]]): Identified candidates.
mentions (Generator[Iterable[Span]]): Mentions per documents for which to get candidates.
RETURNS (Generator[Iterable[Iterable[Candidate]]]): Identified candidates per document.
"""
return [self.get_candidates(span) for span in mentions]
for doc_mentions in mentions:
yield [self.get_candidates(span) for span in doc_mentions]
def get_candidates(self, mention: Span) -> Iterable[Candidate]:
"""
@ -49,7 +51,7 @@ cdef class KnowledgeBase:
RETURNS (Iterable[Candidate]): Identified candidates.
"""
raise NotImplementedError(
Errors.E1045.format(parent="KnowledgeBase", method="get_candidates", name=self.__name__)
Errors.E1044.format(parent="KnowledgeBase", method="get_candidates", name=self.__name__)
)
def get_vectors(self, entities: Iterable[str]) -> Iterable[Iterable[float]]:
@ -67,7 +69,7 @@ cdef class KnowledgeBase:
RETURNS (Iterable[float]): Vector for specified entity.
"""
raise NotImplementedError(
Errors.E1045.format(parent="KnowledgeBase", method="get_vector", name=self.__name__)
Errors.E1044.format(parent="KnowledgeBase", method="get_vector", name=self.__name__)
)
def to_bytes(self, **kwargs) -> bytes:
@ -75,7 +77,7 @@ cdef class KnowledgeBase:
RETURNS (bytes): Current state as binary string.
"""
raise NotImplementedError(
Errors.E1045.format(parent="KnowledgeBase", method="to_bytes", name=self.__name__)
Errors.E1044.format(parent="KnowledgeBase", method="to_bytes", name=self.__name__)
)
def from_bytes(self, bytes_data: bytes, *, exclude: Tuple[str] = tuple()):
@ -84,7 +86,7 @@ cdef class KnowledgeBase:
exclude (Tuple[str]): Properties to exclude when restoring KB.
"""
raise NotImplementedError(
Errors.E1045.format(parent="KnowledgeBase", method="from_bytes", name=self.__name__)
Errors.E1044.format(parent="KnowledgeBase", method="from_bytes", name=self.__name__)
)
def to_disk(self, path: Union[str, Path], exclude: Iterable[str] = SimpleFrozenList()) -> None:
@ -94,7 +96,7 @@ cdef class KnowledgeBase:
exclude (Iterable[str]): List of components to exclude.
"""
raise NotImplementedError(
Errors.E1045.format(parent="KnowledgeBase", method="to_disk", name=self.__name__)
Errors.E1044.format(parent="KnowledgeBase", method="to_disk", name=self.__name__)
)
def from_disk(self, path: Union[str, Path], exclude: Iterable[str] = SimpleFrozenList()) -> None:
@ -104,5 +106,5 @@ cdef class KnowledgeBase:
exclude (Iterable[str]): List of components to exclude.
"""
raise NotImplementedError(
Errors.E1045.format(parent="KnowledgeBase", method="from_disk", name=self.__name__)
Errors.E1044.format(parent="KnowledgeBase", method="from_disk", name=self.__name__)
)

View File

@ -1,12 +1,12 @@
from pathlib import Path
from typing import Optional, Callable, Iterable, List, Tuple
from typing import Optional, Callable, Iterable, List, Tuple, Generator
from thinc.types import Floats2d
from thinc.api import chain, list2ragged, reduce_mean, residual
from thinc.api import Model, Maxout, Linear, tuplify, Ragged
from ...util import registry
from ...kb import KnowledgeBase, InMemoryLookupKB
from ...kb import Candidate, get_candidates, get_candidates_batch
from ...kb import Candidate, get_candidates, get_candidates_all
from ...vocab import Vocab
from ...tokens import Span, Doc
from ..extract_spans import extract_spans
@ -105,8 +105,8 @@ def create_candidates() -> Callable[[KnowledgeBase, Span], Iterable[Candidate]]:
return get_candidates
@registry.misc("spacy.CandidateBatchGenerator.v1")
def create_candidates_batch() -> Callable[
[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]
@registry.misc("spacy.CandidateAllGenerator.v1")
def create_candidates_all() -> Callable[
[KnowledgeBase, Generator[Iterable[Span], None, None]], Generator[Iterable[Iterable[Candidate]], None, None]
]:
return get_candidates_batch
return get_candidates_all

View File

@ -1,4 +1,4 @@
from typing import Optional, Iterable, Callable, Dict, Union, List, Any
from typing import Optional, Iterable, Callable, Dict, Union, List, Any, Generator
from thinc.types import Floats2d
from pathlib import Path
from itertools import islice
@ -53,11 +53,11 @@ DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"]
"incl_context": True,
"entity_vector_length": 64,
"get_candidates": {"@misc": "spacy.CandidateGenerator.v1"},
"get_candidates_batch": {"@misc": "spacy.CandidateBatchGenerator.v1"},
"get_candidates_all": {"@misc": "spacy.CandidateAllGenerator.v1"},
"overwrite": True,
"scorer": {"@scorers": "spacy.entity_linker_scorer.v1"},
"use_gold_ents": True,
"candidates_batch_size": 1,
"candidates_doc_mode": False,
"threshold": None,
},
default_score_weights={
@ -77,13 +77,14 @@ def make_entity_linker(
incl_context: bool,
entity_vector_length: int,
get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]],
get_candidates_batch: Callable[
[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]
get_candidates_all: Callable[
[KnowledgeBase, Generator[Iterable[Span], None, None]],
Generator[Iterable[Iterable[Candidate]], None, None]
],
overwrite: bool,
scorer: Optional[Callable],
use_gold_ents: bool,
candidates_batch_size: int,
candidates_doc_mode: bool,
threshold: Optional[float] = None,
):
"""Construct an EntityLinker component.
@ -98,13 +99,18 @@ def make_entity_linker(
entity_vector_length (int): Size of encoding vectors in the KB.
get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function that
produces a list of candidates, given a certain knowledge base and a textual mention.
get_candidates_batch (
Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]], Iterable[Candidate]]
): Function that produces a list of candidates, given a certain knowledge base and several textual mentions.
get_candidates_all (
Callable[
[KnowledgeBase, Generator[Iterable[Span], None, None]],
Generator[Iterable[Iterable[Candidate]], None, None]
]): Function that produces a list of candidates per document, given a certain knowledge base and several textual
documents with textual mentions.
scorer (Optional[Callable]): The scoring method.
use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another
component must provide entity annotations.
candidates_batch_size (int): Size of batches for entity candidate generation.
candidates_doc_mode (bool): Whether or not to operate candidate generation in doc mode, i.e. to provide a generator
yielding entities per document (candidate generator callable is called only once in this case). If False,
the candidate generator is called once per entity.
threshold (Optional[float]): Confidence threshold for entity predictions. If confidence is below the threshold,
prediction is discarded. If None, predictions are not filtered by any threshold.
"""
@ -134,11 +140,11 @@ def make_entity_linker(
incl_context=incl_context,
entity_vector_length=entity_vector_length,
get_candidates=get_candidates,
get_candidates_batch=get_candidates_batch,
get_candidates_all=get_candidates_all,
overwrite=overwrite,
scorer=scorer,
use_gold_ents=use_gold_ents,
candidates_batch_size=candidates_batch_size,
candidates_doc_mode=candidates_doc_mode,
threshold=threshold,
)
@ -172,13 +178,14 @@ class EntityLinker(TrainablePipe):
incl_context: bool,
entity_vector_length: int,
get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]],
get_candidates_batch: Callable[
[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]
get_candidates_all: Callable[
[KnowledgeBase, Generator[Iterable[Span], None, None]],
Generator[Iterable[Iterable[Candidate]], None, None]
],
overwrite: bool = BACKWARD_OVERWRITE,
scorer: Optional[Callable] = entity_linker_score,
use_gold_ents: bool,
candidates_batch_size: int,
candidates_doc_mode: bool,
threshold: Optional[float] = None,
) -> None:
"""Initialize an entity linker.
@ -194,14 +201,18 @@ class EntityLinker(TrainablePipe):
entity_vector_length (int): Size of encoding vectors in the KB.
get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function that
produces a list of candidates, given a certain knowledge base and a textual mention.
get_candidates_batch (
Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]],
Iterable[Candidate]]
): Function that produces a list of candidates, given a certain knowledge base and several textual mentions.
get_candidates_all (
Callable[
[KnowledgeBase, Generator[Iterable[Span], None, None]],
Generator[Iterable[Iterable[Candidate]], None, None]
]): Function that produces a list of candidates per document, given a certain knowledge base and several textual
documents with textual mentions.
scorer (Optional[Callable]): The scoring method. Defaults to Scorer.score_links.
use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another
component must provide entity annotations.
candidates_batch_size (int): Size of batches for entity candidate generation.
candidates_doc_mode (bool): Whether or not to operate candidate generation in doc mode, i.e. to provide a generator
yielding entities per document (candidate generator callable is called only once in this case). If False,
the candidate generator is called once per entity.
threshold (Optional[float]): Confidence threshold for entity predictions. If confidence is below the
threshold, prediction is discarded. If None, predictions are not filtered by any threshold.
DOCS: https://spacy.io/api/entitylinker#init
@ -224,7 +235,7 @@ class EntityLinker(TrainablePipe):
self.incl_prior = incl_prior
self.incl_context = incl_context
self.get_candidates = get_candidates
self.get_candidates_batch = get_candidates_batch
self.get_candidates_all = get_candidates_all
self.cfg: Dict[str, Any] = {"overwrite": overwrite}
self.distance = CosineDistance(normalize=False)
# how many neighbour sentences to take into account
@ -232,12 +243,9 @@ class EntityLinker(TrainablePipe):
self.kb = empty_kb(entity_vector_length)(self.vocab)
self.scorer = scorer
self.use_gold_ents = use_gold_ents
self.candidates_batch_size = candidates_batch_size
self.candidates_doc_mode = candidates_doc_mode
self.threshold = threshold
if candidates_batch_size < 1:
raise ValueError(Errors.E1044)
def set_kb(self, kb_loader: Callable[[Vocab], KnowledgeBase]):
"""Define the KB of this pipe by providing a function that will
create it using this object's vocab."""
@ -440,96 +448,98 @@ class EntityLinker(TrainablePipe):
return final_kb_ids
if isinstance(docs, Doc):
docs = [docs]
for i, doc in enumerate(docs):
# Determine which entities are to be ignored due to labels_discard.
valid_ent_idx_per_doc = (
[
idx
for idx in range(len(doc.ents))
if doc.ents[idx].label_ not in self.labels_discard
]
for doc in docs if len(doc.ents)
)
# Call candidate generator.
if self.candidates_doc_mode:
all_ent_cands = self.get_candidates_all(
self.kb,
([doc.ents[idx] for idx in next(valid_ent_idx_per_doc)] for doc in docs if len(doc.ents))
)
else:
# Alternative: collect entities the old-fashioned way - by retrieving entities individually.
all_ent_cands = (
[self.get_candidates(self.kb, doc.ents[idx]) for idx in next(valid_ent_idx_per_doc)]
for doc in docs if len(doc.ents)
)
for doc_idx, doc in enumerate(docs):
if len(doc) == 0:
continue
sentences = [s for s in doc.sents]
doc_ent_cands = list(next(all_ent_cands)) if len(doc.ents) else []
# Loop over entities in batches.
for ent_idx in range(0, len(doc.ents), self.candidates_batch_size):
ent_batch = doc.ents[ent_idx : ent_idx + self.candidates_batch_size]
# Looping over candidate entities for this doc. (TODO: rewrite)
for ent_cand_idx, ent in enumerate(doc.ents):
sent_index = sentences.index(ent.sent)
assert sent_index >= 0
# Look up candidate entities.
valid_ent_idx = [
idx
for idx in range(len(ent_batch))
if ent_batch[idx].label_ not in self.labels_discard
]
batch_candidates = list(
self.get_candidates_batch(
self.kb, [ent_batch[idx] for idx in valid_ent_idx]
if self.incl_context:
# get n_neighbour sentences, clipped to the length of the document
start_sentence = max(0, sent_index - self.n_sents)
end_sentence = min(
len(sentences) - 1, sent_index + self.n_sents
)
if self.candidates_batch_size > 1
else [
self.get_candidates(self.kb, ent_batch[idx])
for idx in valid_ent_idx
]
)
# Looping through each entity in batch (TODO: rewrite)
for j, ent in enumerate(ent_batch):
sent_index = sentences.index(ent.sent)
assert sent_index >= 0
if self.incl_context:
# get n_neighbour sentences, clipped to the length of the document
start_sentence = max(0, sent_index - self.n_sents)
end_sentence = min(
len(sentences) - 1, sent_index + self.n_sents
)
start_token = sentences[start_sentence].start
end_token = sentences[end_sentence].end
sent_doc = doc[start_token:end_token].as_doc()
# currently, the context is the same for each entity in a sentence (should be refined)
sentence_encoding = self.model.predict([sent_doc])[0]
sentence_encoding_t = sentence_encoding.T
sentence_norm = xp.linalg.norm(sentence_encoding_t)
entity_count += 1
if ent.label_ in self.labels_discard:
# ignoring this entity - setting to NIL
start_token = sentences[start_sentence].start
end_token = sentences[end_sentence].end
sent_doc = doc[start_token:end_token].as_doc()
# currently, the context is the same for each entity in a sentence (should be refined)
sentence_encoding = self.model.predict([sent_doc])[0]
sentence_encoding_t = sentence_encoding.T
sentence_norm = xp.linalg.norm(sentence_encoding_t)
entity_count += 1
if ent.label_ in self.labels_discard:
# ignoring this entity - setting to NIL
final_kb_ids.append(self.NIL)
else:
candidates = list(doc_ent_cands[ent_cand_idx])
if not candidates:
# no prediction possible for this entity - setting to NIL
final_kb_ids.append(self.NIL)
elif len(candidates) == 1 and self.threshold is None:
# shortcut for efficiency reasons: take the 1 candidate
final_kb_ids.append(candidates[0].entity_)
else:
candidates = list(batch_candidates[j])
if not candidates:
# no prediction possible for this entity - setting to NIL
final_kb_ids.append(self.NIL)
elif len(candidates) == 1 and self.threshold is None:
# shortcut for efficiency reasons: take the 1 candidate
final_kb_ids.append(candidates[0].entity_)
else:
random.shuffle(candidates)
# set all prior probabilities to 0 if incl_prior=False
prior_probs = xp.asarray([c.prior_prob for c in candidates])
if not self.incl_prior:
prior_probs = xp.asarray([0.0 for _ in candidates])
scores = prior_probs
# add in similarity from the context
if self.incl_context:
entity_encodings = xp.asarray(
[c.entity_vector for c in candidates]
)
entity_norm = xp.linalg.norm(entity_encodings, axis=1)
if len(entity_encodings) != len(prior_probs):
raise RuntimeError(
Errors.E147.format(
method="predict",
msg="vectors not of equal length",
)
)
# cosine similarity
sims = xp.dot(entity_encodings, sentence_encoding_t) / (
sentence_norm * entity_norm
)
if sims.shape != prior_probs.shape:
raise ValueError(Errors.E161)
scores = prior_probs + sims - (prior_probs * sims)
final_kb_ids.append(
candidates[scores.argmax().item()].entity_
if self.threshold is None
or scores.max() >= self.threshold
else EntityLinker.NIL
random.shuffle(candidates)
# set all prior probabilities to 0 if incl_prior=False
prior_probs = xp.asarray([c.prior_prob for c in candidates])
if not self.incl_prior:
prior_probs = xp.asarray([0.0 for _ in candidates])
scores = prior_probs
# add in similarity from the context
if self.incl_context:
entity_encodings = xp.asarray(
[c.entity_vector for c in candidates]
)
entity_norm = xp.linalg.norm(entity_encodings, axis=1)
if len(entity_encodings) != len(prior_probs):
raise RuntimeError(
Errors.E147.format(
method="predict",
msg="vectors not of equal length",
)
)
# cosine similarity
sims = xp.dot(entity_encodings, sentence_encoding_t) / (
sentence_norm * entity_norm
)
if sims.shape != prior_probs.shape:
raise ValueError(Errors.E161)
scores = prior_probs + sims - (prior_probs * sims)
final_kb_ids.append(
candidates[scores.argmax().item()].entity_
if self.threshold is None
or scores.max() >= self.threshold
else EntityLinker.NIL
)
if not (len(final_kb_ids) == entity_count):
err = Errors.E147.format(

View File

@ -1,4 +1,4 @@
from typing import Callable, Iterable, Dict, Any
from typing import Callable, Iterable, Dict, Any, Generator
import pytest
from numpy.testing import assert_equal
@ -497,11 +497,14 @@ def test_el_pipe_configuration(nlp):
assert doc[1].ent_kb_id_ == ""
assert doc[2].ent_kb_id_ == "Q2"
# Replace the pipe with a new one with with a different candidate generator.
def get_lowercased_candidates(kb, span):
return kb.get_alias_candidates(span.text.lower())
def get_lowercased_candidates_batch(kb, spans):
return [get_lowercased_candidates(kb, span) for span in spans]
def get_lowercased_candidates_all(kb, spans_per_doc):
for doc_spans in spans_per_doc:
yield [get_lowercased_candidates(kb, span) for span in doc_spans]
@registry.misc("spacy.LowercaseCandidateGenerator.v1")
def create_candidates() -> Callable[
@ -509,29 +512,39 @@ def test_el_pipe_configuration(nlp):
]:
return get_lowercased_candidates
@registry.misc("spacy.LowercaseCandidateBatchGenerator.v1")
@registry.misc("spacy.LowercaseCandidateAllGenerator.v1")
def create_candidates_batch() -> Callable[
[InMemoryLookupKB, Iterable["Span"]], Iterable[Iterable[Candidate]]
[InMemoryLookupKB, Generator[Iterable["Span"], None, None]],
Generator[Iterable[Iterable[Candidate]], None, None]
]:
return get_lowercased_candidates_batch
return get_lowercased_candidates_all
# replace the pipe with a new one with with a different candidate generator
entity_linker = nlp.replace_pipe(
"entity_linker",
"entity_linker",
config={
"incl_context": False,
"get_candidates": {"@misc": "spacy.LowercaseCandidateGenerator.v1"},
"get_candidates_batch": {
"@misc": "spacy.LowercaseCandidateBatchGenerator.v1"
def test_reconfigured_el(candidates_doc_mode: bool, doc_text: str) -> None:
"""Test reconfigured EL for correct results.
candidates_doc_mode (bool): candidates_doc_mode in pipe config.
doc_text (str): Text to infer.
"""
_entity_linker = nlp.replace_pipe(
"entity_linker",
"entity_linker",
config={
"incl_context": False,
"candidates_doc_mode": candidates_doc_mode,
"get_candidates": {"@misc": "spacy.LowercaseCandidateGenerator.v1"},
"get_candidates_all": {
"@misc": "spacy.LowercaseCandidateAllGenerator.v1"
},
},
},
)
entity_linker.set_kb(create_kb)
doc = nlp(text)
assert doc[0].ent_kb_id_ == "Q2"
assert doc[1].ent_kb_id_ == ""
assert doc[2].ent_kb_id_ == "Q2"
)
_entity_linker.set_kb(create_kb)
_doc = nlp(doc_text)
assert _doc[0].ent_kb_id_ == "Q2"
assert _doc[1].ent_kb_id_ == ""
assert _doc[2].ent_kb_id_ == "Q2"
# Test individual and doc-wise candidate generation.
test_reconfigured_el(False, text)
test_reconfigured_el(True, text)
def test_nel_nsents(nlp):
@ -670,6 +683,7 @@ def test_preserving_links_asdoc(nlp):
assert s_ent.kb_id_ == orig_kb_id
def test_preserving_links_ents(nlp):
"""Test that doc.ents preserves KB annotations"""
text = "She lives in Boston. He lives in Denver."