From b6bc6885d9486a8611d784c97f5c73b0ac71c3ba Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Thu, 15 Dec 2022 10:17:25 +0100 Subject: [PATCH] Switch to SpanGroup (from Doc) for bundling Spans for candidate retrieval. --- spacy/kb/kb.pyx | 22 +++++++++---- spacy/kb/kb_in_memory.pyx | 4 +-- spacy/ml/models/entity_linker.py | 10 +++--- spacy/pipeline/entity_linker.py | 36 ++++++++-------------- spacy/tests/pipeline/test_entity_linker.py | 16 +++++----- spacy/tokens/doc.pyx | 10 ++++++ 6 files changed, 55 insertions(+), 43 deletions(-) diff --git a/spacy/kb/kb.pyx b/spacy/kb/kb.pyx index b72378323..bc8d54761 100644 --- a/spacy/kb/kb.pyx +++ b/spacy/kb/kb.pyx @@ -1,11 +1,11 @@ # cython: infer_types=True, profile=True from pathlib import Path -from typing import Iterable, Tuple, Union, Iterator, TypeVar, Type, Optional +from typing import Iterable, Tuple, Union, Iterator, TypeVar, Type, Callable from cymem.cymem cimport Pool from .candidate import Candidate -from ..tokens import Span, Doc +from ..tokens import Span, SpanGroup, Doc from ..util import SimpleFrozenList from ..errors import Errors @@ -32,16 +32,26 @@ cdef class KnowledgeBase: self.entity_vector_length = entity_vector_length self.mem = Pool() - def get_candidates_all(self, docs: Iterator[Doc]) -> Iterator[Iterable[Iterable[Candidate]]]: + def get_candidates_all(self, mentions: Iterator[SpanGroup]) -> Iterator[Iterable[Iterable[Candidate]]]: """ Return candidate entities for mentions stored in `ent` attribute in passed docs. Each candidate defines the entity, the original alias, and the prior probability of that alias resolving to that entity. If no candidate is found for a given mention, an empty list is returned. - docs (Iterator[Doc]): Doc instances with mentions (stored in `.ent`). + mentions (Iterator[SpanGroup]): Mentions per doc as SpanGroup instance. RETURNS (Iterator[Iterable[Iterable[Candidate]]]): Identified candidates per document. """ - for doc in docs: - yield [self.get_candidates(ent_span) for ent_span in doc.ents] + for doc_mentions in mentions: + yield [self.get_candidates(ent_span) for ent_span in doc_mentions] + + @staticmethod + def get_ents_as_spangroup(doc: Doc, extractor: Union[str, Callable[[Iterable[Span]], Doc]] = "ent") -> SpanGroup: + """ + Fetch entities from doc and returns them as a SpanGroup ready to be used in + `KnowledgeBase.get_candidates_all()`. + doc (Doc): Doc whose entities should be fetched. + extractor (Union[str, Callable[[Iterable[Span]], Doc]]): Defines how to retrieve object holding spans + used to describe entities. This can be a key referring to a property of the doc instance (e.g. " + """ def get_candidates(self, mention: Span) -> Iterable[Candidate]: """ diff --git a/spacy/kb/kb_in_memory.pyx b/spacy/kb/kb_in_memory.pyx index 133dc3abb..a87ddf0f6 100644 --- a/spacy/kb/kb_in_memory.pyx +++ b/spacy/kb/kb_in_memory.pyx @@ -231,8 +231,8 @@ cdef class InMemoryLookupKB(KnowledgeBase): alias_entry.probs = probs self._aliases_table[alias_index] = alias_entry - def get_candidates(self, mention: Span, doc: Optional[Doc] = None) -> Iterable[Candidate]: - return self.get_alias_candidates(mention.text) # type: ignore + def get_candidates(self, mention: Span) -> Iterable[Candidate]: + return self.get_alias_candidates(mention.text) def get_alias_candidates(self, str alias) -> Iterable[Candidate]: """ diff --git a/spacy/ml/models/entity_linker.py b/spacy/ml/models/entity_linker.py index d683a8fe7..b5f455cdc 100644 --- a/spacy/ml/models/entity_linker.py +++ b/spacy/ml/models/entity_linker.py @@ -8,7 +8,7 @@ from ...util import registry from ...kb import KnowledgeBase, InMemoryLookupKB from ...kb import Candidate from ...vocab import Vocab -from ...tokens import Span, Doc +from ...tokens import Span, Doc, SpanGroup from ..extract_spans import extract_spans from ...errors import Errors @@ -118,15 +118,15 @@ def get_candidates(kb: KnowledgeBase, mention: Span) -> Iterable[Candidate]: def get_candidates_all( - kb: KnowledgeBase, docs: Iterator[Doc] + kb: KnowledgeBase, mentions: Iterator[SpanGroup] ) -> Iterator[Iterable[Iterable[Candidate]]]: """ Return candidate entities for the given mentions and fetching appropriate entries from the index. kb (KnowledgeBase): Knowledge base to query. - docs (Iterator[Doc]): Doc instances with mentions (stored in `.ent`). + mentions (Iterator[SpanGroup]): Mentions per doc as SpanGroup instance. RETURNS (Iterator[Iterable[Iterable[Candidate]]]): Identified candidates per document. """ - return kb.get_candidates_all(docs) + return kb.get_candidates_all(mentions) @registry.misc("spacy.CandidateGenerator.v1") @@ -136,7 +136,7 @@ def create_candidates() -> Callable[[KnowledgeBase, Span], Iterable[Candidate]]: @registry.misc("spacy.CandidateAllGenerator.v1") def create_candidates_all() -> Callable[ - [KnowledgeBase, Iterator[Doc]], + [KnowledgeBase, Iterator[SpanGroup]], Iterator[Iterable[Iterable[Candidate]]], ]: return get_candidates_all diff --git a/spacy/pipeline/entity_linker.py b/spacy/pipeline/entity_linker.py index 6d2a114cb..169d375e1 100644 --- a/spacy/pipeline/entity_linker.py +++ b/spacy/pipeline/entity_linker.py @@ -17,7 +17,7 @@ from thinc.api import CosineDistance, Model, Optimizer, Config from thinc.api import set_dropout_rate from ..kb import KnowledgeBase, Candidate -from ..tokens import Doc, Span +from ..tokens import Doc, Span, SpanGroup from .pipe import deserialize_config from .legacy.entity_linker import EntityLinker_v1 from .trainable_pipe import TrainablePipe @@ -87,7 +87,7 @@ def make_entity_linker( entity_vector_length: int, get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]], get_candidates_all: Callable[ - [KnowledgeBase, Iterator[Doc]], Iterator[Iterable[Iterable[Candidate]]] + [KnowledgeBase, Iterator[SpanGroup]], Iterator[Iterable[Iterable[Candidate]]] ], generate_empty_kb: Callable[[Vocab, int], KnowledgeBase], overwrite: bool, @@ -108,9 +108,9 @@ def make_entity_linker( entity_vector_length (int): Size of encoding vectors in the KB. get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function producing a list of candidates, given a certain knowledge base and a textual mention. - get_candidates_all (Callable[[KnowledgeBase, Iterator[Doc]], Iterator[Iterable[Iterable[Candidate]]]]): Function - that produces a list of candidates per document, given a certain knowledge base and several textual documents - with textual mentions. + get_candidates_all (Callable[[KnowledgeBase, Iterator[SpanGroup]], Iterator[Iterable[Iterable[Candidate]]]]): + Function producing a list of candidates per document, given a certain knowledge base and several textual + documents with textual mentions. generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase. scorer (Optional[Callable]): The scoring method. use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another @@ -187,7 +187,8 @@ class EntityLinker(TrainablePipe): entity_vector_length: int, get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]], get_candidates_all: Callable[ - [KnowledgeBase, Iterator[Doc]], Iterator[Iterable[Iterable[Candidate]]] + [KnowledgeBase, Iterator[SpanGroup]], + Iterator[Iterable[Iterable[Candidate]]], ], generate_empty_kb: Callable[[Vocab, int], KnowledgeBase], overwrite: bool = BACKWARD_OVERWRITE, @@ -209,8 +210,8 @@ class EntityLinker(TrainablePipe): entity_vector_length (int): Size of encoding vectors in the KB. get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function producing a list of candidates, given a certain knowledge base and a textual mention. - get_candidates_all (Callable[[KnowledgeBase, Iterator[Doc]], Iterator[Iterable[Iterable[Candidate]]]]): - Function that produces a list of candidates per document, given a certain knowledge base and several textual + get_candidates_all (Callable[[KnowledgeBase, Iterator[SpanGroup]], Iterator[Iterable[Iterable[Candidate]]]]): + Function producing a list of candidates per document, given a certain knowledge base and several textual documents with textual mentions. generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase. scorer (Optional[Callable]): The scoring method. Defaults to Scorer.score_links. @@ -468,24 +469,13 @@ class EntityLinker(TrainablePipe): # Call candidate generator. if self.candidates_doc_mode: - - def _adjust_ents_in_doc(doc: Doc, valid_ent_idx: Iterable[int]) -> Doc: - """ - Generates copy of doc object with only those ents that are candidates are to be retrieved for. - doc (Doc): Doc object to adjust. - valid_ent_idx (Iterable[int]): Indices of entities to keep. - RETURN (doc): Doc instance with only valid entities (i.e. those to retrieve candidates for). - """ - _doc = doc.copy() - # mypy complains about mismatching types here (Tuple[str] vs. Tuple[str, ...]), which isn't correct and - # probably an artifact of a misreading of the Cython code. - _doc.ents = tuple([doc.ents[i] for i in valid_ent_idx]) # type: ignore - return _doc - all_ent_cands = self.get_candidates_all( self.kb, ( - _adjust_ents_in_doc(doc, next(valid_ent_idx_per_doc)) + SpanGroup( + doc, + spans=[doc.ents[idx] for idx in next(valid_ent_idx_per_doc)], + ) for doc in docs if len(doc) and len(doc.ents) ), diff --git a/spacy/tests/pipeline/test_entity_linker.py b/spacy/tests/pipeline/test_entity_linker.py index 19b276045..c6ae4d5a9 100644 --- a/spacy/tests/pipeline/test_entity_linker.py +++ b/spacy/tests/pipeline/test_entity_linker.py @@ -1,4 +1,4 @@ -from typing import Callable, Iterable, Dict, Any, Generator, Iterator +from typing import Callable, Iterable, Dict, Any, Iterator import pytest from numpy.testing import assert_equal @@ -15,7 +15,7 @@ from spacy.pipeline.legacy import EntityLinker_v1 from spacy.pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL from spacy.scorer import Scorer from spacy.tests.util import make_tempdir -from spacy.tokens import Span, Doc +from spacy.tokens import Span, Doc, SpanGroup from spacy.training import Example from spacy.util import ensure_path from spacy.vocab import Vocab @@ -500,12 +500,14 @@ def test_el_pipe_configuration(nlp): # Replace the pipe with a new one with with a different candidate generator. - def get_lowercased_candidates(kb, span): + def get_lowercased_candidates(kb: InMemoryLookupKB, span: Span): return kb.get_alias_candidates(span.text.lower()) - def get_lowercased_candidates_all(kb, docs): - for _doc in docs: - yield [get_lowercased_candidates(kb, ent_span) for ent_span in _doc.ents] + def get_lowercased_candidates_all( + kb: InMemoryLookupKB, mentions: Iterator[SpanGroup] + ): + for doc_mentions in mentions: + yield [get_lowercased_candidates(kb, mention) for mention in doc_mentions] @registry.misc("spacy.LowercaseCandidateGenerator.v1") def create_candidates() -> Callable[ @@ -515,7 +517,7 @@ def test_el_pipe_configuration(nlp): @registry.misc("spacy.LowercaseCandidateAllGenerator.v1") def create_candidates_batch() -> Callable[ - [InMemoryLookupKB, Generator[Iterable["Span"], None, None]], + [InMemoryLookupKB, Iterator[SpanGroup]], Iterator[Iterable[Iterable[Candidate]]], ]: return get_lowercased_candidates_all diff --git a/spacy/tokens/doc.pyx b/spacy/tokens/doc.pyx index 075bc4d15..ca190cbe0 100644 --- a/spacy/tokens/doc.pyx +++ b/spacy/tokens/doc.pyx @@ -19,6 +19,8 @@ import warnings from .span cimport Span from .token cimport MISSING_DEP +from .span_group cimport SpanGroup + from ._dict_proxies import SpanGroups from .token cimport Token from ..lexeme cimport Lexeme, EMPTY_LEXEME @@ -704,6 +706,14 @@ cdef class Doc: """ return self.text + @property + def ents_spangroup(self) -> SpanGroup: + """ + Returns entities (in `.ents`) as `SpanGroup`. + RETURNS (SpanGroup): All entities (in `.ents`) as `SpanGroup`. + """ + return SpanGroup(self, spans=self.ents, name="ents") + property ents: """The named entities in the document. Returns a tuple of named entity `Span` objects, if the entity recognizer has been applied.