From 4a921766f128755fb733e899e6701599008184a5 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Mon, 13 Mar 2023 16:54:38 +0100 Subject: [PATCH] Remove prior_prob from supported properties in Candidate. Introduce KnowledgeBase.supports_prior_probs(). --- spacy/errors.py | 3 +++ spacy/kb/candidate.pyx | 7 ++----- spacy/kb/kb.pyx | 7 +++++++ spacy/kb/kb_in_memory.pyx | 3 +++ spacy/pipeline/entity_linker.py | 14 ++++++++------ 5 files changed, 23 insertions(+), 11 deletions(-) diff --git a/spacy/errors.py b/spacy/errors.py index 30446e7ea..0f8091e3a 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -209,7 +209,10 @@ class Warnings(metaclass=ErrorsWithCodes): "`enabled` ({enabled}). Be aware that this might affect other components in your pipeline.") W124 = ("{host}:{port} is already in use, using the nearest available port {serve_port} as an alternative.") + # v4 warning strings W400 = ("`use_upper=False` is ignored, the upper layer is always enabled") + W401 = ("`incl_prior is True`, but the selected knowledge base type {kb_type} doesn't support prior probability " + "lookups.") class Errors(metaclass=ErrorsWithCodes): diff --git a/spacy/kb/candidate.pyx b/spacy/kb/candidate.pyx index ac19df671..9e4e9f321 100644 --- a/spacy/kb/candidate.pyx +++ b/spacy/kb/candidate.pyx @@ -41,11 +41,6 @@ cdef class Candidate: """RETURNS (vector[float]): Entity vector.""" raise NotImplementedError - @property - def prior_prob(self) -> float: - """RETURNS (List[float]): Entity vector.""" - raise NotImplementedError - cdef class InMemoryCandidate(Candidate): """Candidate for InMemoryLookupKB.""" @@ -89,6 +84,7 @@ cdef class InMemoryCandidate(Candidate): @property def prior_prob(self) -> float: + """RETURNS (float): Prior probability that this mention resolves to this entity.""" return self._prior_prob @property @@ -101,4 +97,5 @@ cdef class InMemoryCandidate(Candidate): @property def entity_freq(self) -> float: + """RETURNS (float): Entity frequency in KB corpus.""" return self._entity_freq diff --git a/spacy/kb/kb.pyx b/spacy/kb/kb.pyx index 7da312863..d10123e37 100644 --- a/spacy/kb/kb.pyx +++ b/spacy/kb/kb.pyx @@ -106,3 +106,10 @@ cdef class KnowledgeBase: raise NotImplementedError( Errors.E1045.format(parent="KnowledgeBase", method="from_disk", name=self.__name__) ) + + @property + def supports_prior_probs(self) -> bool: + """RETURNS (bool): Whether this KB type supports looking up prior probabilities for entity mentions.""" + raise NotImplementedError( + Errors.E1045.format(parent="KnowledgeBase", method="supports_prior_probs", name=self.__name__) + ) diff --git a/spacy/kb/kb_in_memory.pyx b/spacy/kb/kb_in_memory.pyx index 4ceb87888..e3b9dfcb3 100644 --- a/spacy/kb/kb_in_memory.pyx +++ b/spacy/kb/kb_in_memory.pyx @@ -283,6 +283,9 @@ cdef class InMemoryLookupKB(KnowledgeBase): return 0.0 + def supports_prior_probs(self) -> bool: + return True + def to_bytes(self, **kwargs): """Serialize the current state to a binary string. """ diff --git a/spacy/pipeline/entity_linker.py b/spacy/pipeline/entity_linker.py index 39cff218a..caced9cfd 100644 --- a/spacy/pipeline/entity_linker.py +++ b/spacy/pipeline/entity_linker.py @@ -1,5 +1,5 @@ -from typing import Optional, Iterable, Callable, Dict, Sequence, Union, List, Any -from typing import cast +import warnings +from typing import Optional, Iterable, Callable, Dict, Sequence, Union, List, Any, cast from numpy import dtype from thinc.types import Floats1d, Floats2d, Ints1d, Ragged from pathlib import Path @@ -10,14 +10,13 @@ from thinc.api import CosineDistance, Model, Optimizer, Config from thinc.api import set_dropout_rate from ..kb import KnowledgeBase, Candidate -from ..ml import empty_kb from ..tokens import Doc, Span from .pipe import deserialize_config from .trainable_pipe import TrainablePipe from ..language import Language from ..vocab import Vocab from ..training import Example, validate_examples, validate_get_examples -from ..errors import Errors +from ..errors import Errors, Warnings from ..util import SimpleFrozenList, registry from .. import util from ..scorer import Scorer @@ -240,6 +239,8 @@ class EntityLinker(TrainablePipe): if candidates_batch_size < 1: raise ValueError(Errors.E1044) + if self.incl_prior and not self.kb.supports_prior_probs: + warnings.warn(Warnings.W401) def set_kb(self, kb_loader: Callable[[Vocab], KnowledgeBase]): """Define the KB of this pipe by providing a function that will @@ -532,8 +533,9 @@ class EntityLinker(TrainablePipe): else: random.shuffle(candidates) # set all prior probabilities to 0 if incl_prior=False - prior_probs = xp.asarray([c.prior_prob for c in candidates]) - if not self.incl_prior: + if self.incl_prior and self.kb.supports_prior_probs: + prior_probs = xp.asarray([c.prior_prob for c in candidates]) # type: ignore + else: prior_probs = xp.asarray([0.0 for _ in candidates]) scores = prior_probs # add in similarity from the context