mirror of
https://github.com/explosion/spaCy.git
synced 2025-04-22 18:12:00 +03:00
Modify EL batching system.
This commit is contained in:
parent
3beda2b23a
commit
bb7418ebdd
|
@ -940,7 +940,7 @@ class Errors(metaclass=ErrorsWithCodes):
|
|||
"case pass an empty list for the previously not specified argument to avoid this error.")
|
||||
E1043 = ("Expected None or a value in range [{range_start}, {range_end}] for entity linker threshold, but got "
|
||||
"{value}.")
|
||||
E1044 = ("Expected `candidates_batch_size` to be >= 1, but got: {value}")
|
||||
# E1044 is unused
|
||||
E1045 = ("Encountered {parent} subclass without `{parent}.{method}` "
|
||||
"method in '{name}'. If you want to use this method, make "
|
||||
"sure it's overwritten on the subclass.")
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import abc
|
||||
from typing import List, Union, Callable
|
||||
from typing import List, Callable
|
||||
|
||||
|
||||
class Candidate(abc.ABC):
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# cython: infer_types=True, profile=True
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Tuple, Union
|
||||
from typing import Iterable, Tuple, Union, Iterator
|
||||
from cymem.cymem cimport Pool
|
||||
|
||||
from .candidate import Candidate
|
||||
|
@ -30,23 +30,13 @@ cdef class KnowledgeBase:
|
|||
self.entity_vector_length = entity_vector_length
|
||||
self.mem = Pool()
|
||||
|
||||
def get_candidates_batch(self, mentions: SpanGroup) -> Iterable[Iterable[Candidate]]:
|
||||
def get_candidates(self, mentions: Iterator[SpanGroup]) -> Iterator[Iterable[Iterable[Candidate]]]:
|
||||
"""
|
||||
Return candidate entities for specified texts. Each candidate defines the entity, the original alias,
|
||||
and the prior probability of that alias resolving to that entity.
|
||||
If no candidate is found for a given text, an empty list is returned.
|
||||
mentions (SpanGroup): Mentions for which to get candidates.
|
||||
RETURNS (Iterable[Iterable[Candidate]]): Identified candidates.
|
||||
"""
|
||||
return [self.get_candidates(span) for span in mentions]
|
||||
|
||||
def get_candidates(self, mention: Span) -> Iterable[Candidate]:
|
||||
"""
|
||||
Return candidate entities for specified text. Each candidate defines the entity, the original alias,
|
||||
and the prior probability of that alias resolving to that entity.
|
||||
If the no candidate is found for a given text, an empty list is returned.
|
||||
mention (Span): Mention for which to get candidates.
|
||||
RETURNS (Iterable[Candidate]): Identified candidates.
|
||||
Return candidate entities for mentions stored in `ent` attribute in passed docs. Each candidate defines the
|
||||
entity, the original alias, and the prior probability of that alias resolving to that entity.
|
||||
If no candidate is found for a given mention, an empty list is returned.
|
||||
mentions (Iterator[SpanGroup]): Mentions per doc as SpanGroup instance.
|
||||
RETURNS (Iterator[Iterable[Iterable[Candidate]]]): Identified candidates per document.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
Errors.E1045.format(parent="KnowledgeBase", method="get_candidates", name=self.__name__)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# cython: infer_types=True, profile=True
|
||||
from typing import Iterable, Callable, Dict, Any, Union
|
||||
from typing import Iterable, Callable, Dict, Any, Union, Iterator
|
||||
|
||||
import srsly
|
||||
from preshed.maps cimport PreshMap
|
||||
|
@ -11,7 +11,7 @@ from libcpp.vector cimport vector
|
|||
from pathlib import Path
|
||||
import warnings
|
||||
|
||||
from ..tokens import Span
|
||||
from ..tokens import SpanGroup
|
||||
from ..typedefs cimport hash_t
|
||||
from ..errors import Errors, Warnings
|
||||
from .. import util
|
||||
|
@ -223,8 +223,9 @@ cdef class InMemoryLookupKB(KnowledgeBase):
|
|||
alias_entry.probs = probs
|
||||
self._aliases_table[alias_index] = alias_entry
|
||||
|
||||
def get_candidates(self, mention: Span) -> Iterable[InMemoryCandidate]:
|
||||
return self.get_alias_candidates(mention.text) # type: ignore
|
||||
def get_candidates(self, mentions: Iterator[SpanGroup]) -> Iterator[Iterable[Iterable[InMemoryCandidate]]]:
|
||||
for mentions_for_doc in mentions:
|
||||
yield [self.get_alias_candidates(ent_span.text) for ent_span in mentions_for_doc]
|
||||
|
||||
def get_alias_candidates(self, str alias) -> Iterable[InMemoryCandidate]:
|
||||
"""
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from pathlib import Path
|
||||
from typing import Optional, Callable, Iterable, List, Tuple
|
||||
from typing import Optional, Callable, Iterable, List, Tuple, Iterator
|
||||
from thinc.types import Floats2d
|
||||
from thinc.api import chain, list2ragged, reduce_mean, residual
|
||||
from thinc.api import Model, Maxout, Linear, tuplify, Ragged
|
||||
|
@ -100,34 +100,20 @@ def empty_kb(
|
|||
|
||||
|
||||
@registry.misc("spacy.CandidateGenerator.v1")
|
||||
def create_candidates() -> Callable[[KnowledgeBase, Span], Iterable[Candidate]]:
|
||||
def create_candidates_all() -> Callable[
|
||||
[KnowledgeBase, Iterator[SpanGroup]],
|
||||
Iterator[Iterable[Iterable[Candidate]]],
|
||||
]:
|
||||
return get_candidates
|
||||
|
||||
|
||||
@registry.misc("spacy.CandidateBatchGenerator.v1")
|
||||
def create_candidates_batch() -> Callable[
|
||||
[KnowledgeBase, SpanGroup], Iterable[Iterable[Candidate]]
|
||||
]:
|
||||
return get_candidates_batch
|
||||
|
||||
|
||||
def get_candidates(kb: KnowledgeBase, mention: Span) -> Iterable[Candidate]:
|
||||
"""
|
||||
Return candidate entities for a given mention and fetching appropriate entries from the index.
|
||||
kb (KnowledgeBase): Knowledge base to query.
|
||||
mention (Span): Entity mention for which to identify candidates.
|
||||
RETURNS (Iterable[InMemoryCandidate]): Identified candidates.
|
||||
"""
|
||||
return kb.get_candidates(mention)
|
||||
|
||||
|
||||
def get_candidates_batch(
|
||||
kb: KnowledgeBase, mentions: SpanGroup
|
||||
) -> Iterable[Iterable[Candidate]]:
|
||||
def get_candidates(
|
||||
kb: KnowledgeBase, mentions: Iterator[SpanGroup]
|
||||
) -> Iterator[Iterable[Iterable[Candidate]]]:
|
||||
"""
|
||||
Return candidate entities for the given mentions and fetching appropriate entries from the index.
|
||||
kb (KnowledgeBase): Knowledge base to query.
|
||||
mention (SpanGroup): Entity mentions for which to identify candidates.
|
||||
RETURNS (Iterable[Iterable[InMemoryCandidate]]): Identified candidates.
|
||||
mentions (Iterator[SpanGroup]): Mentions per doc as SpanGroup instance.
|
||||
RETURNS (Iterator[Iterable[Iterable[Candidate]]]): Identified candidates per document.
|
||||
"""
|
||||
return kb.get_candidates_batch(mentions)
|
||||
return kb.get_candidates(mentions)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Optional, Iterable, Callable, Dict, Sequence, Union, List, Any
|
||||
from typing import Optional, Iterable, Callable, Dict, Sequence, Union, List, Any, Iterator
|
||||
from typing import cast
|
||||
from numpy import dtype
|
||||
from thinc.types import Floats1d, Floats2d, Ints1d, Ragged
|
||||
|
@ -9,10 +9,9 @@ import random
|
|||
from thinc.api import CosineDistance, Model, Optimizer, Config
|
||||
from thinc.api import set_dropout_rate
|
||||
|
||||
from ..tokens import SpanGroup
|
||||
from ..kb import KnowledgeBase, Candidate
|
||||
from ..ml import empty_kb
|
||||
from ..tokens import Doc, Span, SpanGroup
|
||||
from ..tokens import Doc, SpanGroup
|
||||
from .pipe import deserialize_config
|
||||
from .trainable_pipe import TrainablePipe
|
||||
from ..language import Language
|
||||
|
@ -57,11 +56,9 @@ DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"]
|
|||
"incl_context": True,
|
||||
"entity_vector_length": 64,
|
||||
"get_candidates": {"@misc": "spacy.CandidateGenerator.v1"},
|
||||
"get_candidates_batch": {"@misc": "spacy.CandidateBatchGenerator.v1"},
|
||||
"overwrite": False,
|
||||
"scorer": {"@scorers": "spacy.entity_linker_scorer.v1"},
|
||||
"use_gold_ents": True,
|
||||
"candidates_batch_size": 1,
|
||||
"threshold": None,
|
||||
"save_activations": False,
|
||||
},
|
||||
|
@ -81,14 +78,10 @@ def make_entity_linker(
|
|||
incl_prior: bool,
|
||||
incl_context: bool,
|
||||
entity_vector_length: int,
|
||||
get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]],
|
||||
get_candidates_batch: Callable[
|
||||
[KnowledgeBase, SpanGroup], Iterable[Iterable[Candidate]]
|
||||
],
|
||||
get_candidates: Callable[[KnowledgeBase, Iterator[SpanGroup]], Iterator[Iterable[Iterable[Candidate]]]],
|
||||
overwrite: bool,
|
||||
scorer: Optional[Callable],
|
||||
use_gold_ents: bool,
|
||||
candidates_batch_size: int,
|
||||
threshold: Optional[float] = None,
|
||||
save_activations: bool,
|
||||
):
|
||||
|
@ -102,15 +95,12 @@ def make_entity_linker(
|
|||
incl_prior (bool): Whether or not to include prior probabilities from the KB in the model.
|
||||
incl_context (bool): Whether or not to include the local context in the model.
|
||||
entity_vector_length (int): Size of encoding vectors in the KB.
|
||||
get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function that
|
||||
produces a list of candidates, given a certain knowledge base and a textual mention.
|
||||
get_candidates_batch (
|
||||
Callable[[KnowledgeBase, SpanGroup], Iterable[Iterable[Candidate]]], Iterable[Candidate]]
|
||||
): Function that produces a list of candidates, given a certain knowledge base and several textual mentions.
|
||||
get_candidates (Callable[[KnowledgeBase, Iterator[SpanGroup]], Iterator[Iterable[Iterable[Candidate]]]]):
|
||||
Function producing a list of candidates per document, given a certain knowledge base and several textual
|
||||
documents with textual mentions.
|
||||
scorer (Optional[Callable]): The scoring method.
|
||||
use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another
|
||||
component must provide entity annotations.
|
||||
candidates_batch_size (int): Size of batches for entity candidate generation.
|
||||
threshold (Optional[float]): Confidence threshold for entity predictions. If confidence is below the threshold,
|
||||
prediction is discarded. If None, predictions are not filtered by any threshold.
|
||||
save_activations (bool): save model activations in Doc when annotating.
|
||||
|
@ -147,11 +137,9 @@ def make_entity_linker(
|
|||
incl_context=incl_context,
|
||||
entity_vector_length=entity_vector_length,
|
||||
get_candidates=get_candidates,
|
||||
get_candidates_batch=get_candidates_batch,
|
||||
overwrite=overwrite,
|
||||
scorer=scorer,
|
||||
use_gold_ents=use_gold_ents,
|
||||
candidates_batch_size=candidates_batch_size,
|
||||
threshold=threshold,
|
||||
save_activations=save_activations,
|
||||
)
|
||||
|
@ -185,14 +173,10 @@ class EntityLinker(TrainablePipe):
|
|||
incl_prior: bool,
|
||||
incl_context: bool,
|
||||
entity_vector_length: int,
|
||||
get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]],
|
||||
get_candidates_batch: Callable[
|
||||
[KnowledgeBase, SpanGroup], Iterable[Iterable[Candidate]]
|
||||
],
|
||||
get_candidates: Callable[[KnowledgeBase, Iterator[SpanGroup]], Iterator[Iterable[Iterable[Candidate]]]],
|
||||
overwrite: bool = False,
|
||||
scorer: Optional[Callable] = entity_linker_score,
|
||||
use_gold_ents: bool,
|
||||
candidates_batch_size: int,
|
||||
threshold: Optional[float] = None,
|
||||
save_activations: bool = False,
|
||||
) -> None:
|
||||
|
@ -207,17 +191,13 @@ class EntityLinker(TrainablePipe):
|
|||
incl_prior (bool): Whether or not to include prior probabilities from the KB in the model.
|
||||
incl_context (bool): Whether or not to include the local context in the model.
|
||||
entity_vector_length (int): Size of encoding vectors in the KB.
|
||||
get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function that
|
||||
produces a list of candidates, given a certain knowledge base and a textual mention.
|
||||
get_candidates_batch (
|
||||
Callable[[KnowledgeBase, SpanGroup], Iterable[Iterable[Candidate]]],
|
||||
Iterable[Candidate]]
|
||||
): Function that produces a list of candidates, given a certain knowledge base and several textual mentions.
|
||||
get_candidates (Callable[[KnowledgeBase, Iterator[SpanGroup]], Iterator[Iterable[Iterable[Candidate]]]]):
|
||||
Function producing a list of candidates per document, given a certain knowledge base and several textual
|
||||
documents with textual mentions.
|
||||
overwrite (bool): Whether to overwrite existing non-empty annotations.
|
||||
scorer (Optional[Callable]): The scoring method. Defaults to Scorer.score_links.
|
||||
use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another
|
||||
component must provide entity annotations.
|
||||
candidates_batch_size (int): Size of batches for entity candidate generation.
|
||||
threshold (Optional[float]): Confidence threshold for entity predictions. If confidence is below the
|
||||
threshold, prediction is discarded. If None, predictions are not filtered by any threshold.
|
||||
DOCS: https://spacy.io/api/entitylinker#init
|
||||
|
@ -240,7 +220,6 @@ class EntityLinker(TrainablePipe):
|
|||
self.incl_prior = incl_prior
|
||||
self.incl_context = incl_context
|
||||
self.get_candidates = get_candidates
|
||||
self.get_candidates_batch = get_candidates_batch
|
||||
self.cfg: Dict[str, Any] = {"overwrite": overwrite}
|
||||
self.distance = CosineDistance(normalize=False)
|
||||
# how many neighbour sentences to take into account
|
||||
|
@ -248,13 +227,9 @@ class EntityLinker(TrainablePipe):
|
|||
self.kb = empty_kb(entity_vector_length)(self.vocab)
|
||||
self.scorer = scorer
|
||||
self.use_gold_ents = use_gold_ents
|
||||
self.candidates_batch_size = candidates_batch_size
|
||||
self.threshold = threshold
|
||||
self.save_activations = save_activations
|
||||
|
||||
if candidates_batch_size < 1:
|
||||
raise ValueError(Errors.E1044)
|
||||
|
||||
def set_kb(self, kb_loader: Callable[[Vocab], KnowledgeBase]):
|
||||
"""Define the KB of this pipe by providing a function that will
|
||||
create it using this object's vocab."""
|
||||
|
@ -331,11 +306,12 @@ class EntityLinker(TrainablePipe):
|
|||
|
||||
If one isn't present, then the update step needs to be skipped.
|
||||
"""
|
||||
|
||||
for eg in examples:
|
||||
for ent in eg.predicted.ents:
|
||||
candidates = list(self.get_candidates(self.kb, ent))
|
||||
if candidates:
|
||||
# todo continue here: fix get_candidates_call
|
||||
for candidates_for_doc in self.get_candidates(
|
||||
self.kb, (SpanGroup(doc=eg.predicted, spans=eg.predicted.ents) for eg in examples)
|
||||
):
|
||||
for candidates_for_mention in candidates_for_doc:
|
||||
if list(candidates_for_mention):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
@ -464,59 +440,72 @@ class EntityLinker(TrainablePipe):
|
|||
}
|
||||
if isinstance(docs, Doc):
|
||||
docs = [docs]
|
||||
for doc in docs:
|
||||
|
||||
docs = list(docs)
|
||||
# Determine which entities are to be ignored due to labels_discard.
|
||||
valid_ent_idx_per_doc = (
|
||||
[
|
||||
idx
|
||||
for idx in range(len(doc.ents))
|
||||
if doc.ents[idx].label_ not in self.labels_discard
|
||||
]
|
||||
for doc in docs
|
||||
if len(doc) and len(doc.ents)
|
||||
)
|
||||
|
||||
# Call candidate generator.
|
||||
all_ent_cands = self.get_candidates(
|
||||
self.kb,
|
||||
(
|
||||
SpanGroup(
|
||||
doc,
|
||||
spans=[doc.ents[idx] for idx in next(valid_ent_idx_per_doc)],
|
||||
)
|
||||
for doc in docs
|
||||
if len(doc) and len(doc.ents)
|
||||
),
|
||||
)
|
||||
|
||||
for doc_idx, doc in enumerate(docs):
|
||||
doc_ents: List[Ints1d] = []
|
||||
doc_scores: List[Floats1d] = []
|
||||
if len(doc) == 0:
|
||||
if len(doc) == 0 or len(doc.ents) == 0:
|
||||
docs_scores.append(Ragged(ops.alloc1f(0), ops.alloc1i(0)))
|
||||
docs_ents.append(Ragged(xp.zeros(0, dtype="uint64"), ops.alloc1i(0)))
|
||||
continue
|
||||
sentences = [s for s in doc.sents]
|
||||
doc_ent_cands = list(next(all_ent_cands))
|
||||
|
||||
# Loop over entities in batches.
|
||||
for ent_idx in range(0, len(doc.ents), self.candidates_batch_size):
|
||||
ent_batch = doc.ents[ent_idx : ent_idx + self.candidates_batch_size]
|
||||
# Looping over candidate entities for this doc. (TODO: rewrite)
|
||||
for ent_cand_idx, ent in enumerate(doc.ents):
|
||||
sent_index = sentences.index(ent.sent)
|
||||
assert sent_index >= 0
|
||||
|
||||
# Look up candidate entities.
|
||||
valid_ent_idx = [
|
||||
idx
|
||||
for idx in range(len(ent_batch))
|
||||
if ent_batch[idx].label_ not in self.labels_discard
|
||||
]
|
||||
|
||||
batch_candidates = list(
|
||||
self.get_candidates_batch(
|
||||
self.kb,
|
||||
SpanGroup(doc, spans=[ent_batch[idx] for idx in valid_ent_idx]),
|
||||
if self.incl_context:
|
||||
# get n_neighbour sentences, clipped to the length of the document
|
||||
start_sentence = max(0, sent_index - self.n_sents)
|
||||
end_sentence = min(len(sentences) - 1, sent_index + self.n_sents)
|
||||
start_token = sentences[start_sentence].start
|
||||
end_token = sentences[end_sentence].end
|
||||
sent_doc = doc[start_token:end_token].as_doc()
|
||||
# currently, the context is the same for each entity in a sentence (should be refined)
|
||||
sentence_encoding = self.model.predict([sent_doc])[0]
|
||||
sentence_encoding_t = sentence_encoding.T
|
||||
sentence_norm = xp.linalg.norm(sentence_encoding_t)
|
||||
entity_count += 1
|
||||
if ent.label_ in self.labels_discard:
|
||||
# ignoring this entity - setting to NIL
|
||||
final_kb_ids.append(self.NIL)
|
||||
self._add_activations(
|
||||
doc_scores=doc_scores,
|
||||
doc_ents=doc_ents,
|
||||
scores=[0.0],
|
||||
ents=[0],
|
||||
)
|
||||
if self.candidates_batch_size > 1
|
||||
else [
|
||||
self.get_candidates(self.kb, ent_batch[idx])
|
||||
for idx in valid_ent_idx
|
||||
]
|
||||
)
|
||||
|
||||
# Looping through each entity in batch (TODO: rewrite)
|
||||
for j, ent in enumerate(ent_batch):
|
||||
sent_index = sentences.index(ent.sent)
|
||||
assert sent_index >= 0
|
||||
|
||||
if self.incl_context:
|
||||
# get n_neighbour sentences, clipped to the length of the document
|
||||
start_sentence = max(0, sent_index - self.n_sents)
|
||||
end_sentence = min(
|
||||
len(sentences) - 1, sent_index + self.n_sents
|
||||
)
|
||||
start_token = sentences[start_sentence].start
|
||||
end_token = sentences[end_sentence].end
|
||||
sent_doc = doc[start_token:end_token].as_doc()
|
||||
# currently, the context is the same for each entity in a sentence (should be refined)
|
||||
sentence_encoding = self.model.predict([sent_doc])[0]
|
||||
sentence_encoding_t = sentence_encoding.T
|
||||
sentence_norm = xp.linalg.norm(sentence_encoding_t)
|
||||
entity_count += 1
|
||||
if ent.label_ in self.labels_discard:
|
||||
# ignoring this entity - setting to NIL
|
||||
else:
|
||||
candidates = list(doc_ent_cands[ent_cand_idx])
|
||||
if not candidates:
|
||||
# no prediction possible for this entity - setting to NIL
|
||||
final_kb_ids.append(self.NIL)
|
||||
self._add_activations(
|
||||
doc_scores=doc_scores,
|
||||
|
@ -524,65 +513,56 @@ class EntityLinker(TrainablePipe):
|
|||
scores=[0.0],
|
||||
ents=[0],
|
||||
)
|
||||
elif len(candidates) == 1 and self.threshold is None:
|
||||
# shortcut for efficiency reasons: take the 1 candidate
|
||||
final_kb_ids.append(candidates[0].entity_)
|
||||
self._add_activations(
|
||||
doc_scores=doc_scores,
|
||||
doc_ents=doc_ents,
|
||||
scores=[1.0],
|
||||
ents=[candidates[0].entity],
|
||||
)
|
||||
else:
|
||||
candidates = list(batch_candidates[j])
|
||||
if not candidates:
|
||||
# no prediction possible for this entity - setting to NIL
|
||||
final_kb_ids.append(self.NIL)
|
||||
self._add_activations(
|
||||
doc_scores=doc_scores,
|
||||
doc_ents=doc_ents,
|
||||
scores=[0.0],
|
||||
ents=[0],
|
||||
random.shuffle(candidates)
|
||||
# set all prior probabilities to 0 if incl_prior=False
|
||||
scores = prior_probs = xp.asarray(
|
||||
[
|
||||
c.prior_prob if self.incl_prior else 0.0
|
||||
for c in candidates
|
||||
]
|
||||
)
|
||||
# add in similarity from the context
|
||||
if self.incl_context:
|
||||
entity_encodings = xp.asarray(
|
||||
[c.entity_vector for c in candidates]
|
||||
)
|
||||
elif len(candidates) == 1 and self.threshold is None:
|
||||
# shortcut for efficiency reasons: take the 1 candidate
|
||||
final_kb_ids.append(candidates[0].entity_)
|
||||
self._add_activations(
|
||||
doc_scores=doc_scores,
|
||||
doc_ents=doc_ents,
|
||||
scores=[1.0],
|
||||
ents=[candidates[0].entity],
|
||||
)
|
||||
else:
|
||||
random.shuffle(candidates)
|
||||
# set all prior probabilities to 0 if incl_prior=False
|
||||
prior_probs = xp.asarray([c.prior_prob for c in candidates])
|
||||
if not self.incl_prior:
|
||||
prior_probs = xp.asarray([0.0 for _ in candidates])
|
||||
scores = prior_probs
|
||||
# add in similarity from the context
|
||||
if self.incl_context:
|
||||
entity_encodings = xp.asarray(
|
||||
[c.entity_vector for c in candidates]
|
||||
)
|
||||
entity_norm = xp.linalg.norm(entity_encodings, axis=1)
|
||||
if len(entity_encodings) != len(prior_probs):
|
||||
raise RuntimeError(
|
||||
Errors.E147.format(
|
||||
method="predict",
|
||||
msg="vectors not of equal length",
|
||||
)
|
||||
entity_norm = xp.linalg.norm(entity_encodings, axis=1)
|
||||
if len(entity_encodings) != len(prior_probs):
|
||||
raise RuntimeError(
|
||||
Errors.E147.format(
|
||||
method="predict",
|
||||
msg="vectors not of equal length",
|
||||
)
|
||||
# cosine similarity
|
||||
sims = xp.dot(entity_encodings, sentence_encoding_t) / (
|
||||
sentence_norm * entity_norm
|
||||
)
|
||||
if sims.shape != prior_probs.shape:
|
||||
raise ValueError(Errors.E161)
|
||||
scores = prior_probs + sims - (prior_probs * sims)
|
||||
final_kb_ids.append(
|
||||
candidates[scores.argmax().item()].entity_
|
||||
if self.threshold is None
|
||||
or scores.max() >= self.threshold
|
||||
else EntityLinker.NIL
|
||||
)
|
||||
self._add_activations(
|
||||
doc_scores=doc_scores,
|
||||
doc_ents=doc_ents,
|
||||
scores=scores,
|
||||
ents=[c.entity for c in candidates],
|
||||
# cosine similarity
|
||||
sims = xp.dot(entity_encodings, sentence_encoding_t) / (
|
||||
sentence_norm * entity_norm
|
||||
)
|
||||
if sims.shape != prior_probs.shape:
|
||||
raise ValueError(Errors.E161)
|
||||
scores = prior_probs + sims - (prior_probs * sims)
|
||||
final_kb_ids.append(
|
||||
candidates[scores.argmax().item()].entity_
|
||||
if self.threshold is None or scores.max() >= self.threshold
|
||||
else EntityLinker.NIL
|
||||
)
|
||||
self._add_activations(
|
||||
doc_scores=doc_scores,
|
||||
doc_ents=doc_ents,
|
||||
scores=scores,
|
||||
ents=[c.entity for c in candidates],
|
||||
)
|
||||
|
||||
self._add_doc_activations(
|
||||
docs_scores=docs_scores,
|
||||
docs_ents=docs_ents,
|
||||
|
@ -594,6 +574,7 @@ class EntityLinker(TrainablePipe):
|
|||
method="predict", msg="result variables not of equal length"
|
||||
)
|
||||
raise RuntimeError(err)
|
||||
|
||||
return {
|
||||
KNOWLEDGE_BASE_IDS: final_kb_ids,
|
||||
"ents": docs_ents,
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Callable, Iterable, Dict, Any, cast
|
||||
from typing import Callable, Iterable, Dict, Any, cast, Iterator
|
||||
|
||||
import pytest
|
||||
from numpy.testing import assert_equal
|
||||
|
@ -15,7 +15,7 @@ from spacy.pipeline import EntityLinker, TrainablePipe
|
|||
from spacy.pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL
|
||||
from spacy.scorer import Scorer
|
||||
from spacy.tests.util import make_tempdir
|
||||
from spacy.tokens import Span, Doc
|
||||
from spacy.tokens import Doc, Span, SpanGroup
|
||||
from spacy.training import Example
|
||||
from spacy.util import ensure_path
|
||||
from spacy.vocab import Vocab
|
||||
|
@ -462,16 +462,17 @@ def test_candidate_generation(nlp):
|
|||
mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
|
||||
|
||||
# test the size of the relevant candidates
|
||||
assert len(get_candidates(mykb, douglas_ent)) == 2
|
||||
assert len(get_candidates(mykb, adam_ent)) == 1
|
||||
assert len(get_candidates(mykb, Adam_ent)) == 0 # default case sensitive
|
||||
assert len(get_candidates(mykb, shrubbery_ent)) == 0
|
||||
adam_ent_cands = next(get_candidates(mykb, SpanGroup(doc=doc, spans=[adam_ent])))[0]
|
||||
assert len(adam_ent_cands) == 1
|
||||
assert len(next(get_candidates(mykb, SpanGroup(doc=doc, spans=[douglas_ent])))[0]) == 2
|
||||
assert len(next(get_candidates(mykb, SpanGroup(doc=doc, spans=[Adam_ent])))[0]) == 0 # default case sensitive
|
||||
assert len(next(get_candidates(mykb, SpanGroup(doc=doc, spans=[shrubbery_ent])))[0]) == 0
|
||||
|
||||
# test the content of the candidates
|
||||
assert get_candidates(mykb, adam_ent)[0].entity_ == "Q2"
|
||||
assert get_candidates(mykb, adam_ent)[0].alias_ == "adam"
|
||||
assert_almost_equal(get_candidates(mykb, adam_ent)[0].entity_freq, 12)
|
||||
assert_almost_equal(get_candidates(mykb, adam_ent)[0].prior_prob, 0.9)
|
||||
assert adam_ent_cands[0].entity_ == "Q2"
|
||||
assert adam_ent_cands[0].alias_ == "adam"
|
||||
assert_almost_equal(adam_ent_cands[0].entity_freq, 12)
|
||||
assert_almost_equal(adam_ent_cands[0].prior_prob, 0.9)
|
||||
|
||||
|
||||
def test_el_pipe_configuration(nlp):
|
||||
|
@ -498,24 +499,16 @@ def test_el_pipe_configuration(nlp):
|
|||
assert doc[1].ent_kb_id_ == ""
|
||||
assert doc[2].ent_kb_id_ == "Q2"
|
||||
|
||||
def get_lowercased_candidates(kb, span):
|
||||
return kb.get_alias_candidates(span.text.lower())
|
||||
|
||||
def get_lowercased_candidates_batch(kb, spans):
|
||||
return [get_lowercased_candidates(kb, span) for span in spans]
|
||||
def get_lowercased_candidates(kb: InMemoryLookupKB, mentions: Iterator[SpanGroup]):
|
||||
for mentions_for_doc in mentions:
|
||||
yield [kb.get_alias_candidates(ent_span.text.lower()) for ent_span in mentions_for_doc]
|
||||
|
||||
@registry.misc("spacy.LowercaseCandidateGenerator.v1")
|
||||
def create_candidates() -> Callable[
|
||||
[InMemoryLookupKB, "Span"], Iterable[InMemoryCandidate]
|
||||
[InMemoryLookupKB, Iterator[SpanGroup]], Iterator[Iterable[Iterable[InMemoryCandidate]]]
|
||||
]:
|
||||
return get_lowercased_candidates
|
||||
|
||||
@registry.misc("spacy.LowercaseCandidateBatchGenerator.v1")
|
||||
def create_candidates_batch() -> Callable[
|
||||
[InMemoryLookupKB, Iterable["Span"]], Iterable[Iterable[InMemoryCandidate]]
|
||||
]:
|
||||
return get_lowercased_candidates_batch
|
||||
|
||||
# replace the pipe with a new one with with a different candidate generator
|
||||
entity_linker = nlp.replace_pipe(
|
||||
"entity_linker",
|
||||
|
@ -523,9 +516,6 @@ def test_el_pipe_configuration(nlp):
|
|||
config={
|
||||
"incl_context": False,
|
||||
"get_candidates": {"@misc": "spacy.LowercaseCandidateGenerator.v1"},
|
||||
"get_candidates_batch": {
|
||||
"@misc": "spacy.LowercaseCandidateBatchGenerator.v1"
|
||||
},
|
||||
},
|
||||
)
|
||||
entity_linker.set_kb(create_kb)
|
||||
|
|
|
@ -10,7 +10,7 @@ import re
|
|||
from pathlib import Path
|
||||
import thinc
|
||||
from thinc.api import NumpyOps, get_current_ops, Adam, Config, Optimizer
|
||||
from thinc.api import ConfigValidationError, Model, constant as constant_schedule
|
||||
from thinc.api import ConfigValidationError, Model
|
||||
import functools
|
||||
import itertools
|
||||
import numpy
|
||||
|
|
|
@ -62,7 +62,7 @@ architectures and their arguments and hyperparameters.
|
|||
| `model` | The [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. Defaults to [EntityLinker](/api/architectures#EntityLinker). ~~Model~~ |
|
||||
| `entity_vector_length` | Size of encoding vectors in the KB. Defaults to `64`. ~~int~~ |
|
||||
| `use_gold_ents` | Whether to copy entities from the gold docs or not. Defaults to `True`. If `False`, entities must be set in the training data or by an annotating component in the pipeline. ~~int~~ |
|
||||
| `get_candidates` | Function that generates plausible candidates for a given `Span` object. Defaults to [CandidateGenerator](/api/architectures#CandidateGenerator), a function looking up exact, case-dependent aliases in the KB. ~~Callable[[KnowledgeBase, Span], Iterable[Candidate]]~~ |
|
||||
| `get_candidates` | Function that retrieves plausible candidates per entity mention in a given `SpanGroup`. Defaults to [CandidateGenerator](/api/architectures#CandidateGenerator). ~~Callable[[KnowledgeBase, Iterator[SpanGroup]], Iterator[Iterable[Iterable[Candidate]]]]~~ |
|
||||
| `overwrite` <Tag variant="new">3.2</Tag> | Whether existing annotation is overwritten. Defaults to `False`. ~~bool~~ |
|
||||
| `scorer` <Tag variant="new">3.2</Tag> | The scoring method. Defaults to [`Scorer.score_links`](/api/scorer#score_links). ~~Optional[Callable]~~ |
|
||||
| `save_activations` <Tag variant="new">4.0</Tag> | Save activations in `Doc` when annotating. Saved activations are `"ents"` and `"scores"`. ~~Union[bool, list[str]]~~ |
|
||||
|
|
|
@ -155,35 +155,13 @@ Get a list of all aliases in the knowledge base.
|
|||
|
||||
## InMemoryLookupKB.get_candidates {id="get_candidates",tag="method"}
|
||||
|
||||
Given a certain textual mention as input, retrieve a list of candidate entities
|
||||
of type [`InMemoryCandidate`](/api/kb#candidate). Wraps
|
||||
[`get_alias_candidates()`](/api/inmemorylookupkb#get_alias_candidates).
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> from spacy.lang.en import English
|
||||
> nlp = English()
|
||||
> doc = nlp("Douglas Adams wrote 'The Hitchhiker's Guide to the Galaxy'.")
|
||||
> candidates = kb.get_candidates(doc[0:2])
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| ----------- | ------------------------------------------------------------------------------------ |
|
||||
| `mention` | The textual mention or alias. ~~Span~~ |
|
||||
| **RETURNS** | An iterable of relevant `InMemoryCandidate` objects. ~~Iterable[InMemoryCandidate]~~ |
|
||||
|
||||
## InMemoryLookupKB.get_candidates_batch {id="get_candidates_batch",tag="method"}
|
||||
|
||||
Same as [`get_candidates()`](/api/inmemorylookupkb#get_candidates), but for an
|
||||
arbitrary number of mentions. The [`EntityLinker`](/api/entitylinker) component
|
||||
will call `get_candidates_batch()` instead of `get_candidates()`, if the config
|
||||
parameter `candidates_batch_size` is greater or equal than 1.
|
||||
|
||||
The default implementation of `get_candidates_batch()` executes
|
||||
`get_candidates()` in a loop. We recommend implementing a more efficient way to
|
||||
retrieve candidates for multiple mentions at once, if performance is of concern
|
||||
to you.
|
||||
Given textual mentions for an arbitrary number of documents as input, retrieve a
|
||||
list of candidate entities of type [`InMemoryCandidate`](/api/kb#candidate) for
|
||||
each mention. The [`EntityLinker`](/api/entitylinker) component passes a
|
||||
generator yielding all mentions to retreive candidates for as
|
||||
[`SpanGroup`](/api/spangroup)) per document. The decision of how to batch
|
||||
candidate retrieval lookups over multiple documents is left up to the
|
||||
implementation of `KnowledgeBase.get_candidates()`.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
|
@ -192,13 +170,13 @@ to you.
|
|||
> from spacy.tokens import SpanGroup
|
||||
> nlp = English()
|
||||
> doc = nlp("Douglas Adams wrote 'The Hitchhiker's Guide to the Galaxy'.")
|
||||
> candidates = kb.get_candidates(SpanGroup(doc, spans=[doc[0:2], doc[3:]])
|
||||
> candidates = kb.get_candidates([SpanGroup(doc, spans=[doc[0:2], doc[3:]]])
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| ----------- | ------------------------------------------------------------------------------------------------------------ |
|
||||
| `mentions` | The textual mention or alias. ~~Iterable[Span]~~ |
|
||||
| **RETURNS** | An iterable of iterable with relevant `InMemoryCandidate` objects. ~~Iterable[Iterable[InMemoryCandidate]]~~ |
|
||||
| Name | Description |
|
||||
| ----------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `mentions` | The textual mention or alias. ~~Iterable[SpanGroup]~~ |
|
||||
| **RETURNS** | An iterator over iterables of iterables with relevant [`InMemoryCandidate`](/api/kb#candidate) objects (per mention and doc). ~~Iterator[Iterable[Iterable[InMemoryCandidate]]]~~ |
|
||||
|
||||
## InMemoryLookupKB.get_alias_candidates {id="get_alias_candidates",tag="method"}
|
||||
|
||||
|
|
|
@ -60,34 +60,13 @@ The length of the fixed-size entity vectors in the knowledge base.
|
|||
|
||||
## KnowledgeBase.get_candidates {id="get_candidates",tag="method"}
|
||||
|
||||
Given a certain textual mention as input, retrieve a list of candidate entities
|
||||
of type [`Candidate`](/api/kb#candidate).
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> from spacy.lang.en import English
|
||||
> nlp = English()
|
||||
> doc = nlp("Douglas Adams wrote 'The Hitchhiker's Guide to the Galaxy'.")
|
||||
> candidates = kb.get_candidates(doc[0:2])
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| ----------- | -------------------------------------------------------------------- |
|
||||
| `mention` | The textual mention or alias. ~~Span~~ |
|
||||
| **RETURNS** | An iterable of relevant `Candidate` objects. ~~Iterable[Candidate]~~ |
|
||||
|
||||
## KnowledgeBase.get_candidates_batch {id="get_candidates_batch",tag="method"}
|
||||
|
||||
Same as [`get_candidates()`](/api/kb#get_candidates), but for an arbitrary
|
||||
number of mentions. The [`EntityLinker`](/api/entitylinker) component will call
|
||||
`get_candidates_batch()` instead of `get_candidates()`, if the config parameter
|
||||
`candidates_batch_size` is greater or equal than 1.
|
||||
|
||||
The default implementation of `get_candidates_batch()` executes
|
||||
`get_candidates()` in a loop. We recommend implementing a more efficient way to
|
||||
retrieve candidates for multiple mentions at once, if performance is of concern
|
||||
to you.
|
||||
Given textual mentions for an arbitrary number of documents as input, retrieve a
|
||||
list of candidate entities of type [`Candidate`](/api/kb#candidate) for each
|
||||
mention. The [`EntityLinker`](/api/entitylinker) component passes a generator
|
||||
yielding all mentions to retreive candidates for as
|
||||
[`SpanGroup`](/api/spangroup)) per document. The decision of how to batch
|
||||
candidate retrieval lookups over multiple documents is left up to the
|
||||
implementation of `KnowledgeBase.get_candidates()`.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
|
@ -96,30 +75,13 @@ to you.
|
|||
> from spacy.tokens import SpanGroup
|
||||
> nlp = English()
|
||||
> doc = nlp("Douglas Adams wrote 'The Hitchhiker's Guide to the Galaxy'.")
|
||||
> candidates = kb.get_candidates(SpanGroup(doc, spans=[doc[0:2], doc[3:]])
|
||||
> candidates = kb.get_candidates([SpanGroup(doc, spans=[doc[0:2], doc[3:]]])
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| ----------- | -------------------------------------------------------------------------------------------- |
|
||||
| `mentions` | The textual mention or alias. ~~SpanGroup~~ |
|
||||
| **RETURNS** | An iterable of iterable with relevant `Candidate` objects. ~~Iterable[Iterable[Candidate]]~~ |
|
||||
|
||||
## KnowledgeBase.get_alias_candidates {id="get_alias_candidates",tag="method"}
|
||||
|
||||
<Infobox variant="warning">
|
||||
This method is _not_ available from spaCy 3.5 onwards.
|
||||
</Infobox>
|
||||
|
||||
From spaCy 3.5 on `KnowledgeBase` is an abstract class (with
|
||||
[`InMemoryLookupKB`](/api/inmemorylookupkb) being a drop-in replacement) to
|
||||
allow more flexibility in customizing knowledge bases. Some of its methods were
|
||||
moved to [`InMemoryLookupKB`](/api/inmemorylookupkb) during this refactoring,
|
||||
one of those being `get_alias_candidates()`. This method is now available as
|
||||
[`InMemoryLookupKB.get_alias_candidates()`](/api/inmemorylookupkb#get_alias_candidates).
|
||||
Note:
|
||||
[`InMemoryLookupKB.get_candidates()`](/api/inmemorylookupkb#get_candidates)
|
||||
defaults to
|
||||
[`InMemoryLookupKB.get_alias_candidates()`](/api/inmemorylookupkb#get_alias_candidates).
|
||||
| Name | Description |
|
||||
| ----------- | -------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `mentions` | The textual mention or alias. ~~Iterable[SpanGroup]~~ |
|
||||
| **RETURNS** | An iterator over iterables of iterables with relevant `Candidate` objects (per mention and doc). ~~Iterator[Iterable[Iterable[Candidate]]]~~ |
|
||||
|
||||
## KnowledgeBase.get_vector {id="get_vector",tag="method"}
|
||||
|
||||
|
@ -193,11 +155,11 @@ Restore the state of the knowledge base from a given directory. Note that the
|
|||
|
||||
## InMemoryCandidate {id="candidate",tag="class"}
|
||||
|
||||
A `InMemoryCandidate` object refers to a textual mention (alias) that may or may
|
||||
not be resolved to a specific entity from a `KnowledgeBase`. This will be used
|
||||
as input for the entity linking algorithm which will disambiguate the various
|
||||
candidates to the correct one. Each candidate `(alias, entity)` pair is assigned
|
||||
to a certain prior probability.
|
||||
A `InMemoryCandidate` object refers to a textual mention that may or may not be
|
||||
resolved to a specific entity from a `KnowledgeBase`. This will be used as input
|
||||
for the entity linking algorithm which will disambiguate the various candidates
|
||||
to the correct one. Each candidate `(mention, entity)` pair is assigned to a
|
||||
certain prior probability.
|
||||
|
||||
### InMemoryCandidate.\_\_init\_\_ {id="candidate-init",tag="method"}
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user