mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-11 00:32:40 +03:00
Modify candidate retrieval interface to accept docs instead of individual spans.
This commit is contained in:
parent
df6e4ab055
commit
53a24abd8b
|
@ -1,11 +1,11 @@
|
|||
# cython: infer_types=True, profile=True
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Tuple, Union, Iterator, TypeVar, Type
|
||||
from typing import Iterable, Tuple, Union, Iterator, TypeVar, Type, Optional
|
||||
from cymem.cymem cimport Pool
|
||||
|
||||
from .candidate import Candidate
|
||||
from ..tokens import Span
|
||||
from ..tokens import Span, Doc
|
||||
from ..util import SimpleFrozenList
|
||||
from ..errors import Errors
|
||||
|
||||
|
@ -32,24 +32,25 @@ cdef class KnowledgeBase:
|
|||
self.entity_vector_length = entity_vector_length
|
||||
self.mem = Pool()
|
||||
|
||||
def get_candidates_all(self, mentions: Iterator[Iterable[Span]]) -> Iterator[Iterable[Iterable[Candidate]]]:
|
||||
def get_candidates_all(self, docs: Iterator[Doc]) -> Iterator[Iterable[Iterable[Candidate]]]:
|
||||
"""
|
||||
Return candidate entities for specified mentions. Each candidate defines the entity, the original alias,
|
||||
and the prior probability of that alias resolving to that entity.
|
||||
Return candidate entities for mentions stored in `ent` attribute in passed docs. Each candidate defines the
|
||||
entity, the original alias, and the prior probability of that alias resolving to that entity.
|
||||
If no candidate is found for a given mention, an empty list is returned.
|
||||
mentions (Generator[Iterable[Span]]): Mentions per documents for which to get candidates.
|
||||
RETURNS (Generator[Iterable[Iterable[Candidate]]]): Identified candidates per document.
|
||||
docs (Iterator[Doc]): Doc instances with mentions (stored in `.ent`).
|
||||
RETURNS (Iterator[Iterable[Iterable[Candidate]]]): Identified candidates per document.
|
||||
"""
|
||||
for doc in docs:
|
||||
yield [self.get_candidates(ent_span, doc) for ent_span in doc.ents]
|
||||
|
||||
for doc_mentions in mentions:
|
||||
yield [self.get_candidates(span) for span in doc_mentions]
|
||||
|
||||
def get_candidates(self, mention: Span) -> Iterable[Candidate]:
|
||||
def get_candidates(self, mention: Span, doc: Optional[Doc] = None) -> Iterable[Candidate]:
|
||||
"""
|
||||
Return candidate entities for specified text. Each candidate defines the entity, the original alias,
|
||||
and the prior probability of that alias resolving to that entity.
|
||||
If the no candidate is found for a given text, an empty list is returned.
|
||||
Note that doc is not utilized for further context in this implementation.
|
||||
mention (Span): Mention for which to get candidates.
|
||||
doc (Optional[Doc]): Doc to use for context.
|
||||
RETURNS (Iterable[Candidate]): Identified candidates.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# cython: infer_types=True, profile=True
|
||||
from typing import Iterable, Callable, Dict, Any, Union
|
||||
from typing import Iterable, Callable, Dict, Any, Union, Optional
|
||||
|
||||
import srsly
|
||||
from preshed.maps cimport PreshMap
|
||||
|
@ -11,7 +11,7 @@ from libcpp.vector cimport vector
|
|||
from pathlib import Path
|
||||
import warnings
|
||||
|
||||
from ..tokens import Span
|
||||
from ..tokens import Span, Doc
|
||||
from ..typedefs cimport hash_t
|
||||
from ..errors import Errors, Warnings
|
||||
from .. import util
|
||||
|
@ -231,7 +231,7 @@ cdef class InMemoryLookupKB(KnowledgeBase):
|
|||
alias_entry.probs = probs
|
||||
self._aliases_table[alias_index] = alias_entry
|
||||
|
||||
def get_candidates(self, mention: Span) -> Iterable[Candidate]:
|
||||
def get_candidates(self, mention: Span, doc: Optional[Doc] = None) -> Iterable[Candidate]:
|
||||
return self.get_alias_candidates(mention.text) # type: ignore
|
||||
|
||||
def get_alias_candidates(self, str alias) -> Iterable[Candidate]:
|
||||
|
|
|
@ -118,15 +118,15 @@ def get_candidates(kb: KnowledgeBase, mention: Span) -> Iterable[Candidate]:
|
|||
|
||||
|
||||
def get_candidates_all(
|
||||
kb: KnowledgeBase, mentions: Iterator[Iterable[Span]]
|
||||
kb: KnowledgeBase, docs: Iterator[Doc]
|
||||
) -> Iterator[Iterable[Iterable[Candidate]]]:
|
||||
"""
|
||||
Return candidate entities for the given mentions and fetching appropriate entries from the index.
|
||||
kb (KnowledgeBase): Knowledge base to query.
|
||||
mention (Iterator[Iterable[Span]]): Entity mentions per document for which to identify candidates.
|
||||
docs (Iterator[Doc]): Doc instances with mentions (stored in `.ent`).
|
||||
RETURNS (Iterator[Iterable[Iterable[Candidate]]]): Identified candidates per document.
|
||||
"""
|
||||
return kb.get_candidates_all(mentions)
|
||||
return kb.get_candidates_all(docs)
|
||||
|
||||
|
||||
@registry.misc("spacy.CandidateGenerator.v1")
|
||||
|
|
|
@ -87,8 +87,7 @@ def make_entity_linker(
|
|||
entity_vector_length: int,
|
||||
get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]],
|
||||
get_candidates_all: Callable[
|
||||
[KnowledgeBase, Iterator[Iterable[Span]]],
|
||||
Iterator[Iterable[Iterable[Candidate]]],
|
||||
[KnowledgeBase, Iterator[Doc]], Iterator[Iterable[Iterable[Candidate]]]
|
||||
],
|
||||
generate_empty_kb: Callable[[Vocab, int], KnowledgeBase],
|
||||
overwrite: bool,
|
||||
|
@ -107,11 +106,11 @@ def make_entity_linker(
|
|||
incl_prior (bool): Whether or not to include prior probabilities from the KB in the model.
|
||||
incl_context (bool): Whether or not to include the local context in the model.
|
||||
entity_vector_length (int): Size of encoding vectors in the KB.
|
||||
get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function that
|
||||
produces a list of candidates, given a certain knowledge base and a textual mention.
|
||||
get_candidates_all (Callable[[KnowledgeBase, Iterator[Iterable[Span]]], Iterator[Iterable[Iterable[Candidate]]]]):
|
||||
Function that produces a list of candidates per document, given a certain knowledge base and several textual
|
||||
documents with textual mentions.
|
||||
get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function producing a list of
|
||||
candidates, given a certain knowledge base and a textual mention.
|
||||
get_candidates_all (Callable[[KnowledgeBase, Iterator[Doc]], Iterator[Iterable[Iterable[Candidate]]]]): Function
|
||||
that produces a list of candidates per document, given a certain knowledge base and several textual documents
|
||||
with textual mentions.
|
||||
generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase.
|
||||
scorer (Optional[Callable]): The scoring method.
|
||||
use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another
|
||||
|
@ -188,8 +187,7 @@ class EntityLinker(TrainablePipe):
|
|||
entity_vector_length: int,
|
||||
get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]],
|
||||
get_candidates_all: Callable[
|
||||
[KnowledgeBase, Iterator[Iterable[Span]]],
|
||||
Iterator[Iterable[Iterable[Candidate]]],
|
||||
[KnowledgeBase, Iterator[Doc]], Iterator[Iterable[Iterable[Candidate]]]
|
||||
],
|
||||
generate_empty_kb: Callable[[Vocab, int], KnowledgeBase],
|
||||
overwrite: bool = BACKWARD_OVERWRITE,
|
||||
|
@ -209,9 +207,9 @@ class EntityLinker(TrainablePipe):
|
|||
incl_prior (bool): Whether or not to include prior probabilities from the KB in the model.
|
||||
incl_context (bool): Whether or not to include the local context in the model.
|
||||
entity_vector_length (int): Size of encoding vectors in the KB.
|
||||
get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function that
|
||||
produces a list of candidates, given a certain knowledge base and a textual mention.
|
||||
get_candidates_all (Callable[[KnowledgeBase, Iterator[Iterable[Span]]], Iterator[Iterable[Iterable[Candidate]]]]):
|
||||
get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function producing a list
|
||||
of candidates, given a certain knowledge base and a textual mention.
|
||||
get_candidates_all (Callable[[KnowledgeBase, Iterator[Doc]], Iterator[Iterable[Iterable[Candidate]]]]):
|
||||
Function that produces a list of candidates per document, given a certain knowledge base and several textual
|
||||
documents with textual mentions.
|
||||
generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase.
|
||||
|
@ -330,7 +328,6 @@ class EntityLinker(TrainablePipe):
|
|||
|
||||
If one isn't present, then the update step needs to be skipped.
|
||||
"""
|
||||
|
||||
for eg in examples:
|
||||
for ent in eg.predicted.ents:
|
||||
candidates = list(self.get_candidates(self.kb, ent))
|
||||
|
@ -471,10 +468,22 @@ class EntityLinker(TrainablePipe):
|
|||
|
||||
# Call candidate generator.
|
||||
if self.candidates_doc_mode:
|
||||
|
||||
def _adjust_ents_in_doc(doc: Doc, valid_ent_idx: Iterable[int]) -> Doc:
|
||||
"""
|
||||
Generates copy of doc object with only those ents that are candidates are to be retrieved for.
|
||||
doc (Doc): Doc object to adjust.
|
||||
valid_ent_idx (Iterable[int]): Indices of entities to keep.
|
||||
RETURN (doc): Doc instance with only valid entities (i.e. those to retrieve candidates for).
|
||||
"""
|
||||
_doc = doc.copy()
|
||||
_doc.ents = [doc.ents[i] for i in valid_ent_idx]
|
||||
return _doc
|
||||
|
||||
all_ent_cands = self.get_candidates_all(
|
||||
self.kb,
|
||||
(
|
||||
[doc.ents[idx] for idx in next(valid_ent_idx_per_doc)]
|
||||
_adjust_ents_in_doc(doc, next(valid_ent_idx_per_doc))
|
||||
for doc in docs
|
||||
if len(doc) and len(doc.ents)
|
||||
),
|
||||
|
@ -564,6 +573,7 @@ class EntityLinker(TrainablePipe):
|
|||
method="predict", msg="result variables not of equal length"
|
||||
)
|
||||
raise RuntimeError(err)
|
||||
|
||||
return final_kb_ids
|
||||
|
||||
def set_annotations(self, docs: Iterable[Doc], kb_ids: List[str]) -> None:
|
||||
|
|
|
@ -503,9 +503,9 @@ def test_el_pipe_configuration(nlp):
|
|||
def get_lowercased_candidates(kb, span):
|
||||
return kb.get_alias_candidates(span.text.lower())
|
||||
|
||||
def get_lowercased_candidates_all(kb, spans_per_doc):
|
||||
for doc_spans in spans_per_doc:
|
||||
yield [get_lowercased_candidates(kb, span) for span in doc_spans]
|
||||
def get_lowercased_candidates_all(kb, docs):
|
||||
for _doc in docs:
|
||||
yield [get_lowercased_candidates(kb, ent_span) for ent_span in _doc.ents]
|
||||
|
||||
@registry.misc("spacy.LowercaseCandidateGenerator.v1")
|
||||
def create_candidates() -> Callable[
|
||||
|
|
Loading…
Reference in New Issue
Block a user