Remove prior_prob from supported properties in Candidate. Introduce KnowledgeBase.supports_prior_probs().

This commit is contained in:
Raphael Mitsch 2023-03-13 16:54:38 +01:00
parent 6adc15178f
commit 4a921766f1
5 changed files with 23 additions and 11 deletions

View File

@ -209,7 +209,10 @@ class Warnings(metaclass=ErrorsWithCodes):
"`enabled` ({enabled}). Be aware that this might affect other components in your pipeline.")
W124 = ("{host}:{port} is already in use, using the nearest available port {serve_port} as an alternative.")
# v4 warning strings
W400 = ("`use_upper=False` is ignored, the upper layer is always enabled")
W401 = ("`incl_prior is True`, but the selected knowledge base type {kb_type} doesn't support prior probability "
"lookups.")
class Errors(metaclass=ErrorsWithCodes):

View File

@ -41,11 +41,6 @@ cdef class Candidate:
"""RETURNS (vector[float]): Entity vector."""
raise NotImplementedError
@property
def prior_prob(self) -> float:
"""RETURNS (List[float]): Entity vector."""
raise NotImplementedError
cdef class InMemoryCandidate(Candidate):
"""Candidate for InMemoryLookupKB."""
@ -89,6 +84,7 @@ cdef class InMemoryCandidate(Candidate):
@property
def prior_prob(self) -> float:
"""RETURNS (float): Prior probability that this mention resolves to this entity."""
return self._prior_prob
@property
@ -101,4 +97,5 @@ cdef class InMemoryCandidate(Candidate):
@property
def entity_freq(self) -> float:
"""RETURNS (float): Entity frequency in KB corpus."""
return self._entity_freq

View File

@ -106,3 +106,10 @@ cdef class KnowledgeBase:
raise NotImplementedError(
Errors.E1045.format(parent="KnowledgeBase", method="from_disk", name=self.__name__)
)
@property
def supports_prior_probs(self) -> bool:
"""RETURNS (bool): Whether this KB type supports looking up prior probabilities for entity mentions."""
raise NotImplementedError(
Errors.E1045.format(parent="KnowledgeBase", method="supports_prior_probs", name=self.__name__)
)

View File

@ -283,6 +283,9 @@ cdef class InMemoryLookupKB(KnowledgeBase):
return 0.0
def supports_prior_probs(self) -> bool:
return True
def to_bytes(self, **kwargs):
"""Serialize the current state to a binary string.
"""

View File

@ -1,5 +1,5 @@
from typing import Optional, Iterable, Callable, Dict, Sequence, Union, List, Any
from typing import cast
import warnings
from typing import Optional, Iterable, Callable, Dict, Sequence, Union, List, Any, cast
from numpy import dtype
from thinc.types import Floats1d, Floats2d, Ints1d, Ragged
from pathlib import Path
@ -10,14 +10,13 @@ from thinc.api import CosineDistance, Model, Optimizer, Config
from thinc.api import set_dropout_rate
from ..kb import KnowledgeBase, Candidate
from ..ml import empty_kb
from ..tokens import Doc, Span
from .pipe import deserialize_config
from .trainable_pipe import TrainablePipe
from ..language import Language
from ..vocab import Vocab
from ..training import Example, validate_examples, validate_get_examples
from ..errors import Errors
from ..errors import Errors, Warnings
from ..util import SimpleFrozenList, registry
from .. import util
from ..scorer import Scorer
@ -240,6 +239,8 @@ class EntityLinker(TrainablePipe):
if candidates_batch_size < 1:
raise ValueError(Errors.E1044)
if self.incl_prior and not self.kb.supports_prior_probs:
warnings.warn(Warnings.W401)
def set_kb(self, kb_loader: Callable[[Vocab], KnowledgeBase]):
"""Define the KB of this pipe by providing a function that will
@ -532,8 +533,9 @@ class EntityLinker(TrainablePipe):
else:
random.shuffle(candidates)
# set all prior probabilities to 0 if incl_prior=False
prior_probs = xp.asarray([c.prior_prob for c in candidates])
if not self.incl_prior:
if self.incl_prior and self.kb.supports_prior_probs:
prior_probs = xp.asarray([c.prior_prob for c in candidates]) # type: ignore
else:
prior_probs = xp.asarray([0.0 for _ in candidates])
scores = prior_probs
# add in similarity from the context