Switch to SpanGroup (from Doc) for bundling Spans for candidate retrieval.

This commit is contained in:
Raphael Mitsch 2022-12-15 10:17:25 +01:00
parent 581c2fd40f
commit b6bc6885d9
6 changed files with 55 additions and 43 deletions

View File

@ -1,11 +1,11 @@
# cython: infer_types=True, profile=True
from pathlib import Path
from typing import Iterable, Tuple, Union, Iterator, TypeVar, Type, Optional
from typing import Iterable, Tuple, Union, Iterator, TypeVar, Type, Callable
from cymem.cymem cimport Pool
from .candidate import Candidate
from ..tokens import Span, Doc
from ..tokens import Span, SpanGroup, Doc
from ..util import SimpleFrozenList
from ..errors import Errors
@ -32,16 +32,26 @@ cdef class KnowledgeBase:
self.entity_vector_length = entity_vector_length
self.mem = Pool()
def get_candidates_all(self, docs: Iterator[Doc]) -> Iterator[Iterable[Iterable[Candidate]]]:
def get_candidates_all(self, mentions: Iterator[SpanGroup]) -> Iterator[Iterable[Iterable[Candidate]]]:
"""
Return candidate entities for mentions stored in `ent` attribute in passed docs. Each candidate defines the
entity, the original alias, and the prior probability of that alias resolving to that entity.
If no candidate is found for a given mention, an empty list is returned.
docs (Iterator[Doc]): Doc instances with mentions (stored in `.ent`).
mentions (Iterator[SpanGroup]): Mentions per doc as SpanGroup instance.
RETURNS (Iterator[Iterable[Iterable[Candidate]]]): Identified candidates per document.
"""
for doc in docs:
yield [self.get_candidates(ent_span) for ent_span in doc.ents]
for doc_mentions in mentions:
yield [self.get_candidates(ent_span) for ent_span in doc_mentions]
@staticmethod
def get_ents_as_spangroup(doc: Doc, extractor: Union[str, Callable[[Iterable[Span]], Doc]] = "ent") -> SpanGroup:
"""
Fetch entities from doc and returns them as a SpanGroup ready to be used in
`KnowledgeBase.get_candidates_all()`.
doc (Doc): Doc whose entities should be fetched.
extractor (Union[str, Callable[[Iterable[Span]], Doc]]): Defines how to retrieve object holding spans
used to describe entities. This can be a key referring to a property of the doc instance (e.g. "
"""
def get_candidates(self, mention: Span) -> Iterable[Candidate]:
"""

View File

@ -231,8 +231,8 @@ cdef class InMemoryLookupKB(KnowledgeBase):
alias_entry.probs = probs
self._aliases_table[alias_index] = alias_entry
def get_candidates(self, mention: Span, doc: Optional[Doc] = None) -> Iterable[Candidate]:
return self.get_alias_candidates(mention.text) # type: ignore
def get_candidates(self, mention: Span) -> Iterable[Candidate]:
return self.get_alias_candidates(mention.text)
def get_alias_candidates(self, str alias) -> Iterable[Candidate]:
"""

View File

@ -8,7 +8,7 @@ from ...util import registry
from ...kb import KnowledgeBase, InMemoryLookupKB
from ...kb import Candidate
from ...vocab import Vocab
from ...tokens import Span, Doc
from ...tokens import Span, Doc, SpanGroup
from ..extract_spans import extract_spans
from ...errors import Errors
@ -118,15 +118,15 @@ def get_candidates(kb: KnowledgeBase, mention: Span) -> Iterable[Candidate]:
def get_candidates_all(
kb: KnowledgeBase, docs: Iterator[Doc]
kb: KnowledgeBase, mentions: Iterator[SpanGroup]
) -> Iterator[Iterable[Iterable[Candidate]]]:
"""
Return candidate entities for the given mentions and fetching appropriate entries from the index.
kb (KnowledgeBase): Knowledge base to query.
docs (Iterator[Doc]): Doc instances with mentions (stored in `.ent`).
mentions (Iterator[SpanGroup]): Mentions per doc as SpanGroup instance.
RETURNS (Iterator[Iterable[Iterable[Candidate]]]): Identified candidates per document.
"""
return kb.get_candidates_all(docs)
return kb.get_candidates_all(mentions)
@registry.misc("spacy.CandidateGenerator.v1")
@ -136,7 +136,7 @@ def create_candidates() -> Callable[[KnowledgeBase, Span], Iterable[Candidate]]:
@registry.misc("spacy.CandidateAllGenerator.v1")
def create_candidates_all() -> Callable[
[KnowledgeBase, Iterator[Doc]],
[KnowledgeBase, Iterator[SpanGroup]],
Iterator[Iterable[Iterable[Candidate]]],
]:
return get_candidates_all

View File

@ -17,7 +17,7 @@ from thinc.api import CosineDistance, Model, Optimizer, Config
from thinc.api import set_dropout_rate
from ..kb import KnowledgeBase, Candidate
from ..tokens import Doc, Span
from ..tokens import Doc, Span, SpanGroup
from .pipe import deserialize_config
from .legacy.entity_linker import EntityLinker_v1
from .trainable_pipe import TrainablePipe
@ -87,7 +87,7 @@ def make_entity_linker(
entity_vector_length: int,
get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]],
get_candidates_all: Callable[
[KnowledgeBase, Iterator[Doc]], Iterator[Iterable[Iterable[Candidate]]]
[KnowledgeBase, Iterator[SpanGroup]], Iterator[Iterable[Iterable[Candidate]]]
],
generate_empty_kb: Callable[[Vocab, int], KnowledgeBase],
overwrite: bool,
@ -108,9 +108,9 @@ def make_entity_linker(
entity_vector_length (int): Size of encoding vectors in the KB.
get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function producing a list of
candidates, given a certain knowledge base and a textual mention.
get_candidates_all (Callable[[KnowledgeBase, Iterator[Doc]], Iterator[Iterable[Iterable[Candidate]]]]): Function
that produces a list of candidates per document, given a certain knowledge base and several textual documents
with textual mentions.
get_candidates_all (Callable[[KnowledgeBase, Iterator[SpanGroup]], Iterator[Iterable[Iterable[Candidate]]]]):
Function producing a list of candidates per document, given a certain knowledge base and several textual
documents with textual mentions.
generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase.
scorer (Optional[Callable]): The scoring method.
use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another
@ -187,7 +187,8 @@ class EntityLinker(TrainablePipe):
entity_vector_length: int,
get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]],
get_candidates_all: Callable[
[KnowledgeBase, Iterator[Doc]], Iterator[Iterable[Iterable[Candidate]]]
[KnowledgeBase, Iterator[SpanGroup]],
Iterator[Iterable[Iterable[Candidate]]],
],
generate_empty_kb: Callable[[Vocab, int], KnowledgeBase],
overwrite: bool = BACKWARD_OVERWRITE,
@ -209,8 +210,8 @@ class EntityLinker(TrainablePipe):
entity_vector_length (int): Size of encoding vectors in the KB.
get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function producing a list
of candidates, given a certain knowledge base and a textual mention.
get_candidates_all (Callable[[KnowledgeBase, Iterator[Doc]], Iterator[Iterable[Iterable[Candidate]]]]):
Function that produces a list of candidates per document, given a certain knowledge base and several textual
get_candidates_all (Callable[[KnowledgeBase, Iterator[SpanGroup]], Iterator[Iterable[Iterable[Candidate]]]]):
Function producing a list of candidates per document, given a certain knowledge base and several textual
documents with textual mentions.
generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase.
scorer (Optional[Callable]): The scoring method. Defaults to Scorer.score_links.
@ -468,24 +469,13 @@ class EntityLinker(TrainablePipe):
# Call candidate generator.
if self.candidates_doc_mode:
def _adjust_ents_in_doc(doc: Doc, valid_ent_idx: Iterable[int]) -> Doc:
"""
Generates copy of doc object with only those ents that are candidates are to be retrieved for.
doc (Doc): Doc object to adjust.
valid_ent_idx (Iterable[int]): Indices of entities to keep.
RETURN (doc): Doc instance with only valid entities (i.e. those to retrieve candidates for).
"""
_doc = doc.copy()
# mypy complains about mismatching types here (Tuple[str] vs. Tuple[str, ...]), which isn't correct and
# probably an artifact of a misreading of the Cython code.
_doc.ents = tuple([doc.ents[i] for i in valid_ent_idx]) # type: ignore
return _doc
all_ent_cands = self.get_candidates_all(
self.kb,
(
_adjust_ents_in_doc(doc, next(valid_ent_idx_per_doc))
SpanGroup(
doc,
spans=[doc.ents[idx] for idx in next(valid_ent_idx_per_doc)],
)
for doc in docs
if len(doc) and len(doc.ents)
),

View File

@ -1,4 +1,4 @@
from typing import Callable, Iterable, Dict, Any, Generator, Iterator
from typing import Callable, Iterable, Dict, Any, Iterator
import pytest
from numpy.testing import assert_equal
@ -15,7 +15,7 @@ from spacy.pipeline.legacy import EntityLinker_v1
from spacy.pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL
from spacy.scorer import Scorer
from spacy.tests.util import make_tempdir
from spacy.tokens import Span, Doc
from spacy.tokens import Span, Doc, SpanGroup
from spacy.training import Example
from spacy.util import ensure_path
from spacy.vocab import Vocab
@ -500,12 +500,14 @@ def test_el_pipe_configuration(nlp):
# Replace the pipe with a new one with with a different candidate generator.
def get_lowercased_candidates(kb, span):
def get_lowercased_candidates(kb: InMemoryLookupKB, span: Span):
return kb.get_alias_candidates(span.text.lower())
def get_lowercased_candidates_all(kb, docs):
for _doc in docs:
yield [get_lowercased_candidates(kb, ent_span) for ent_span in _doc.ents]
def get_lowercased_candidates_all(
kb: InMemoryLookupKB, mentions: Iterator[SpanGroup]
):
for doc_mentions in mentions:
yield [get_lowercased_candidates(kb, mention) for mention in doc_mentions]
@registry.misc("spacy.LowercaseCandidateGenerator.v1")
def create_candidates() -> Callable[
@ -515,7 +517,7 @@ def test_el_pipe_configuration(nlp):
@registry.misc("spacy.LowercaseCandidateAllGenerator.v1")
def create_candidates_batch() -> Callable[
[InMemoryLookupKB, Generator[Iterable["Span"], None, None]],
[InMemoryLookupKB, Iterator[SpanGroup]],
Iterator[Iterable[Iterable[Candidate]]],
]:
return get_lowercased_candidates_all

View File

@ -19,6 +19,8 @@ import warnings
from .span cimport Span
from .token cimport MISSING_DEP
from .span_group cimport SpanGroup
from ._dict_proxies import SpanGroups
from .token cimport Token
from ..lexeme cimport Lexeme, EMPTY_LEXEME
@ -704,6 +706,14 @@ cdef class Doc:
"""
return self.text
@property
def ents_spangroup(self) -> SpanGroup:
"""
Returns entities (in `.ents`) as `SpanGroup`.
RETURNS (SpanGroup): All entities (in `.ents`) as `SpanGroup`.
"""
return SpanGroup(self, spans=self.ents, name="ents")
property ents:
"""The named entities in the document. Returns a tuple of named entity
`Span` objects, if the entity recognizer has been applied.