mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-01 10:23:07 +03:00
Switch to SpanGroup (from Doc) for bundling Spans for candidate retrieval.
This commit is contained in:
parent
581c2fd40f
commit
b6bc6885d9
|
@ -1,11 +1,11 @@
|
||||||
# cython: infer_types=True, profile=True
|
# cython: infer_types=True, profile=True
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Iterable, Tuple, Union, Iterator, TypeVar, Type, Optional
|
from typing import Iterable, Tuple, Union, Iterator, TypeVar, Type, Callable
|
||||||
from cymem.cymem cimport Pool
|
from cymem.cymem cimport Pool
|
||||||
|
|
||||||
from .candidate import Candidate
|
from .candidate import Candidate
|
||||||
from ..tokens import Span, Doc
|
from ..tokens import Span, SpanGroup, Doc
|
||||||
from ..util import SimpleFrozenList
|
from ..util import SimpleFrozenList
|
||||||
from ..errors import Errors
|
from ..errors import Errors
|
||||||
|
|
||||||
|
@ -32,16 +32,26 @@ cdef class KnowledgeBase:
|
||||||
self.entity_vector_length = entity_vector_length
|
self.entity_vector_length = entity_vector_length
|
||||||
self.mem = Pool()
|
self.mem = Pool()
|
||||||
|
|
||||||
def get_candidates_all(self, docs: Iterator[Doc]) -> Iterator[Iterable[Iterable[Candidate]]]:
|
def get_candidates_all(self, mentions: Iterator[SpanGroup]) -> Iterator[Iterable[Iterable[Candidate]]]:
|
||||||
"""
|
"""
|
||||||
Return candidate entities for mentions stored in `ent` attribute in passed docs. Each candidate defines the
|
Return candidate entities for mentions stored in `ent` attribute in passed docs. Each candidate defines the
|
||||||
entity, the original alias, and the prior probability of that alias resolving to that entity.
|
entity, the original alias, and the prior probability of that alias resolving to that entity.
|
||||||
If no candidate is found for a given mention, an empty list is returned.
|
If no candidate is found for a given mention, an empty list is returned.
|
||||||
docs (Iterator[Doc]): Doc instances with mentions (stored in `.ent`).
|
mentions (Iterator[SpanGroup]): Mentions per doc as SpanGroup instance.
|
||||||
RETURNS (Iterator[Iterable[Iterable[Candidate]]]): Identified candidates per document.
|
RETURNS (Iterator[Iterable[Iterable[Candidate]]]): Identified candidates per document.
|
||||||
"""
|
"""
|
||||||
for doc in docs:
|
for doc_mentions in mentions:
|
||||||
yield [self.get_candidates(ent_span) for ent_span in doc.ents]
|
yield [self.get_candidates(ent_span) for ent_span in doc_mentions]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_ents_as_spangroup(doc: Doc, extractor: Union[str, Callable[[Iterable[Span]], Doc]] = "ent") -> SpanGroup:
|
||||||
|
"""
|
||||||
|
Fetch entities from doc and returns them as a SpanGroup ready to be used in
|
||||||
|
`KnowledgeBase.get_candidates_all()`.
|
||||||
|
doc (Doc): Doc whose entities should be fetched.
|
||||||
|
extractor (Union[str, Callable[[Iterable[Span]], Doc]]): Defines how to retrieve object holding spans
|
||||||
|
used to describe entities. This can be a key referring to a property of the doc instance (e.g. "
|
||||||
|
"""
|
||||||
|
|
||||||
def get_candidates(self, mention: Span) -> Iterable[Candidate]:
|
def get_candidates(self, mention: Span) -> Iterable[Candidate]:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -231,8 +231,8 @@ cdef class InMemoryLookupKB(KnowledgeBase):
|
||||||
alias_entry.probs = probs
|
alias_entry.probs = probs
|
||||||
self._aliases_table[alias_index] = alias_entry
|
self._aliases_table[alias_index] = alias_entry
|
||||||
|
|
||||||
def get_candidates(self, mention: Span, doc: Optional[Doc] = None) -> Iterable[Candidate]:
|
def get_candidates(self, mention: Span) -> Iterable[Candidate]:
|
||||||
return self.get_alias_candidates(mention.text) # type: ignore
|
return self.get_alias_candidates(mention.text)
|
||||||
|
|
||||||
def get_alias_candidates(self, str alias) -> Iterable[Candidate]:
|
def get_alias_candidates(self, str alias) -> Iterable[Candidate]:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -8,7 +8,7 @@ from ...util import registry
|
||||||
from ...kb import KnowledgeBase, InMemoryLookupKB
|
from ...kb import KnowledgeBase, InMemoryLookupKB
|
||||||
from ...kb import Candidate
|
from ...kb import Candidate
|
||||||
from ...vocab import Vocab
|
from ...vocab import Vocab
|
||||||
from ...tokens import Span, Doc
|
from ...tokens import Span, Doc, SpanGroup
|
||||||
from ..extract_spans import extract_spans
|
from ..extract_spans import extract_spans
|
||||||
from ...errors import Errors
|
from ...errors import Errors
|
||||||
|
|
||||||
|
@ -118,15 +118,15 @@ def get_candidates(kb: KnowledgeBase, mention: Span) -> Iterable[Candidate]:
|
||||||
|
|
||||||
|
|
||||||
def get_candidates_all(
|
def get_candidates_all(
|
||||||
kb: KnowledgeBase, docs: Iterator[Doc]
|
kb: KnowledgeBase, mentions: Iterator[SpanGroup]
|
||||||
) -> Iterator[Iterable[Iterable[Candidate]]]:
|
) -> Iterator[Iterable[Iterable[Candidate]]]:
|
||||||
"""
|
"""
|
||||||
Return candidate entities for the given mentions and fetching appropriate entries from the index.
|
Return candidate entities for the given mentions and fetching appropriate entries from the index.
|
||||||
kb (KnowledgeBase): Knowledge base to query.
|
kb (KnowledgeBase): Knowledge base to query.
|
||||||
docs (Iterator[Doc]): Doc instances with mentions (stored in `.ent`).
|
mentions (Iterator[SpanGroup]): Mentions per doc as SpanGroup instance.
|
||||||
RETURNS (Iterator[Iterable[Iterable[Candidate]]]): Identified candidates per document.
|
RETURNS (Iterator[Iterable[Iterable[Candidate]]]): Identified candidates per document.
|
||||||
"""
|
"""
|
||||||
return kb.get_candidates_all(docs)
|
return kb.get_candidates_all(mentions)
|
||||||
|
|
||||||
|
|
||||||
@registry.misc("spacy.CandidateGenerator.v1")
|
@registry.misc("spacy.CandidateGenerator.v1")
|
||||||
|
@ -136,7 +136,7 @@ def create_candidates() -> Callable[[KnowledgeBase, Span], Iterable[Candidate]]:
|
||||||
|
|
||||||
@registry.misc("spacy.CandidateAllGenerator.v1")
|
@registry.misc("spacy.CandidateAllGenerator.v1")
|
||||||
def create_candidates_all() -> Callable[
|
def create_candidates_all() -> Callable[
|
||||||
[KnowledgeBase, Iterator[Doc]],
|
[KnowledgeBase, Iterator[SpanGroup]],
|
||||||
Iterator[Iterable[Iterable[Candidate]]],
|
Iterator[Iterable[Iterable[Candidate]]],
|
||||||
]:
|
]:
|
||||||
return get_candidates_all
|
return get_candidates_all
|
||||||
|
|
|
@ -17,7 +17,7 @@ from thinc.api import CosineDistance, Model, Optimizer, Config
|
||||||
from thinc.api import set_dropout_rate
|
from thinc.api import set_dropout_rate
|
||||||
|
|
||||||
from ..kb import KnowledgeBase, Candidate
|
from ..kb import KnowledgeBase, Candidate
|
||||||
from ..tokens import Doc, Span
|
from ..tokens import Doc, Span, SpanGroup
|
||||||
from .pipe import deserialize_config
|
from .pipe import deserialize_config
|
||||||
from .legacy.entity_linker import EntityLinker_v1
|
from .legacy.entity_linker import EntityLinker_v1
|
||||||
from .trainable_pipe import TrainablePipe
|
from .trainable_pipe import TrainablePipe
|
||||||
|
@ -87,7 +87,7 @@ def make_entity_linker(
|
||||||
entity_vector_length: int,
|
entity_vector_length: int,
|
||||||
get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]],
|
get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]],
|
||||||
get_candidates_all: Callable[
|
get_candidates_all: Callable[
|
||||||
[KnowledgeBase, Iterator[Doc]], Iterator[Iterable[Iterable[Candidate]]]
|
[KnowledgeBase, Iterator[SpanGroup]], Iterator[Iterable[Iterable[Candidate]]]
|
||||||
],
|
],
|
||||||
generate_empty_kb: Callable[[Vocab, int], KnowledgeBase],
|
generate_empty_kb: Callable[[Vocab, int], KnowledgeBase],
|
||||||
overwrite: bool,
|
overwrite: bool,
|
||||||
|
@ -108,9 +108,9 @@ def make_entity_linker(
|
||||||
entity_vector_length (int): Size of encoding vectors in the KB.
|
entity_vector_length (int): Size of encoding vectors in the KB.
|
||||||
get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function producing a list of
|
get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function producing a list of
|
||||||
candidates, given a certain knowledge base and a textual mention.
|
candidates, given a certain knowledge base and a textual mention.
|
||||||
get_candidates_all (Callable[[KnowledgeBase, Iterator[Doc]], Iterator[Iterable[Iterable[Candidate]]]]): Function
|
get_candidates_all (Callable[[KnowledgeBase, Iterator[SpanGroup]], Iterator[Iterable[Iterable[Candidate]]]]):
|
||||||
that produces a list of candidates per document, given a certain knowledge base and several textual documents
|
Function producing a list of candidates per document, given a certain knowledge base and several textual
|
||||||
with textual mentions.
|
documents with textual mentions.
|
||||||
generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase.
|
generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase.
|
||||||
scorer (Optional[Callable]): The scoring method.
|
scorer (Optional[Callable]): The scoring method.
|
||||||
use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another
|
use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another
|
||||||
|
@ -187,7 +187,8 @@ class EntityLinker(TrainablePipe):
|
||||||
entity_vector_length: int,
|
entity_vector_length: int,
|
||||||
get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]],
|
get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]],
|
||||||
get_candidates_all: Callable[
|
get_candidates_all: Callable[
|
||||||
[KnowledgeBase, Iterator[Doc]], Iterator[Iterable[Iterable[Candidate]]]
|
[KnowledgeBase, Iterator[SpanGroup]],
|
||||||
|
Iterator[Iterable[Iterable[Candidate]]],
|
||||||
],
|
],
|
||||||
generate_empty_kb: Callable[[Vocab, int], KnowledgeBase],
|
generate_empty_kb: Callable[[Vocab, int], KnowledgeBase],
|
||||||
overwrite: bool = BACKWARD_OVERWRITE,
|
overwrite: bool = BACKWARD_OVERWRITE,
|
||||||
|
@ -209,8 +210,8 @@ class EntityLinker(TrainablePipe):
|
||||||
entity_vector_length (int): Size of encoding vectors in the KB.
|
entity_vector_length (int): Size of encoding vectors in the KB.
|
||||||
get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function producing a list
|
get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function producing a list
|
||||||
of candidates, given a certain knowledge base and a textual mention.
|
of candidates, given a certain knowledge base and a textual mention.
|
||||||
get_candidates_all (Callable[[KnowledgeBase, Iterator[Doc]], Iterator[Iterable[Iterable[Candidate]]]]):
|
get_candidates_all (Callable[[KnowledgeBase, Iterator[SpanGroup]], Iterator[Iterable[Iterable[Candidate]]]]):
|
||||||
Function that produces a list of candidates per document, given a certain knowledge base and several textual
|
Function producing a list of candidates per document, given a certain knowledge base and several textual
|
||||||
documents with textual mentions.
|
documents with textual mentions.
|
||||||
generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase.
|
generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase.
|
||||||
scorer (Optional[Callable]): The scoring method. Defaults to Scorer.score_links.
|
scorer (Optional[Callable]): The scoring method. Defaults to Scorer.score_links.
|
||||||
|
@ -468,24 +469,13 @@ class EntityLinker(TrainablePipe):
|
||||||
|
|
||||||
# Call candidate generator.
|
# Call candidate generator.
|
||||||
if self.candidates_doc_mode:
|
if self.candidates_doc_mode:
|
||||||
|
|
||||||
def _adjust_ents_in_doc(doc: Doc, valid_ent_idx: Iterable[int]) -> Doc:
|
|
||||||
"""
|
|
||||||
Generates copy of doc object with only those ents that are candidates are to be retrieved for.
|
|
||||||
doc (Doc): Doc object to adjust.
|
|
||||||
valid_ent_idx (Iterable[int]): Indices of entities to keep.
|
|
||||||
RETURN (doc): Doc instance with only valid entities (i.e. those to retrieve candidates for).
|
|
||||||
"""
|
|
||||||
_doc = doc.copy()
|
|
||||||
# mypy complains about mismatching types here (Tuple[str] vs. Tuple[str, ...]), which isn't correct and
|
|
||||||
# probably an artifact of a misreading of the Cython code.
|
|
||||||
_doc.ents = tuple([doc.ents[i] for i in valid_ent_idx]) # type: ignore
|
|
||||||
return _doc
|
|
||||||
|
|
||||||
all_ent_cands = self.get_candidates_all(
|
all_ent_cands = self.get_candidates_all(
|
||||||
self.kb,
|
self.kb,
|
||||||
(
|
(
|
||||||
_adjust_ents_in_doc(doc, next(valid_ent_idx_per_doc))
|
SpanGroup(
|
||||||
|
doc,
|
||||||
|
spans=[doc.ents[idx] for idx in next(valid_ent_idx_per_doc)],
|
||||||
|
)
|
||||||
for doc in docs
|
for doc in docs
|
||||||
if len(doc) and len(doc.ents)
|
if len(doc) and len(doc.ents)
|
||||||
),
|
),
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Callable, Iterable, Dict, Any, Generator, Iterator
|
from typing import Callable, Iterable, Dict, Any, Iterator
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from numpy.testing import assert_equal
|
from numpy.testing import assert_equal
|
||||||
|
@ -15,7 +15,7 @@ from spacy.pipeline.legacy import EntityLinker_v1
|
||||||
from spacy.pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL
|
from spacy.pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL
|
||||||
from spacy.scorer import Scorer
|
from spacy.scorer import Scorer
|
||||||
from spacy.tests.util import make_tempdir
|
from spacy.tests.util import make_tempdir
|
||||||
from spacy.tokens import Span, Doc
|
from spacy.tokens import Span, Doc, SpanGroup
|
||||||
from spacy.training import Example
|
from spacy.training import Example
|
||||||
from spacy.util import ensure_path
|
from spacy.util import ensure_path
|
||||||
from spacy.vocab import Vocab
|
from spacy.vocab import Vocab
|
||||||
|
@ -500,12 +500,14 @@ def test_el_pipe_configuration(nlp):
|
||||||
|
|
||||||
# Replace the pipe with a new one with with a different candidate generator.
|
# Replace the pipe with a new one with with a different candidate generator.
|
||||||
|
|
||||||
def get_lowercased_candidates(kb, span):
|
def get_lowercased_candidates(kb: InMemoryLookupKB, span: Span):
|
||||||
return kb.get_alias_candidates(span.text.lower())
|
return kb.get_alias_candidates(span.text.lower())
|
||||||
|
|
||||||
def get_lowercased_candidates_all(kb, docs):
|
def get_lowercased_candidates_all(
|
||||||
for _doc in docs:
|
kb: InMemoryLookupKB, mentions: Iterator[SpanGroup]
|
||||||
yield [get_lowercased_candidates(kb, ent_span) for ent_span in _doc.ents]
|
):
|
||||||
|
for doc_mentions in mentions:
|
||||||
|
yield [get_lowercased_candidates(kb, mention) for mention in doc_mentions]
|
||||||
|
|
||||||
@registry.misc("spacy.LowercaseCandidateGenerator.v1")
|
@registry.misc("spacy.LowercaseCandidateGenerator.v1")
|
||||||
def create_candidates() -> Callable[
|
def create_candidates() -> Callable[
|
||||||
|
@ -515,7 +517,7 @@ def test_el_pipe_configuration(nlp):
|
||||||
|
|
||||||
@registry.misc("spacy.LowercaseCandidateAllGenerator.v1")
|
@registry.misc("spacy.LowercaseCandidateAllGenerator.v1")
|
||||||
def create_candidates_batch() -> Callable[
|
def create_candidates_batch() -> Callable[
|
||||||
[InMemoryLookupKB, Generator[Iterable["Span"], None, None]],
|
[InMemoryLookupKB, Iterator[SpanGroup]],
|
||||||
Iterator[Iterable[Iterable[Candidate]]],
|
Iterator[Iterable[Iterable[Candidate]]],
|
||||||
]:
|
]:
|
||||||
return get_lowercased_candidates_all
|
return get_lowercased_candidates_all
|
||||||
|
|
|
@ -19,6 +19,8 @@ import warnings
|
||||||
|
|
||||||
from .span cimport Span
|
from .span cimport Span
|
||||||
from .token cimport MISSING_DEP
|
from .token cimport MISSING_DEP
|
||||||
|
from .span_group cimport SpanGroup
|
||||||
|
|
||||||
from ._dict_proxies import SpanGroups
|
from ._dict_proxies import SpanGroups
|
||||||
from .token cimport Token
|
from .token cimport Token
|
||||||
from ..lexeme cimport Lexeme, EMPTY_LEXEME
|
from ..lexeme cimport Lexeme, EMPTY_LEXEME
|
||||||
|
@ -704,6 +706,14 @@ cdef class Doc:
|
||||||
"""
|
"""
|
||||||
return self.text
|
return self.text
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ents_spangroup(self) -> SpanGroup:
|
||||||
|
"""
|
||||||
|
Returns entities (in `.ents`) as `SpanGroup`.
|
||||||
|
RETURNS (SpanGroup): All entities (in `.ents`) as `SpanGroup`.
|
||||||
|
"""
|
||||||
|
return SpanGroup(self, spans=self.ents, name="ents")
|
||||||
|
|
||||||
property ents:
|
property ents:
|
||||||
"""The named entities in the document. Returns a tuple of named entity
|
"""The named entities in the document. Returns a tuple of named entity
|
||||||
`Span` objects, if the entity recognizer has been applied.
|
`Span` objects, if the entity recognizer has been applied.
|
||||||
|
|
Loading…
Reference in New Issue
Block a user