diff --git a/setup.py b/setup.py index 243554c7a..79bdcba8d 100755 --- a/setup.py +++ b/setup.py @@ -30,7 +30,6 @@ MOD_NAMES = [ "spacy.lexeme", "spacy.vocab", "spacy.attrs", - "spacy.kb.candidate", "spacy.kb.kb", "spacy.kb.kb_in_memory", "spacy.ml.parser_model", diff --git a/spacy/errors.py b/spacy/errors.py index 40cfa8d92..3dce0ef2c 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -957,11 +957,10 @@ class Errors(metaclass=ErrorsWithCodes): "case pass an empty list for the previously not specified argument to avoid this error.") E1043 = ("Expected None or a value in range [{range_start}, {range_end}] for entity linker threshold, but got " "{value}.") - E1044 = ("Expected `candidates_batch_size` to be >= 1, but got: {value}") - E1045 = ("Encountered {parent} subclass without `{parent}.{method}` " + E1044 = ("Encountered {parent} subclass without `{parent}.{method}` " "method in '{name}'. If you want to use this method, make " "sure it's overwritten on the subclass.") - E1046 = ("{cls_name} is an abstract class and cannot be instantiated. If you are looking for spaCy's default " + E1045 = ("{cls_name} is an abstract class and cannot be instantiated. If you are looking for spaCy's default " "knowledge base, use `InMemoryLookupKB`.") E1047 = ("`find_threshold()` only supports components with a `scorer` attribute.") E1048 = ("Got '{unexpected}' as console progress bar type, but expected one of the following: {expected}") diff --git a/spacy/kb/__init__.py b/spacy/kb/__init__.py index 1d70a9b34..6dd4a3222 100644 --- a/spacy/kb/__init__.py +++ b/spacy/kb/__init__.py @@ -1,3 +1,5 @@ from .kb import KnowledgeBase from .kb_in_memory import InMemoryLookupKB -from .candidate import Candidate, get_candidates, get_candidates_batch +from .candidate import Candidate + +__all__ = ["KnowledgeBase", "InMemoryLookupKB", "Candidate"] diff --git a/spacy/kb/candidate.pxd b/spacy/kb/candidate.pxd deleted file mode 100644 index 942ce9dd0..000000000 --- a/spacy/kb/candidate.pxd +++ /dev/null @@ -1,12 +0,0 @@ -from .kb cimport KnowledgeBase -from libcpp.vector cimport vector -from ..typedefs cimport hash_t - -# Object used by the Entity Linker that summarizes one entity-alias candidate combination. -cdef class Candidate: - cdef readonly KnowledgeBase kb - cdef hash_t entity_hash - cdef float entity_freq - cdef vector[float] entity_vector - cdef hash_t alias_hash - cdef float prior_prob diff --git a/spacy/kb/candidate.py b/spacy/kb/candidate.py new file mode 100644 index 000000000..190792fbe --- /dev/null +++ b/spacy/kb/candidate.py @@ -0,0 +1,109 @@ +import abc +from typing import List, Union, Callable + + +class BaseCandidate(abc.ABC): + """A `BaseCandidate` object refers to a textual mention (`alias`) that may or may not be resolved + to a specific `entity_id` from a Knowledge Base. This will be used as input for the entity_id linking + algorithm which will disambiguate the various candidates to the correct one. + Each candidate (alias, entity_id) pair is assigned a certain prior probability. + + DOCS: https://spacy.io/api/kb/#candidate-init + """ + + def __init__( + self, mention: str, entity_id: Union[int, str], entity_vector: List[float] + ): + """Create new instance of `Candidate`. Note: has to be a sub-class, otherwise error will be raised. + mention (str): Mention text for this candidate. + entity_id (Union[int, str]): Unique entity ID. + entity_vector (List[float]): Entity embedding. + """ + self._mention = mention + self._entity_id = entity_id + self._entity_vector = entity_vector + + @property + def entity(self) -> Union[int, str]: + """RETURNS (Union[int, str]): Entity ID.""" + return self._entity_id + + @property + @abc.abstractmethod + def entity_(self) -> str: + """RETURNS (str): Entity name.""" + + @property + def mention(self) -> str: + """RETURNS (str): Mention.""" + return self._mention + + @property + def entity_vector(self) -> List[float]: + """RETURNS (List[float]): Entity vector.""" + return self._entity_vector + + +class Candidate(BaseCandidate): + """`Candidate` for InMemoryLookupKBCandidate.""" + + def __init__( + self, + retrieve_string_from_hash: Callable[[int], str], + entity_hash: int, + entity_freq: int, + entity_vector: List[float], + alias_hash: int, + prior_prob: float, + ): + """ + retrieve_string_from_hash (Callable[[int], str]): Callable retrieveing entity name from provided entity/vocab + hash. + entity_hash (str): Hashed entity name /ID. + entity_freq (int): Entity frequency in KB corpus. + entity_vector (List[float]): Entity embedding. + alias_hash (int): Hashed alias. + prior_prob (float): Prior probability of entity for this mention - i.e. the probability that, independent of + the context, this mention resolves to this entity_id in the corpus used to build the knowledge base. In + cases in which this isn't always possible (e.g.: the corpus to analyse contains mentions that the KB corpus + doesn't) it might be better to eschew this information and always supply the same value. + """ + super().__init__( + mention=retrieve_string_from_hash(alias_hash), + entity_id=entity_hash, + entity_vector=entity_vector, + ) + self._retrieve_string_from_hash = retrieve_string_from_hash + self._entity_hash = entity_hash + self._entity_freq = entity_freq + self._alias_hash = alias_hash + self._prior_prob = prior_prob + + @property + def entity(self) -> int: + """RETURNS (int): hash of the entity_id's KB ID/name""" + return self._entity_hash + + @property + def entity_(self) -> str: + """RETURNS (str): ID/name of this entity_id in the KB""" + return self._retrieve_string_from_hash(self._entity_hash) + + @property + def alias(self) -> int: + """RETURNS (int): hash of the alias""" + return self._alias_hash + + @property + def alias_(self) -> str: + """RETURNS (str): ID of the original alias""" + return self._retrieve_string_from_hash(self._alias_hash) + + @property + def entity_freq(self) -> float: + return self._entity_freq + + @property + def prior_prob(self) -> float: + """RETURNS (List[float]): Entity vector.""" + return self._prior_prob diff --git a/spacy/kb/candidate.pyx b/spacy/kb/candidate.pyx deleted file mode 100644 index c89efeb03..000000000 --- a/spacy/kb/candidate.pyx +++ /dev/null @@ -1,74 +0,0 @@ -# cython: infer_types=True, profile=True - -from typing import Iterable -from .kb cimport KnowledgeBase -from ..tokens import Span - -cdef class Candidate: - """A `Candidate` object refers to a textual mention (`alias`) that may or may not be resolved - to a specific `entity` from a Knowledge Base. This will be used as input for the entity linking - algorithm which will disambiguate the various candidates to the correct one. - Each candidate (alias, entity) pair is assigned a certain prior probability. - - DOCS: https://spacy.io/api/kb/#candidate-init - """ - - def __init__(self, KnowledgeBase kb, entity_hash, entity_freq, entity_vector, alias_hash, prior_prob): - self.kb = kb - self.entity_hash = entity_hash - self.entity_freq = entity_freq - self.entity_vector = entity_vector - self.alias_hash = alias_hash - self.prior_prob = prior_prob - - @property - def entity(self) -> int: - """RETURNS (uint64): hash of the entity's KB ID/name""" - return self.entity_hash - - @property - def entity_(self) -> str: - """RETURNS (str): ID/name of this entity in the KB""" - return self.kb.vocab.strings[self.entity_hash] - - @property - def alias(self) -> int: - """RETURNS (uint64): hash of the alias""" - return self.alias_hash - - @property - def alias_(self) -> str: - """RETURNS (str): ID of the original alias""" - return self.kb.vocab.strings[self.alias_hash] - - @property - def entity_freq(self) -> float: - return self.entity_freq - - @property - def entity_vector(self) -> Iterable[float]: - return self.entity_vector - - @property - def prior_prob(self) -> float: - return self.prior_prob - - -def get_candidates(kb: KnowledgeBase, mention: Span) -> Iterable[Candidate]: - """ - Return candidate entities for a given mention and fetching appropriate entries from the index. - kb (KnowledgeBase): Knowledge base to query. - mention (Span): Entity mention for which to identify candidates. - RETURNS (Iterable[Candidate]): Identified candidates. - """ - return kb.get_candidates(mention) - - -def get_candidates_batch(kb: KnowledgeBase, mentions: Iterable[Span]) -> Iterable[Iterable[Candidate]]: - """ - Return candidate entities for the given mentions and fetching appropriate entries from the index. - kb (KnowledgeBase): Knowledge base to query. - mention (Iterable[Span]): Entity mentions for which to identify candidates. - RETURNS (Iterable[Iterable[Candidate]]): Identified candidates. - """ - return kb.get_candidates_batch(mentions) diff --git a/spacy/kb/kb.pxd b/spacy/kb/kb.pxd index 1adeef8ae..7261287eb 100644 --- a/spacy/kb/kb.pxd +++ b/spacy/kb/kb.pxd @@ -7,4 +7,4 @@ from ..vocab cimport Vocab cdef class KnowledgeBase: cdef Pool mem cdef readonly Vocab vocab - cdef readonly int64_t entity_vector_length + cdef public int64_t entity_vector_length diff --git a/spacy/kb/kb.pyx b/spacy/kb/kb.pyx index ce4bc0138..bc8d54761 100644 --- a/spacy/kb/kb.pyx +++ b/spacy/kb/kb.pyx @@ -1,11 +1,11 @@ # cython: infer_types=True, profile=True from pathlib import Path -from typing import Iterable, Tuple, Union +from typing import Iterable, Tuple, Union, Iterator, TypeVar, Type, Callable from cymem.cymem cimport Pool from .candidate import Candidate -from ..tokens import Span +from ..tokens import Span, SpanGroup, Doc from ..util import SimpleFrozenList from ..errors import Errors @@ -18,38 +18,52 @@ cdef class KnowledgeBase: DOCS: https://spacy.io/api/kb """ + _KBType = TypeVar("_KBType", bound=KnowledgeBase) + def __init__(self, vocab: Vocab, entity_vector_length: int): """Create a KnowledgeBase.""" # Make sure abstract KB is not instantiated. if self.__class__ == KnowledgeBase: raise TypeError( - Errors.E1046.format(cls_name=self.__class__.__name__) + Errors.E1045.format(cls_name=self.__class__.__name__) ) self.vocab = vocab self.entity_vector_length = entity_vector_length self.mem = Pool() - def get_candidates_batch(self, mentions: Iterable[Span]) -> Iterable[Iterable[Candidate]]: + def get_candidates_all(self, mentions: Iterator[SpanGroup]) -> Iterator[Iterable[Iterable[Candidate]]]: """ - Return candidate entities for specified texts. Each candidate defines the entity, the original alias, - and the prior probability of that alias resolving to that entity. - If no candidate is found for a given text, an empty list is returned. - mentions (Iterable[Span]): Mentions for which to get candidates. - RETURNS (Iterable[Iterable[Candidate]]): Identified candidates. + Return candidate entities for mentions stored in `ent` attribute in passed docs. Each candidate defines the + entity, the original alias, and the prior probability of that alias resolving to that entity. + If no candidate is found for a given mention, an empty list is returned. + mentions (Iterator[SpanGroup]): Mentions per doc as SpanGroup instance. + RETURNS (Iterator[Iterable[Iterable[Candidate]]]): Identified candidates per document. + """ + for doc_mentions in mentions: + yield [self.get_candidates(ent_span) for ent_span in doc_mentions] + + @staticmethod + def get_ents_as_spangroup(doc: Doc, extractor: Union[str, Callable[[Iterable[Span]], Doc]] = "ent") -> SpanGroup: + """ + Fetch entities from doc and returns them as a SpanGroup ready to be used in + `KnowledgeBase.get_candidates_all()`. + doc (Doc): Doc whose entities should be fetched. + extractor (Union[str, Callable[[Iterable[Span]], Doc]]): Defines how to retrieve object holding spans + used to describe entities. This can be a key referring to a property of the doc instance (e.g. " """ - return [self.get_candidates(span) for span in mentions] def get_candidates(self, mention: Span) -> Iterable[Candidate]: """ Return candidate entities for specified text. Each candidate defines the entity, the original alias, and the prior probability of that alias resolving to that entity. If the no candidate is found for a given text, an empty list is returned. + Note that doc is not utilized for further context in this implementation. mention (Span): Mention for which to get candidates. RETURNS (Iterable[Candidate]): Identified candidates. """ raise NotImplementedError( - Errors.E1045.format(parent="KnowledgeBase", method="get_candidates", name=self.__name__) + Errors.E1044.format(parent="KnowledgeBase", method="get_candidates", name=self.__name__) ) def get_vectors(self, entities: Iterable[str]) -> Iterable[Iterable[float]]: @@ -67,7 +81,7 @@ cdef class KnowledgeBase: RETURNS (Iterable[float]): Vector for specified entity. """ raise NotImplementedError( - Errors.E1045.format(parent="KnowledgeBase", method="get_vector", name=self.__name__) + Errors.E1044.format(parent="KnowledgeBase", method="get_vector", name=self.__name__) ) def to_bytes(self, **kwargs) -> bytes: @@ -75,7 +89,7 @@ cdef class KnowledgeBase: RETURNS (bytes): Current state as binary string. """ raise NotImplementedError( - Errors.E1045.format(parent="KnowledgeBase", method="to_bytes", name=self.__name__) + Errors.E1044.format(parent="KnowledgeBase", method="to_bytes", name=self.__name__) ) def from_bytes(self, bytes_data: bytes, *, exclude: Tuple[str] = tuple()): @@ -84,25 +98,45 @@ cdef class KnowledgeBase: exclude (Tuple[str]): Properties to exclude when restoring KB. """ raise NotImplementedError( - Errors.E1045.format(parent="KnowledgeBase", method="from_bytes", name=self.__name__) + Errors.E1044.format(parent="KnowledgeBase", method="from_bytes", name=self.__name__) ) def to_disk(self, path: Union[str, Path], exclude: Iterable[str] = SimpleFrozenList()) -> None: - """ - Write KnowledgeBase content to disk. + """Write KnowledgeBase content to disk. path (Union[str, Path]): Target file path. exclude (Iterable[str]): List of components to exclude. """ raise NotImplementedError( - Errors.E1045.format(parent="KnowledgeBase", method="to_disk", name=self.__name__) + Errors.E1044.format(parent="KnowledgeBase", method="to_disk", name=self.__name__) ) def from_disk(self, path: Union[str, Path], exclude: Iterable[str] = SimpleFrozenList()) -> None: - """ - Load KnowledgeBase content from disk. + """Load KnowledgeBase content from disk. path (Union[str, Path]): Target file path. exclude (Iterable[str]): List of components to exclude. """ raise NotImplementedError( - Errors.E1045.format(parent="KnowledgeBase", method="from_disk", name=self.__name__) + Errors.E1044.format(parent="KnowledgeBase", method="from_disk", name=self.__name__) + ) + + @classmethod + def generate_from_disk( + cls: Type[_KBType], path: Union[str, Path], exclude: Iterable[str] = SimpleFrozenList() + ) -> _KBType: + """ + Factory method for generating KnowledgeBase subclass instance from file. + path (Union[str, Path]): Target file path. + exclude (Iterable[str]): List of components to exclude. + return (_KBType): Instance of KnowledgeBase subclass generated from file. + """ + raise NotImplementedError( + Errors.E1044.format(parent="KnowledgeBase", method="generate_from_disk", name=cls.__name__) + ) + + def __len__(self) -> int: + """Returns number of entities in the KnowledgeBase. + RETURNS (int): Number of entities in the KnowledgeBase. + """ + raise NotImplementedError( + Errors.E1044.format(parent="KnowledgeBase", method="__len__", name=self.__name__) ) diff --git a/spacy/kb/kb_in_memory.pyx b/spacy/kb/kb_in_memory.pyx index 2a74d047b..3a03ed53e 100644 --- a/spacy/kb/kb_in_memory.pyx +++ b/spacy/kb/kb_in_memory.pyx @@ -1,5 +1,5 @@ # cython: infer_types=True, profile=True -from typing import Iterable, Callable, Dict, Any, Union +from typing import Iterable, Callable, Dict, Any, Union, Optional import srsly from preshed.maps cimport PreshMap @@ -11,7 +11,7 @@ from libcpp.vector cimport vector from pathlib import Path import warnings -from ..tokens import Span +from ..tokens import Span, Doc from ..typedefs cimport hash_t from ..errors import Errors, Warnings from .. import util @@ -49,6 +49,14 @@ cdef class InMemoryLookupKB(KnowledgeBase): def is_empty(self): return len(self) == 0 + @classmethod + def generate_from_disk( + cls, path: Union[str, Path], exclude: Iterable[str] = SimpleFrozenList() + ) -> "InMemoryLookupKB": + kb = InMemoryLookupKB(vocab=Vocab(strings=["."]), entity_vector_length=1) + kb.from_disk(path) + return kb + def __len__(self): return self.get_size_entities() @@ -227,7 +235,7 @@ cdef class InMemoryLookupKB(KnowledgeBase): self._aliases_table[alias_index] = alias_entry def get_candidates(self, mention: Span) -> Iterable[Candidate]: - return self.get_alias_candidates(mention.text) # type: ignore + return self.get_alias_candidates(mention.text) def get_alias_candidates(self, str alias) -> Iterable[Candidate]: """ @@ -241,14 +249,18 @@ cdef class InMemoryLookupKB(KnowledgeBase): alias_index = self._alias_index.get(alias_hash) alias_entry = self._aliases_table[alias_index] - return [Candidate(kb=self, - entity_hash=self._entries[entry_index].entity_hash, - entity_freq=self._entries[entry_index].freq, - entity_vector=self._vectors_table[self._entries[entry_index].vector_index], - alias_hash=alias_hash, - prior_prob=prior_prob) - for (entry_index, prior_prob) in zip(alias_entry.entry_indices, alias_entry.probs) - if entry_index != 0] + return [ + Candidate( + retrieve_string_from_hash=self.vocab.strings.__getitem__, + entity_hash=self._entries[entry_index].entity_hash, + entity_freq=self._entries[entry_index].freq, + entity_vector=self._vectors_table[self._entries[entry_index].vector_index], + alias_hash=alias_hash, + prior_prob=prior_prob + ) + for (entry_index, prior_prob) in zip(alias_entry.entry_indices, alias_entry.probs) + if entry_index != 0 + ] def get_vector(self, str entity): cdef hash_t entity_hash = self.vocab.strings[entity] diff --git a/spacy/ml/models/entity_linker.py b/spacy/ml/models/entity_linker.py index 7332ca199..b5f455cdc 100644 --- a/spacy/ml/models/entity_linker.py +++ b/spacy/ml/models/entity_linker.py @@ -1,14 +1,14 @@ from pathlib import Path -from typing import Optional, Callable, Iterable, List, Tuple +from typing import Optional, Callable, Iterable, List, Tuple, Iterator from thinc.types import Floats2d from thinc.api import chain, list2ragged, reduce_mean, residual from thinc.api import Model, Maxout, Linear, tuplify, Ragged from ...util import registry from ...kb import KnowledgeBase, InMemoryLookupKB -from ...kb import Candidate, get_candidates, get_candidates_batch +from ...kb import Candidate from ...vocab import Vocab -from ...tokens import Span, Doc +from ...tokens import Span, Doc, SpanGroup from ..extract_spans import extract_spans from ...errors import Errors @@ -89,14 +89,6 @@ def load_kb( return kb_from_file -@registry.misc("spacy.EmptyKB.v2") -def empty_kb_for_config() -> Callable[[Vocab, int], KnowledgeBase]: - def empty_kb_factory(vocab: Vocab, entity_vector_length: int): - return InMemoryLookupKB(vocab=vocab, entity_vector_length=entity_vector_length) - - return empty_kb_factory - - @registry.misc("spacy.EmptyKB.v1") def empty_kb( entity_vector_length: int, @@ -107,13 +99,44 @@ def empty_kb( return empty_kb_factory +@registry.misc("spacy.EmptyKB.v2") +def empty_kb_for_config() -> Callable[[Vocab, int], KnowledgeBase]: + def empty_kb_factory(vocab: Vocab, entity_vector_length: int): + return InMemoryLookupKB(vocab=vocab, entity_vector_length=entity_vector_length) + + return empty_kb_factory + + +def get_candidates(kb: KnowledgeBase, mention: Span) -> Iterable[Candidate]: + """ + Return candidate entities for a given mention and fetching appropriate entries from the index. + kb (KnowledgeBase): Knowledge base to query. + mention (Span): Entity mention for which to identify candidates. + RETURNS (Iterable[Candidate]): Identified candidates. + """ + return kb.get_candidates(mention) + + +def get_candidates_all( + kb: KnowledgeBase, mentions: Iterator[SpanGroup] +) -> Iterator[Iterable[Iterable[Candidate]]]: + """ + Return candidate entities for the given mentions and fetching appropriate entries from the index. + kb (KnowledgeBase): Knowledge base to query. + mentions (Iterator[SpanGroup]): Mentions per doc as SpanGroup instance. + RETURNS (Iterator[Iterable[Iterable[Candidate]]]): Identified candidates per document. + """ + return kb.get_candidates_all(mentions) + + @registry.misc("spacy.CandidateGenerator.v1") def create_candidates() -> Callable[[KnowledgeBase, Span], Iterable[Candidate]]: return get_candidates -@registry.misc("spacy.CandidateBatchGenerator.v1") -def create_candidates_batch() -> Callable[ - [KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]] +@registry.misc("spacy.CandidateAllGenerator.v1") +def create_candidates_all() -> Callable[ + [KnowledgeBase, Iterator[SpanGroup]], + Iterator[Iterable[Iterable[Candidate]]], ]: - return get_candidates_batch + return get_candidates_all diff --git a/spacy/pipeline/entity_linker.py b/spacy/pipeline/entity_linker.py index 76ccc3247..2cbf45320 100644 --- a/spacy/pipeline/entity_linker.py +++ b/spacy/pipeline/entity_linker.py @@ -1,4 +1,13 @@ -from typing import Optional, Iterable, Callable, Dict, Union, List, Any +from typing import ( + Optional, + Iterable, + Callable, + Dict, + Union, + List, + Any, + Iterator, +) from thinc.types import Floats2d from pathlib import Path from itertools import islice @@ -8,8 +17,7 @@ from thinc.api import CosineDistance, Model, Optimizer, Config from thinc.api import set_dropout_rate from ..kb import KnowledgeBase, Candidate -from ..ml import empty_kb -from ..tokens import Doc, Span +from ..tokens import Doc, Span, SpanGroup from .pipe import deserialize_config from .legacy.entity_linker import EntityLinker_v1 from .trainable_pipe import TrainablePipe @@ -53,12 +61,12 @@ DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"] "incl_context": True, "entity_vector_length": 64, "get_candidates": {"@misc": "spacy.CandidateGenerator.v1"}, - "get_candidates_batch": {"@misc": "spacy.CandidateBatchGenerator.v1"}, + "get_candidates_all": {"@misc": "spacy.CandidateAllGenerator.v1"}, "generate_empty_kb": {"@misc": "spacy.EmptyKB.v2"}, "overwrite": True, "scorer": {"@scorers": "spacy.entity_linker_scorer.v1"}, "use_gold_ents": True, - "candidates_batch_size": 1, + "candidates_doc_mode": False, "threshold": None, }, default_score_weights={ @@ -78,14 +86,14 @@ def make_entity_linker( incl_context: bool, entity_vector_length: int, get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]], - get_candidates_batch: Callable[ - [KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]] + get_candidates_all: Callable[ + [KnowledgeBase, Iterator[SpanGroup]], Iterator[Iterable[Iterable[Candidate]]] ], generate_empty_kb: Callable[[Vocab, int], KnowledgeBase], overwrite: bool, scorer: Optional[Callable], use_gold_ents: bool, - candidates_batch_size: int, + candidates_doc_mode: bool, threshold: Optional[float] = None, ): """Construct an EntityLinker component. @@ -98,16 +106,18 @@ def make_entity_linker( incl_prior (bool): Whether or not to include prior probabilities from the KB in the model. incl_context (bool): Whether or not to include the local context in the model. entity_vector_length (int): Size of encoding vectors in the KB. - get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function that - produces a list of candidates, given a certain knowledge base and a textual mention. - get_candidates_batch ( - Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]], Iterable[Candidate]] - ): Function that produces a list of candidates, given a certain knowledge base and several textual mentions. + get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function producing a list of + candidates, given a certain knowledge base and a textual mention. + get_candidates_all (Callable[[KnowledgeBase, Iterator[SpanGroup]], Iterator[Iterable[Iterable[Candidate]]]]): + Function producing a list of candidates per document, given a certain knowledge base and several textual + documents with textual mentions. generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase. scorer (Optional[Callable]): The scoring method. use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another component must provide entity annotations. - candidates_batch_size (int): Size of batches for entity candidate generation. + candidates_doc_mode (bool): Whether or not to operate candidate generation in doc mode, i.e. to provide a generator + yielding entities per document (candidate generator callable is called only once in this case). If False, + the candidate generator is called once per entity. threshold (Optional[float]): Confidence threshold for entity predictions. If confidence is below the threshold, prediction is discarded. If None, predictions are not filtered by any threshold. """ @@ -137,12 +147,12 @@ def make_entity_linker( incl_context=incl_context, entity_vector_length=entity_vector_length, get_candidates=get_candidates, - get_candidates_batch=get_candidates_batch, + get_candidates_all=get_candidates_all, generate_empty_kb=generate_empty_kb, overwrite=overwrite, scorer=scorer, use_gold_ents=use_gold_ents, - candidates_batch_size=candidates_batch_size, + candidates_doc_mode=candidates_doc_mode, threshold=threshold, ) @@ -176,14 +186,15 @@ class EntityLinker(TrainablePipe): incl_context: bool, entity_vector_length: int, get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]], - get_candidates_batch: Callable[ - [KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]] + get_candidates_all: Callable[ + [KnowledgeBase, Iterator[SpanGroup]], + Iterator[Iterable[Iterable[Candidate]]], ], generate_empty_kb: Callable[[Vocab, int], KnowledgeBase], overwrite: bool = BACKWARD_OVERWRITE, scorer: Optional[Callable] = entity_linker_score, use_gold_ents: bool, - candidates_batch_size: int, + candidates_doc_mode: bool, threshold: Optional[float] = None, ) -> None: """Initialize an entity linker. @@ -197,17 +208,18 @@ class EntityLinker(TrainablePipe): incl_prior (bool): Whether or not to include prior probabilities from the KB in the model. incl_context (bool): Whether or not to include the local context in the model. entity_vector_length (int): Size of encoding vectors in the KB. - get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function that - produces a list of candidates, given a certain knowledge base and a textual mention. - get_candidates_batch ( - Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]], - Iterable[Candidate]] - ): Function that produces a list of candidates, given a certain knowledge base and several textual mentions. + get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function producing a list + of candidates, given a certain knowledge base and a textual mention. + get_candidates_all (Callable[[KnowledgeBase, Iterator[SpanGroup]], Iterator[Iterable[Iterable[Candidate]]]]): + Function producing a list of candidates per document, given a certain knowledge base and several textual + documents with textual mentions. generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase. scorer (Optional[Callable]): The scoring method. Defaults to Scorer.score_links. use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another component must provide entity annotations. - candidates_batch_size (int): Size of batches for entity candidate generation. + candidates_doc_mode (bool): Whether or not to operate candidate generation in doc mode, i.e. to provide a + generator yielding entities per document (candidate generator callable is called only once in this case). If + False, the candidate generator is called once per entity. threshold (Optional[float]): Confidence threshold for entity predictions. If confidence is below the threshold, prediction is discarded. If None, predictions are not filtered by any threshold. DOCS: https://spacy.io/api/entitylinker#init @@ -231,18 +243,16 @@ class EntityLinker(TrainablePipe): self.incl_prior = incl_prior self.incl_context = incl_context self.get_candidates = get_candidates - self.get_candidates_batch = get_candidates_batch + self.get_candidates_all = get_candidates_all + self.generate_empty_kb = generate_empty_kb self.cfg: Dict[str, Any] = {"overwrite": overwrite} self.distance = CosineDistance(normalize=False) self.kb = generate_empty_kb(self.vocab, entity_vector_length) self.scorer = scorer self.use_gold_ents = use_gold_ents - self.candidates_batch_size = candidates_batch_size + self.candidates_doc_mode = candidates_doc_mode self.threshold = threshold - if candidates_batch_size < 1: - raise ValueError(Errors.E1044) - def set_kb(self, kb_loader: Callable[[Vocab], KnowledgeBase]): """Define the KB of this pipe by providing a function that will create it using this object's vocab.""" @@ -319,7 +329,6 @@ class EntityLinker(TrainablePipe): If one isn't present, then the update step needs to be skipped. """ - for eg in examples: for ent in eg.predicted.ents: candidates = list(self.get_candidates(self.kb, ent)) @@ -445,108 +454,125 @@ class EntityLinker(TrainablePipe): return final_kb_ids if isinstance(docs, Doc): docs = [docs] - for i, doc in enumerate(docs): - if len(doc) == 0: + + docs = list(docs) + # Determine which entities are to be ignored due to labels_discard. + valid_ent_idx_per_doc = ( + [ + idx + for idx in range(len(doc.ents)) + if doc.ents[idx].label_ not in self.labels_discard + ] + for doc in docs + if len(doc) and len(doc.ents) + ) + + # Call candidate generator. + if self.candidates_doc_mode: + all_ent_cands = self.get_candidates_all( + self.kb, + ( + SpanGroup( + doc, + spans=[doc.ents[idx] for idx in next(valid_ent_idx_per_doc)], + ) + for doc in docs + if len(doc) and len(doc.ents) + ), + ) + else: + # Alternative: collect entities the old-fashioned way - by retrieving entities individually. + all_ent_cands = ( + [ + self.get_candidates(self.kb, doc.ents[idx]) + for idx in next(valid_ent_idx_per_doc) + ] + for doc in docs + if len(doc) and len(doc.ents) + ) + + for doc_idx, doc in enumerate(docs): + if len(doc) == 0 or len(doc.ents) == 0: continue sentences = [s for s in doc.sents] + doc_ent_cands = list(next(all_ent_cands)) - # Loop over entities in batches. - for ent_idx in range(0, len(doc.ents), self.candidates_batch_size): - ent_batch = doc.ents[ent_idx : ent_idx + self.candidates_batch_size] - - # Look up candidate entities. - valid_ent_idx = [ - idx - for idx in range(len(ent_batch)) - if ent_batch[idx].label_ not in self.labels_discard - ] - - batch_candidates = list( - self.get_candidates_batch( - self.kb, [ent_batch[idx] for idx in valid_ent_idx] - ) - if self.candidates_batch_size > 1 - else [ - self.get_candidates(self.kb, ent_batch[idx]) - for idx in valid_ent_idx - ] + # Looping over candidate entities for this doc. (TODO: rewrite) + for ent_cand_idx, ent in enumerate(doc.ents): + assert hasattr(ent, "sents") + sents = list(ent.sents) + sent_indices = ( + sentences.index(sents[0]), + sentences.index(sents[-1]), ) + assert sent_indices[1] >= sent_indices[0] >= 0 - # Looping through each entity in batch (TODO: rewrite) - for j, ent in enumerate(ent_batch): - assert hasattr(ent, "sents") - sents = list(ent.sents) - sent_indices = ( - sentences.index(sents[0]), - sentences.index(sents[-1]), + if self.incl_context: + # get n_neighbour sentences, clipped to the length of the document + start_sentence = max(0, sent_indices[0] - self.n_sents) + end_sentence = min( + len(sentences) - 1, sent_indices[1] + self.n_sents ) - assert sent_indices[1] >= sent_indices[0] >= 0 - - if self.incl_context: - # get n_neighbour sentences, clipped to the length of the document - start_sentence = max(0, sent_indices[0] - self.n_sents) - end_sentence = min( - len(sentences) - 1, sent_indices[1] + self.n_sents - ) - start_token = sentences[start_sentence].start - end_token = sentences[end_sentence].end - sent_doc = doc[start_token:end_token].as_doc() - - # currently, the context is the same for each entity in a sentence (should be refined) - sentence_encoding = self.model.predict([sent_doc])[0] - sentence_encoding_t = sentence_encoding.T - sentence_norm = xp.linalg.norm(sentence_encoding_t) - entity_count += 1 - if ent.label_ in self.labels_discard: - # ignoring this entity - setting to NIL + start_token = sentences[start_sentence].start + end_token = sentences[end_sentence].end + sent_doc = doc[start_token:end_token].as_doc() + # currently, the context is the same for each entity in a sentence (should be refined) + sentence_encoding = self.model.predict([sent_doc])[0] + sentence_encoding_t = sentence_encoding.T + sentence_norm = xp.linalg.norm(sentence_encoding_t) + entity_count += 1 + if ent.label_ in self.labels_discard: + # ignoring this entity - setting to NIL + final_kb_ids.append(self.NIL) + else: + candidates = list(doc_ent_cands[ent_cand_idx]) + if not candidates: + # no prediction possible for this entity - setting to NIL final_kb_ids.append(self.NIL) + elif len(candidates) == 1 and self.threshold is None: + # shortcut for efficiency reasons: take the 1 candidate + final_kb_ids.append(candidates[0].entity_) else: - candidates = list(batch_candidates[j]) - if not candidates: - # no prediction possible for this entity - setting to NIL - final_kb_ids.append(self.NIL) - elif len(candidates) == 1 and self.threshold is None: - # shortcut for efficiency reasons: take the 1 candidate - final_kb_ids.append(candidates[0].entity_) - else: - random.shuffle(candidates) - # set all prior probabilities to 0 if incl_prior=False - prior_probs = xp.asarray([c.prior_prob for c in candidates]) - if not self.incl_prior: - prior_probs = xp.asarray([0.0 for _ in candidates]) - scores = prior_probs - # add in similarity from the context - if self.incl_context: - entity_encodings = xp.asarray( - [c.entity_vector for c in candidates] - ) - entity_norm = xp.linalg.norm(entity_encodings, axis=1) - if len(entity_encodings) != len(prior_probs): - raise RuntimeError( - Errors.E147.format( - method="predict", - msg="vectors not of equal length", - ) - ) - # cosine similarity - sims = xp.dot(entity_encodings, sentence_encoding_t) / ( - sentence_norm * entity_norm - ) - if sims.shape != prior_probs.shape: - raise ValueError(Errors.E161) - scores = prior_probs + sims - (prior_probs * sims) - final_kb_ids.append( - candidates[scores.argmax().item()].entity_ - if self.threshold is None - or scores.max() >= self.threshold - else EntityLinker.NIL + random.shuffle(candidates) + # set all prior probabilities to 0 if incl_prior=False + scores = prior_probs = xp.asarray( + [ + c.prior_prob if self.incl_prior else 0.0 + for c in candidates + ] + ) + # add in similarity from the context + if self.incl_context: + entity_encodings = xp.asarray( + [c.entity_vector for c in candidates] ) + entity_norm = xp.linalg.norm(entity_encodings, axis=1) + if len(entity_encodings) != len(prior_probs): + raise RuntimeError( + Errors.E147.format( + method="predict", + msg="vectors not of equal length", + ) + ) + # cosine similarity + sims = xp.dot(entity_encodings, sentence_encoding_t) / ( + sentence_norm * entity_norm + ) + if sims.shape != prior_probs.shape: + raise ValueError(Errors.E161) + scores = prior_probs + sims - (prior_probs * sims) + final_kb_ids.append( + candidates[scores.argmax().item()].entity_ + if self.threshold is None or scores.max() >= self.threshold + else EntityLinker.NIL + ) if not (len(final_kb_ids) == entity_count): err = Errors.E147.format( method="predict", msg="result variables not of equal length" ) raise RuntimeError(err) + return final_kb_ids def set_annotations(self, docs: Iterable[Doc], kb_ids: List[str]) -> None: diff --git a/spacy/tests/pipeline/test_entity_linker.py b/spacy/tests/pipeline/test_entity_linker.py index fc960cb01..1b83bb723 100644 --- a/spacy/tests/pipeline/test_entity_linker.py +++ b/spacy/tests/pipeline/test_entity_linker.py @@ -1,4 +1,4 @@ -from typing import Callable, Iterable, Dict, Any, Tuple +from typing import Callable, Iterable, Dict, Any, Iterator, Tuple import pytest from numpy.testing import assert_equal @@ -6,16 +6,16 @@ from numpy.testing import assert_equal from spacy import registry, util, Language from spacy.attrs import ENT_KB_ID from spacy.compat import pickle -from spacy.kb import Candidate, InMemoryLookupKB, get_candidates, KnowledgeBase +from spacy.kb import Candidate, InMemoryLookupKB, KnowledgeBase from spacy.lang.en import English from spacy.ml import load_kb -from spacy.ml.models.entity_linker import build_span_maker +from spacy.ml.models.entity_linker import build_span_maker, get_candidates from spacy.pipeline import EntityLinker from spacy.pipeline.legacy import EntityLinker_v1 from spacy.pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL from spacy.scorer import Scorer from spacy.tests.util import make_tempdir -from spacy.tokens import Span, Doc +from spacy.tokens import Span, Doc, SpanGroup from spacy.training import Example from spacy.util import ensure_path from spacy.vocab import Vocab @@ -168,7 +168,7 @@ def test_no_entities(): { "sent_starts": [1, 0, 0, 0, 0], }, - ) + ), ] nlp = English() vector_length = 3 @@ -489,11 +489,19 @@ def test_el_pipe_configuration(nlp): assert doc[1].ent_kb_id_ == "" assert doc[2].ent_kb_id_ == "Q2" - def get_lowercased_candidates(kb, span): + # Replace the pipe with a new one with with a different candidate generator. + + def get_lowercased_candidates(kb: InMemoryLookupKB, span: Span): return kb.get_alias_candidates(span.text.lower()) - def get_lowercased_candidates_batch(kb, spans): - return [get_lowercased_candidates(kb, span) for span in spans] + def get_lowercased_candidates_all( + kb: InMemoryLookupKB, mentions: Iterator[SpanGroup] + ): + for doc_mentions in mentions: + yield [ + get_lowercased_candidates(kb, doc_mentions[idx]) + for idx in range(len(doc_mentions)) + ] @registry.misc("spacy.LowercaseCandidateGenerator.v1") def create_candidates() -> Callable[ @@ -501,29 +509,40 @@ def test_el_pipe_configuration(nlp): ]: return get_lowercased_candidates - @registry.misc("spacy.LowercaseCandidateBatchGenerator.v1") + @registry.misc("spacy.LowercaseCandidateAllGenerator.v1") def create_candidates_batch() -> Callable[ - [InMemoryLookupKB, Iterable["Span"]], Iterable[Iterable[Candidate]] + [InMemoryLookupKB, Iterator[SpanGroup]], + Iterator[Iterable[Iterable[Candidate]]], ]: - return get_lowercased_candidates_batch + return get_lowercased_candidates_all - # replace the pipe with a new one with with a different candidate generator - entity_linker = nlp.replace_pipe( - "entity_linker", - "entity_linker", - config={ - "incl_context": False, - "get_candidates": {"@misc": "spacy.LowercaseCandidateGenerator.v1"}, - "get_candidates_batch": { - "@misc": "spacy.LowercaseCandidateBatchGenerator.v1" + def test_reconfigured_el(candidates_doc_mode: bool, doc_text: str) -> None: + """Test reconfigured EL for correct results. + candidates_doc_mode (bool): candidates_doc_mode in pipe config. + doc_text (str): Text to infer. + """ + _entity_linker = nlp.replace_pipe( + "entity_linker", + "entity_linker", + config={ + "incl_context": False, + "incl_prior": True, + "candidates_doc_mode": candidates_doc_mode, + "get_candidates": {"@misc": "spacy.LowercaseCandidateGenerator.v1"}, + "get_candidates_all": { + "@misc": "spacy.LowercaseCandidateAllGenerator.v1" + }, }, - }, - ) - entity_linker.set_kb(create_kb) - doc = nlp(text) - assert doc[0].ent_kb_id_ == "Q2" - assert doc[1].ent_kb_id_ == "" - assert doc[2].ent_kb_id_ == "Q2" + ) + _entity_linker.set_kb(create_kb) + _doc = nlp(doc_text) + assert _doc[0].ent_kb_id_ == "Q2" + assert _doc[1].ent_kb_id_ == "" + assert _doc[2].ent_kb_id_ == "Q2" + + # Test individual and doc-wise candidate generation. + test_reconfigured_el(False, text) + test_reconfigured_el(True, text) def test_nel_nsents(nlp): @@ -1169,18 +1188,19 @@ def test_threshold(meet_threshold: bool, config: Dict[str, Any]): # create artificial KB mykb = InMemoryLookupKB(vocab, entity_vector_length=3) mykb.add_entity(entity=entity_id, freq=12, entity_vector=[6, -4, 3]) - mykb.add_alias( - alias="Mahler", - entities=[entity_id], - probabilities=[1 if meet_threshold else 0.01], - ) + mykb.add_alias(alias="Mahler", entities=[entity_id], probabilities=[1]) return mykb # Create the Entity Linker component and add it to the pipeline entity_linker = nlp.add_pipe( "entity_linker", last=True, - config={"threshold": 0.99, "model": config}, + config={ + "threshold": None if meet_threshold else 1.0, + # Prior for candidate may be 1.0, rendering the our test setting with threshold 1.0 useless otherwise. + "incl_prior": meet_threshold, + "model": config, + }, ) entity_linker.set_kb(create_kb) # type: ignore nlp.initialize(get_examples=lambda: train_examples) @@ -1191,7 +1211,7 @@ def test_threshold(meet_threshold: bool, config: Dict[str, Any]): doc = nlp(text) assert len(doc.ents) == 1 - assert doc.ents[0].kb_id_ == entity_id if meet_threshold else EntityLinker.NIL + assert doc.ents[0].kb_id_ == (entity_id if meet_threshold else EntityLinker.NIL) def test_span_maker_forward_with_empty(): @@ -1207,3 +1227,62 @@ def test_span_maker_forward_with_empty(): # just to get a model span_maker = build_span_maker() span_maker([doc1, doc2], False) + + +def test_nel_candidate_processing(): + """Test that NEL handles candidate streams correctly in a set of documents with & without entities as well as empty + documents. + """ + train_data = [ + ( + "The sky is blue.", + { + "sent_starts": [1, 0, 0, 0, 0], + }, + ), + ( + "They visited New York.", + { + "sent_starts": [1, 0, 0, 0, 0], + "entities": [(13, 21, "GPE")], + }, + ), + ("", {}), + ( + "New York is a city.", + { + "sent_starts": [1, 0, 0, 0, 0, 0], + "entities": [(0, 8, "GPE")], + }, + ), + ] + + nlp = English() + nlp.add_pipe("sentencizer") + + vector_length = 3 + train_examples = [] + for text, annotation in train_data: + train_examples.append(Example.from_dict(nlp(text), annotation)) + + def create_kb(vocab): + # create artificial KB + mykb = InMemoryLookupKB(vocab, entity_vector_length=vector_length) + mykb.add_entity(entity="Q60", freq=12, entity_vector=[1, 2, 3]) + mykb.add_alias("New York", ["Q60"], [0.9]) + return mykb + + # Create and train the Entity Linker + entity_linker = nlp.add_pipe("entity_linker", last=True) + entity_linker.set_kb(create_kb) + optimizer = nlp.initialize(get_examples=lambda: train_examples) + for i in range(2): + losses = {} + nlp.update(train_examples, sgd=optimizer, losses=losses) + + # Add a custom rule-based component to mimick NER + ruler = nlp.add_pipe("entity_ruler", before="entity_linker") + ruler.add_patterns([{"label": "GPE", "pattern": [{"LOWER": "new york"}]}]) # type: ignore + + # this will run the pipeline on the examples and shouldn't crash + nlp.evaluate(train_examples) diff --git a/spacy/tokens/doc.pyi b/spacy/tokens/doc.pyi index 9d45960ab..6deb83ec0 100644 --- a/spacy/tokens/doc.pyi +++ b/spacy/tokens/doc.pyi @@ -2,7 +2,9 @@ from typing import Callable, Protocol, Iterable, Iterator, Optional from typing import Union, Tuple, List, Dict, Any, overload from cymem.cymem import Pool from thinc.types import Floats1d, Floats2d, Ints2d + from .span import Span +from .span_group import SpanGroup from .token import Token from ._dict_proxies import SpanGroups from ._retokenize import Retokenizer @@ -129,6 +131,7 @@ class Doc: outside: Optional[List[Span]] = ..., default: str = ... ) -> None: ... + ents_spangroup: SpanGroup @property def noun_chunks(self) -> Iterator[Span]: ... @property diff --git a/spacy/tokens/doc.pyx b/spacy/tokens/doc.pyx index a54b4ad3c..c5869b967 100644 --- a/spacy/tokens/doc.pyx +++ b/spacy/tokens/doc.pyx @@ -19,6 +19,8 @@ import warnings from .span cimport Span from .token cimport MISSING_DEP +from .span_group cimport SpanGroup + from ._dict_proxies import SpanGroups from .token cimport Token from ..lexeme cimport Lexeme, EMPTY_LEXEME @@ -701,6 +703,14 @@ cdef class Doc: """ return self.text + @property + def ents_spangroup(self) -> SpanGroup: + """ + Returns entities (in `.ents`) as `SpanGroup`. + RETURNS (SpanGroup): All entities (in `.ents`) as `SpanGroup`. + """ + return SpanGroup(self, spans=self.ents, name="ents") + property ents: """The named entities in the document. Returns a tuple of named entity `Span` objects, if the entity recognizer has been applied.