mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-01 02:13:07 +03:00
Switch to SpanGroup (from Doc) for bundling Spans for candidate retrieval.
This commit is contained in:
parent
581c2fd40f
commit
b6bc6885d9
|
@ -1,11 +1,11 @@
|
|||
# cython: infer_types=True, profile=True
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Tuple, Union, Iterator, TypeVar, Type, Optional
|
||||
from typing import Iterable, Tuple, Union, Iterator, TypeVar, Type, Callable
|
||||
from cymem.cymem cimport Pool
|
||||
|
||||
from .candidate import Candidate
|
||||
from ..tokens import Span, Doc
|
||||
from ..tokens import Span, SpanGroup, Doc
|
||||
from ..util import SimpleFrozenList
|
||||
from ..errors import Errors
|
||||
|
||||
|
@ -32,16 +32,26 @@ cdef class KnowledgeBase:
|
|||
self.entity_vector_length = entity_vector_length
|
||||
self.mem = Pool()
|
||||
|
||||
def get_candidates_all(self, docs: Iterator[Doc]) -> Iterator[Iterable[Iterable[Candidate]]]:
|
||||
def get_candidates_all(self, mentions: Iterator[SpanGroup]) -> Iterator[Iterable[Iterable[Candidate]]]:
|
||||
"""
|
||||
Return candidate entities for mentions stored in `ent` attribute in passed docs. Each candidate defines the
|
||||
entity, the original alias, and the prior probability of that alias resolving to that entity.
|
||||
If no candidate is found for a given mention, an empty list is returned.
|
||||
docs (Iterator[Doc]): Doc instances with mentions (stored in `.ent`).
|
||||
mentions (Iterator[SpanGroup]): Mentions per doc as SpanGroup instance.
|
||||
RETURNS (Iterator[Iterable[Iterable[Candidate]]]): Identified candidates per document.
|
||||
"""
|
||||
for doc in docs:
|
||||
yield [self.get_candidates(ent_span) for ent_span in doc.ents]
|
||||
for doc_mentions in mentions:
|
||||
yield [self.get_candidates(ent_span) for ent_span in doc_mentions]
|
||||
|
||||
@staticmethod
|
||||
def get_ents_as_spangroup(doc: Doc, extractor: Union[str, Callable[[Iterable[Span]], Doc]] = "ent") -> SpanGroup:
|
||||
"""
|
||||
Fetch entities from doc and returns them as a SpanGroup ready to be used in
|
||||
`KnowledgeBase.get_candidates_all()`.
|
||||
doc (Doc): Doc whose entities should be fetched.
|
||||
extractor (Union[str, Callable[[Iterable[Span]], Doc]]): Defines how to retrieve object holding spans
|
||||
used to describe entities. This can be a key referring to a property of the doc instance (e.g. "
|
||||
"""
|
||||
|
||||
def get_candidates(self, mention: Span) -> Iterable[Candidate]:
|
||||
"""
|
||||
|
|
|
@ -231,8 +231,8 @@ cdef class InMemoryLookupKB(KnowledgeBase):
|
|||
alias_entry.probs = probs
|
||||
self._aliases_table[alias_index] = alias_entry
|
||||
|
||||
def get_candidates(self, mention: Span, doc: Optional[Doc] = None) -> Iterable[Candidate]:
|
||||
return self.get_alias_candidates(mention.text) # type: ignore
|
||||
def get_candidates(self, mention: Span) -> Iterable[Candidate]:
|
||||
return self.get_alias_candidates(mention.text)
|
||||
|
||||
def get_alias_candidates(self, str alias) -> Iterable[Candidate]:
|
||||
"""
|
||||
|
|
|
@ -8,7 +8,7 @@ from ...util import registry
|
|||
from ...kb import KnowledgeBase, InMemoryLookupKB
|
||||
from ...kb import Candidate
|
||||
from ...vocab import Vocab
|
||||
from ...tokens import Span, Doc
|
||||
from ...tokens import Span, Doc, SpanGroup
|
||||
from ..extract_spans import extract_spans
|
||||
from ...errors import Errors
|
||||
|
||||
|
@ -118,15 +118,15 @@ def get_candidates(kb: KnowledgeBase, mention: Span) -> Iterable[Candidate]:
|
|||
|
||||
|
||||
def get_candidates_all(
|
||||
kb: KnowledgeBase, docs: Iterator[Doc]
|
||||
kb: KnowledgeBase, mentions: Iterator[SpanGroup]
|
||||
) -> Iterator[Iterable[Iterable[Candidate]]]:
|
||||
"""
|
||||
Return candidate entities for the given mentions and fetching appropriate entries from the index.
|
||||
kb (KnowledgeBase): Knowledge base to query.
|
||||
docs (Iterator[Doc]): Doc instances with mentions (stored in `.ent`).
|
||||
mentions (Iterator[SpanGroup]): Mentions per doc as SpanGroup instance.
|
||||
RETURNS (Iterator[Iterable[Iterable[Candidate]]]): Identified candidates per document.
|
||||
"""
|
||||
return kb.get_candidates_all(docs)
|
||||
return kb.get_candidates_all(mentions)
|
||||
|
||||
|
||||
@registry.misc("spacy.CandidateGenerator.v1")
|
||||
|
@ -136,7 +136,7 @@ def create_candidates() -> Callable[[KnowledgeBase, Span], Iterable[Candidate]]:
|
|||
|
||||
@registry.misc("spacy.CandidateAllGenerator.v1")
|
||||
def create_candidates_all() -> Callable[
|
||||
[KnowledgeBase, Iterator[Doc]],
|
||||
[KnowledgeBase, Iterator[SpanGroup]],
|
||||
Iterator[Iterable[Iterable[Candidate]]],
|
||||
]:
|
||||
return get_candidates_all
|
||||
|
|
|
@ -17,7 +17,7 @@ from thinc.api import CosineDistance, Model, Optimizer, Config
|
|||
from thinc.api import set_dropout_rate
|
||||
|
||||
from ..kb import KnowledgeBase, Candidate
|
||||
from ..tokens import Doc, Span
|
||||
from ..tokens import Doc, Span, SpanGroup
|
||||
from .pipe import deserialize_config
|
||||
from .legacy.entity_linker import EntityLinker_v1
|
||||
from .trainable_pipe import TrainablePipe
|
||||
|
@ -87,7 +87,7 @@ def make_entity_linker(
|
|||
entity_vector_length: int,
|
||||
get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]],
|
||||
get_candidates_all: Callable[
|
||||
[KnowledgeBase, Iterator[Doc]], Iterator[Iterable[Iterable[Candidate]]]
|
||||
[KnowledgeBase, Iterator[SpanGroup]], Iterator[Iterable[Iterable[Candidate]]]
|
||||
],
|
||||
generate_empty_kb: Callable[[Vocab, int], KnowledgeBase],
|
||||
overwrite: bool,
|
||||
|
@ -108,9 +108,9 @@ def make_entity_linker(
|
|||
entity_vector_length (int): Size of encoding vectors in the KB.
|
||||
get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function producing a list of
|
||||
candidates, given a certain knowledge base and a textual mention.
|
||||
get_candidates_all (Callable[[KnowledgeBase, Iterator[Doc]], Iterator[Iterable[Iterable[Candidate]]]]): Function
|
||||
that produces a list of candidates per document, given a certain knowledge base and several textual documents
|
||||
with textual mentions.
|
||||
get_candidates_all (Callable[[KnowledgeBase, Iterator[SpanGroup]], Iterator[Iterable[Iterable[Candidate]]]]):
|
||||
Function producing a list of candidates per document, given a certain knowledge base and several textual
|
||||
documents with textual mentions.
|
||||
generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase.
|
||||
scorer (Optional[Callable]): The scoring method.
|
||||
use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another
|
||||
|
@ -187,7 +187,8 @@ class EntityLinker(TrainablePipe):
|
|||
entity_vector_length: int,
|
||||
get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]],
|
||||
get_candidates_all: Callable[
|
||||
[KnowledgeBase, Iterator[Doc]], Iterator[Iterable[Iterable[Candidate]]]
|
||||
[KnowledgeBase, Iterator[SpanGroup]],
|
||||
Iterator[Iterable[Iterable[Candidate]]],
|
||||
],
|
||||
generate_empty_kb: Callable[[Vocab, int], KnowledgeBase],
|
||||
overwrite: bool = BACKWARD_OVERWRITE,
|
||||
|
@ -209,8 +210,8 @@ class EntityLinker(TrainablePipe):
|
|||
entity_vector_length (int): Size of encoding vectors in the KB.
|
||||
get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function producing a list
|
||||
of candidates, given a certain knowledge base and a textual mention.
|
||||
get_candidates_all (Callable[[KnowledgeBase, Iterator[Doc]], Iterator[Iterable[Iterable[Candidate]]]]):
|
||||
Function that produces a list of candidates per document, given a certain knowledge base and several textual
|
||||
get_candidates_all (Callable[[KnowledgeBase, Iterator[SpanGroup]], Iterator[Iterable[Iterable[Candidate]]]]):
|
||||
Function producing a list of candidates per document, given a certain knowledge base and several textual
|
||||
documents with textual mentions.
|
||||
generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase.
|
||||
scorer (Optional[Callable]): The scoring method. Defaults to Scorer.score_links.
|
||||
|
@ -468,24 +469,13 @@ class EntityLinker(TrainablePipe):
|
|||
|
||||
# Call candidate generator.
|
||||
if self.candidates_doc_mode:
|
||||
|
||||
def _adjust_ents_in_doc(doc: Doc, valid_ent_idx: Iterable[int]) -> Doc:
|
||||
"""
|
||||
Generates copy of doc object with only those ents that are candidates are to be retrieved for.
|
||||
doc (Doc): Doc object to adjust.
|
||||
valid_ent_idx (Iterable[int]): Indices of entities to keep.
|
||||
RETURN (doc): Doc instance with only valid entities (i.e. those to retrieve candidates for).
|
||||
"""
|
||||
_doc = doc.copy()
|
||||
# mypy complains about mismatching types here (Tuple[str] vs. Tuple[str, ...]), which isn't correct and
|
||||
# probably an artifact of a misreading of the Cython code.
|
||||
_doc.ents = tuple([doc.ents[i] for i in valid_ent_idx]) # type: ignore
|
||||
return _doc
|
||||
|
||||
all_ent_cands = self.get_candidates_all(
|
||||
self.kb,
|
||||
(
|
||||
_adjust_ents_in_doc(doc, next(valid_ent_idx_per_doc))
|
||||
SpanGroup(
|
||||
doc,
|
||||
spans=[doc.ents[idx] for idx in next(valid_ent_idx_per_doc)],
|
||||
)
|
||||
for doc in docs
|
||||
if len(doc) and len(doc.ents)
|
||||
),
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Callable, Iterable, Dict, Any, Generator, Iterator
|
||||
from typing import Callable, Iterable, Dict, Any, Iterator
|
||||
|
||||
import pytest
|
||||
from numpy.testing import assert_equal
|
||||
|
@ -15,7 +15,7 @@ from spacy.pipeline.legacy import EntityLinker_v1
|
|||
from spacy.pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL
|
||||
from spacy.scorer import Scorer
|
||||
from spacy.tests.util import make_tempdir
|
||||
from spacy.tokens import Span, Doc
|
||||
from spacy.tokens import Span, Doc, SpanGroup
|
||||
from spacy.training import Example
|
||||
from spacy.util import ensure_path
|
||||
from spacy.vocab import Vocab
|
||||
|
@ -500,12 +500,14 @@ def test_el_pipe_configuration(nlp):
|
|||
|
||||
# Replace the pipe with a new one with with a different candidate generator.
|
||||
|
||||
def get_lowercased_candidates(kb, span):
|
||||
def get_lowercased_candidates(kb: InMemoryLookupKB, span: Span):
|
||||
return kb.get_alias_candidates(span.text.lower())
|
||||
|
||||
def get_lowercased_candidates_all(kb, docs):
|
||||
for _doc in docs:
|
||||
yield [get_lowercased_candidates(kb, ent_span) for ent_span in _doc.ents]
|
||||
def get_lowercased_candidates_all(
|
||||
kb: InMemoryLookupKB, mentions: Iterator[SpanGroup]
|
||||
):
|
||||
for doc_mentions in mentions:
|
||||
yield [get_lowercased_candidates(kb, mention) for mention in doc_mentions]
|
||||
|
||||
@registry.misc("spacy.LowercaseCandidateGenerator.v1")
|
||||
def create_candidates() -> Callable[
|
||||
|
@ -515,7 +517,7 @@ def test_el_pipe_configuration(nlp):
|
|||
|
||||
@registry.misc("spacy.LowercaseCandidateAllGenerator.v1")
|
||||
def create_candidates_batch() -> Callable[
|
||||
[InMemoryLookupKB, Generator[Iterable["Span"], None, None]],
|
||||
[InMemoryLookupKB, Iterator[SpanGroup]],
|
||||
Iterator[Iterable[Iterable[Candidate]]],
|
||||
]:
|
||||
return get_lowercased_candidates_all
|
||||
|
|
|
@ -19,6 +19,8 @@ import warnings
|
|||
|
||||
from .span cimport Span
|
||||
from .token cimport MISSING_DEP
|
||||
from .span_group cimport SpanGroup
|
||||
|
||||
from ._dict_proxies import SpanGroups
|
||||
from .token cimport Token
|
||||
from ..lexeme cimport Lexeme, EMPTY_LEXEME
|
||||
|
@ -704,6 +706,14 @@ cdef class Doc:
|
|||
"""
|
||||
return self.text
|
||||
|
||||
@property
|
||||
def ents_spangroup(self) -> SpanGroup:
|
||||
"""
|
||||
Returns entities (in `.ents`) as `SpanGroup`.
|
||||
RETURNS (SpanGroup): All entities (in `.ents`) as `SpanGroup`.
|
||||
"""
|
||||
return SpanGroup(self, spans=self.ents, name="ents")
|
||||
|
||||
property ents:
|
||||
"""The named entities in the document. Returns a tuple of named entity
|
||||
`Span` objects, if the entity recognizer has been applied.
|
||||
|
|
Loading…
Reference in New Issue
Block a user