diff --git a/spacy/kb/__init__.py b/spacy/kb/__init__.py index ff0e209e3..c8a657d62 100644 --- a/spacy/kb/__init__.py +++ b/spacy/kb/__init__.py @@ -2,5 +2,4 @@ from .kb import KnowledgeBase from .kb_in_memory import InMemoryLookupKB from .candidate import Candidate, InMemoryCandidate - __all__ = ["KnowledgeBase", "InMemoryLookupKB", "Candidate", "InMemoryCandidate"] diff --git a/spacy/kb/kb.pyx b/spacy/kb/kb.pyx index 1cb08f488..2d0e1d5a1 100644 --- a/spacy/kb/kb.pyx +++ b/spacy/kb/kb.pyx @@ -5,7 +5,7 @@ from typing import Iterable, Tuple, Union from cymem.cymem cimport Pool from .candidate import Candidate -from ..tokens import Span +from ..tokens import Span, SpanGroup from ..util import SimpleFrozenList from ..errors import Errors @@ -30,13 +30,13 @@ cdef class KnowledgeBase: self.entity_vector_length = entity_vector_length self.mem = Pool() - def get_candidates_batch(self, mentions: Iterable[Span]) -> Iterable[Iterable[Candidate]]: + def get_candidates_batch(self, mentions: SpanGroup) -> Iterable[Iterable[Candidate]]: """ Return candidate entities for a specified Span mention. Each candidate defines at least the entity and the entity's embedding vector. Depending on the KB implementation, further properties - such as the prior probability of the specified mention text resolving to that entity - might be included. If no candidates are found for a given mention, an empty list is returned. - mentions (Iterable[Span]): Mentions for which to get candidates. + mentions (SpanGroup): Mentions for which to get candidates. RETURNS (Iterable[Iterable[Candidate]]): Identified candidates. """ return [self.get_candidates(span) for span in mentions] diff --git a/spacy/ml/models/entity_linker.py b/spacy/ml/models/entity_linker.py index 7fe0b4741..b5122b164 100644 --- a/spacy/ml/models/entity_linker.py +++ b/spacy/ml/models/entity_linker.py @@ -8,7 +8,7 @@ from ...util import registry from ...kb import KnowledgeBase, InMemoryLookupKB from ...kb import Candidate from ...vocab import Vocab -from ...tokens import Span, Doc +from ...tokens import Doc, Span, SpanGroup from ..extract_spans import extract_spans from ...errors import Errors @@ -114,7 +114,7 @@ def create_candidates() -> Callable[[KnowledgeBase, Span], Iterable[Candidate]]: @registry.misc("spacy.CandidateBatchGenerator.v1") def create_candidates_batch() -> Callable[ - [KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]] + [KnowledgeBase, SpanGroup], Iterable[Iterable[Candidate]] ]: return get_candidates_batch @@ -130,12 +130,12 @@ def get_candidates(kb: KnowledgeBase, mention: Span) -> Iterable[Candidate]: def get_candidates_batch( - kb: KnowledgeBase, mentions: Iterable[Span] + kb: KnowledgeBase, mentions: SpanGroup ) -> Iterable[Iterable[Candidate]]: """ Return candidate entities for the given mentions and fetching appropriate entries from the index. kb (KnowledgeBase): Knowledge base to query. - mentions (Iterable[Span]): Entity mentions for which to identify candidates. + mentions (SpanGroup): Entity mentions for which to identify candidates. RETURNS (Iterable[Iterable[Candidate]]): Identified candidates. """ return kb.get_candidates_batch(mentions) diff --git a/spacy/pipeline/entity_linker.py b/spacy/pipeline/entity_linker.py index caced9cfd..ecd156db5 100644 --- a/spacy/pipeline/entity_linker.py +++ b/spacy/pipeline/entity_linker.py @@ -11,6 +11,8 @@ from thinc.api import set_dropout_rate from ..kb import KnowledgeBase, Candidate from ..tokens import Doc, Span +from ..ml import empty_kb +from ..tokens import Doc, Span, SpanGroup from .pipe import deserialize_config from .trainable_pipe import TrainablePipe from ..language import Language @@ -82,7 +84,7 @@ def make_entity_linker( entity_vector_length: int, get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]], get_candidates_batch: Callable[ - [KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]] + [KnowledgeBase, SpanGroup], Iterable[Iterable[Candidate]] ], generate_empty_kb: Callable[[Vocab, int], KnowledgeBase], overwrite: bool, @@ -105,7 +107,7 @@ def make_entity_linker( get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function that produces a list of candidates, given a certain knowledge base and a textual mention. get_candidates_batch ( - Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]], Iterable[Candidate]] + Callable[[KnowledgeBase, SpanGroup], Iterable[Iterable[Candidate]]], Iterable[Candidate]] ): Function that produces a list of candidates, given a certain knowledge base and several textual mentions. generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase. scorer (Optional[Callable]): The scoring method. @@ -170,7 +172,7 @@ class EntityLinker(TrainablePipe): entity_vector_length: int, get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]], get_candidates_batch: Callable[ - [KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]] + [KnowledgeBase, SpanGroup], Iterable[Iterable[Candidate]] ], generate_empty_kb: Callable[[Vocab, int], KnowledgeBase], overwrite: bool = False, @@ -194,7 +196,7 @@ class EntityLinker(TrainablePipe): get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function that produces a list of candidates, given a certain knowledge base and a textual mention. get_candidates_batch ( - Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]], + Callable[[KnowledgeBase, SpanGroup], Iterable[Iterable[Candidate]]], Iterable[Candidate]] ): Function that produces a list of candidates, given a certain knowledge base and several textual mentions. generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase. @@ -473,7 +475,8 @@ class EntityLinker(TrainablePipe): batch_candidates = list( self.get_candidates_batch( - self.kb, [ent_batch[idx] for idx in valid_ent_idx] + self.kb, + SpanGroup(doc, spans=[ent_batch[idx] for idx in valid_ent_idx]), ) if self.candidates_batch_size > 1 else [ diff --git a/spacy/tests/pipeline/test_entity_linker.py b/spacy/tests/pipeline/test_entity_linker.py index 65406a36e..773a5b8f3 100644 --- a/spacy/tests/pipeline/test_entity_linker.py +++ b/spacy/tests/pipeline/test_entity_linker.py @@ -997,7 +997,6 @@ def test_scorer_links(): ) # fmt: on def test_legacy_architectures(name, config): - # Ensure that the legacy architectures still work vector_length = 3 nlp = English() diff --git a/website/docs/api/inmemorylookupkb.mdx b/website/docs/api/inmemorylookupkb.mdx index 6fa6cb235..3b33f7fb7 100644 --- a/website/docs/api/inmemorylookupkb.mdx +++ b/website/docs/api/inmemorylookupkb.mdx @@ -189,14 +189,15 @@ to you. > > ```python > from spacy.lang.en import English +> from spacy.tokens import SpanGroup > nlp = English() > doc = nlp("Douglas Adams wrote 'The Hitchhiker's Guide to the Galaxy'.") -> candidates = kb.get_candidates((doc[0:2], doc[3:])) +> candidates = kb.get_candidates_batch([SpanGroup(doc, spans=[doc[0:2], doc[3:]]]) > ``` | Name | Description | | ----------- | ------------------------------------------------------------------------------------------------------------ | -| `mentions` | The textual mentions. ~~Iterable[Span]~~ | +| `mentions` | The textual mentions. ~~SpanGroup~~ | | **RETURNS** | An iterable of iterable with relevant `InMemoryCandidate` objects. ~~Iterable[Iterable[InMemoryCandidate]]~~ | ## InMemoryLookupKB.get_vector {id="get_vector",tag="method"} diff --git a/website/docs/api/kb.mdx b/website/docs/api/kb.mdx index 9536a3fe3..94506162f 100644 --- a/website/docs/api/kb.mdx +++ b/website/docs/api/kb.mdx @@ -93,14 +93,15 @@ to you. > > ```python > from spacy.lang.en import English +> from spacy.tokens import SpanGroup > nlp = English() > doc = nlp("Douglas Adams wrote 'The Hitchhiker's Guide to the Galaxy'.") -> candidates = kb.get_candidates((doc[0:2], doc[3:])) +> candidates = kb.get_candidates([SpanGroup(doc, spans=[doc[0:2], doc[3:]]]) > ``` | Name | Description | | ----------- | -------------------------------------------------------------------------------------------- | -| `mentions` | The textual mention or alias. ~~Iterable[Span]~~ | +| `mentions` | The textual mentions. ~~SpanGroup~~ | | **RETURNS** | An iterable of iterable with relevant `Candidate` objects. ~~Iterable[Iterable[Candidate]]~~ | ## KnowledgeBase.get_vector {id="get_vector",tag="method"} @@ -187,13 +188,11 @@ Construct an `InMemoryCandidate` object. Usually this constructor is not called directly, but instead these objects are returned by the `get_candidates` method of the [`entity_linker`](/api/entitylinker) pipe. -> #### Example```python +> #### Example > +> ```python > from spacy.kb import InMemoryCandidate candidate = InMemoryCandidate(kb, > entity_hash, entity_freq, entity_vector, alias_hash, prior_prob) -> -> ``` -> > ``` | Name | Description |