mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-11 17:56:30 +03:00
NEL confidence threshold (#11016)
* Add base for NEL abstention threshold mechanism. * Add abstention threshold to entity linker. Add test. * Fix entity linking tests. * Changed abstention default threshold from 0 to None. * Fix default values for abstention thresholds. * Fix mypy errors. * Replace assertion with raise of proper error code. * Simplify threshold check. Remove thresholding from EntityLinker_v1. * Rename test. * Update spacy/pipeline/entity_linker.py Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> * Update spacy/pipeline/entity_linker.py Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> * Make E1043 configurable. * Update docs. * Rephrase description in docs. Adjusting error code message. Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
This commit is contained in:
parent
59c763eec1
commit
e9eb59699f
|
@ -937,6 +937,8 @@ class Errors(metaclass=ErrorsWithCodes):
|
|||
E1041 = ("Expected a string, Doc, or bytes as input, but got: {type}")
|
||||
E1042 = ("Function was called with `{arg1}`={arg1_values} and "
|
||||
"`{arg2}`={arg2_values} but these arguments are conflicting.")
|
||||
E1043 = ("Expected None or a value in range [{range_start}, {range_end}] for entity linker threshold, but got "
|
||||
"{value}.")
|
||||
|
||||
|
||||
# Deprecated model shortcuts, only used in errors and warnings
|
||||
|
|
|
@ -56,6 +56,7 @@ DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"]
|
|||
"overwrite": True,
|
||||
"scorer": {"@scorers": "spacy.entity_linker_scorer.v1"},
|
||||
"use_gold_ents": True,
|
||||
"threshold": None,
|
||||
},
|
||||
default_score_weights={
|
||||
"nel_micro_f": 1.0,
|
||||
|
@ -77,6 +78,7 @@ def make_entity_linker(
|
|||
overwrite: bool,
|
||||
scorer: Optional[Callable],
|
||||
use_gold_ents: bool,
|
||||
threshold: Optional[float] = None,
|
||||
):
|
||||
"""Construct an EntityLinker component.
|
||||
|
||||
|
@ -91,6 +93,10 @@ def make_entity_linker(
|
|||
get_candidates (Callable[[KnowledgeBase, "Span"], Iterable[Candidate]]): Function that
|
||||
produces a list of candidates, given a certain knowledge base and a textual mention.
|
||||
scorer (Optional[Callable]): The scoring method.
|
||||
use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another
|
||||
component must provide entity annotations.
|
||||
threshold (Optional[float]): Confidence threshold for entity predictions. If confidence is below the threshold,
|
||||
prediction is discarded. If None, predictions are not filtered by any threshold.
|
||||
"""
|
||||
|
||||
if not model.attrs.get("include_span_maker", False):
|
||||
|
@ -121,6 +127,7 @@ def make_entity_linker(
|
|||
overwrite=overwrite,
|
||||
scorer=scorer,
|
||||
use_gold_ents=use_gold_ents,
|
||||
threshold=threshold,
|
||||
)
|
||||
|
||||
|
||||
|
@ -156,6 +163,7 @@ class EntityLinker(TrainablePipe):
|
|||
overwrite: bool = BACKWARD_OVERWRITE,
|
||||
scorer: Optional[Callable] = entity_linker_score,
|
||||
use_gold_ents: bool,
|
||||
threshold: Optional[float] = None,
|
||||
) -> None:
|
||||
"""Initialize an entity linker.
|
||||
|
||||
|
@ -174,9 +182,20 @@ class EntityLinker(TrainablePipe):
|
|||
Scorer.score_links.
|
||||
use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another
|
||||
component must provide entity annotations.
|
||||
|
||||
threshold (Optional[float]): Confidence threshold for entity predictions. If confidence is below the
|
||||
threshold, prediction is discarded. If None, predictions are not filtered by any threshold.
|
||||
DOCS: https://spacy.io/api/entitylinker#init
|
||||
"""
|
||||
|
||||
if threshold is not None and not (0 <= threshold <= 1):
|
||||
raise ValueError(
|
||||
Errors.E1043.format(
|
||||
range_start=0,
|
||||
range_end=1,
|
||||
value=threshold,
|
||||
)
|
||||
)
|
||||
|
||||
self.vocab = vocab
|
||||
self.model = model
|
||||
self.name = name
|
||||
|
@ -192,6 +211,7 @@ class EntityLinker(TrainablePipe):
|
|||
self.kb = empty_kb(entity_vector_length)(self.vocab)
|
||||
self.scorer = scorer
|
||||
self.use_gold_ents = use_gold_ents
|
||||
self.threshold = threshold
|
||||
|
||||
def set_kb(self, kb_loader: Callable[[Vocab], KnowledgeBase]):
|
||||
"""Define the KB of this pipe by providing a function that will
|
||||
|
@ -424,9 +444,8 @@ class EntityLinker(TrainablePipe):
|
|||
if not candidates:
|
||||
# no prediction possible for this entity - setting to NIL
|
||||
final_kb_ids.append(self.NIL)
|
||||
elif len(candidates) == 1:
|
||||
elif len(candidates) == 1 and self.threshold is None:
|
||||
# shortcut for efficiency reasons: take the 1 candidate
|
||||
# TODO: thresholding
|
||||
final_kb_ids.append(candidates[0].entity_)
|
||||
else:
|
||||
random.shuffle(candidates)
|
||||
|
@ -455,10 +474,11 @@ class EntityLinker(TrainablePipe):
|
|||
if sims.shape != prior_probs.shape:
|
||||
raise ValueError(Errors.E161)
|
||||
scores = prior_probs + sims - (prior_probs * sims)
|
||||
# TODO: thresholding
|
||||
best_index = scores.argmax().item()
|
||||
best_candidate = candidates[best_index]
|
||||
final_kb_ids.append(best_candidate.entity_)
|
||||
final_kb_ids.append(
|
||||
candidates[scores.argmax().item()].entity_
|
||||
if self.threshold is None or scores.max() >= self.threshold
|
||||
else EntityLinker.NIL
|
||||
)
|
||||
if not (len(final_kb_ids) == entity_count):
|
||||
err = Errors.E147.format(
|
||||
method="predict", msg="result variables not of equal length"
|
||||
|
|
|
@ -7,7 +7,7 @@ from pathlib import Path
|
|||
from itertools import islice
|
||||
import srsly
|
||||
import random
|
||||
from thinc.api import CosineDistance, Model, Optimizer, Config
|
||||
from thinc.api import CosineDistance, Model, Optimizer
|
||||
from thinc.api import set_dropout_rate
|
||||
import warnings
|
||||
|
||||
|
@ -20,7 +20,7 @@ from ...language import Language
|
|||
from ...vocab import Vocab
|
||||
from ...training import Example, validate_examples, validate_get_examples
|
||||
from ...errors import Errors, Warnings
|
||||
from ...util import SimpleFrozenList, registry
|
||||
from ...util import SimpleFrozenList
|
||||
from ... import util
|
||||
from ...scorer import Scorer
|
||||
|
||||
|
@ -70,7 +70,6 @@ class EntityLinker_v1(TrainablePipe):
|
|||
produces a list of candidates, given a certain knowledge base and a textual mention.
|
||||
scorer (Optional[Callable]): The scoring method. Defaults to
|
||||
Scorer.score_links.
|
||||
|
||||
DOCS: https://spacy.io/api/entitylinker#init
|
||||
"""
|
||||
self.vocab = vocab
|
||||
|
@ -272,7 +271,6 @@ class EntityLinker_v1(TrainablePipe):
|
|||
final_kb_ids.append(self.NIL)
|
||||
elif len(candidates) == 1:
|
||||
# shortcut for efficiency reasons: take the 1 candidate
|
||||
# TODO: thresholding
|
||||
final_kb_ids.append(candidates[0].entity_)
|
||||
else:
|
||||
random.shuffle(candidates)
|
||||
|
@ -301,7 +299,6 @@ class EntityLinker_v1(TrainablePipe):
|
|||
if sims.shape != prior_probs.shape:
|
||||
raise ValueError(Errors.E161)
|
||||
scores = prior_probs + sims - (prior_probs * sims)
|
||||
# TODO: thresholding
|
||||
best_index = scores.argmax().item()
|
||||
best_candidate = candidates[best_index]
|
||||
final_kb_ids.append(best_candidate.entity_)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Callable, Iterable
|
||||
from typing import Callable, Iterable, Dict, Any
|
||||
|
||||
import pytest
|
||||
from numpy.testing import assert_equal
|
||||
|
@ -207,7 +207,7 @@ def test_no_entities():
|
|||
nlp.add_pipe("sentencizer", first=True)
|
||||
|
||||
# this will run the pipeline on the examples and shouldn't crash
|
||||
results = nlp.evaluate(train_examples)
|
||||
nlp.evaluate(train_examples)
|
||||
|
||||
|
||||
def test_partial_links():
|
||||
|
@ -1063,7 +1063,7 @@ def test_no_gold_ents(patterns):
|
|||
"entity_linker", config={"use_gold_ents": False}, last=True
|
||||
)
|
||||
entity_linker.set_kb(create_kb)
|
||||
assert entity_linker.use_gold_ents == False
|
||||
assert entity_linker.use_gold_ents is False
|
||||
|
||||
optimizer = nlp.initialize(get_examples=lambda: train_examples)
|
||||
for i in range(2):
|
||||
|
@ -1074,7 +1074,7 @@ def test_no_gold_ents(patterns):
|
|||
nlp.add_pipe("sentencizer", first=True)
|
||||
|
||||
# this will run the pipeline on the examples and shouldn't crash
|
||||
results = nlp.evaluate(train_examples)
|
||||
nlp.evaluate(train_examples)
|
||||
|
||||
|
||||
@pytest.mark.issue(9575)
|
||||
|
@ -1114,4 +1114,61 @@ def test_tokenization_mismatch():
|
|||
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||
|
||||
nlp.add_pipe("sentencizer", first=True)
|
||||
results = nlp.evaluate(train_examples)
|
||||
nlp.evaluate(train_examples)
|
||||
|
||||
|
||||
# fmt: off
|
||||
@pytest.mark.parametrize(
|
||||
"meet_threshold,config",
|
||||
[
|
||||
(False, {"@architectures": "spacy.EntityLinker.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL}),
|
||||
(True, {"@architectures": "spacy.EntityLinker.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL}),
|
||||
],
|
||||
)
|
||||
# fmt: on
|
||||
def test_threshold(meet_threshold: bool, config: Dict[str, Any]):
|
||||
"""Tests abstention threshold.
|
||||
meet_threshold (bool): Whether to configure NEL setup so that confidence threshold is met.
|
||||
config (Dict[str, Any]): NEL architecture config.
|
||||
"""
|
||||
nlp = English()
|
||||
nlp.add_pipe("sentencizer")
|
||||
text = "Mahler's Symphony No. 8 was beautiful."
|
||||
entities = [(0, 6, "PERSON")]
|
||||
links = {(0, 6): {"Q7304": 1.0}}
|
||||
sent_starts = [1, -1, 0, 0, 0, 0, 0, 0, 0]
|
||||
entity_id = "Q7304"
|
||||
doc = nlp(text)
|
||||
train_examples = [
|
||||
Example.from_dict(
|
||||
doc, {"entities": entities, "links": links, "sent_starts": sent_starts}
|
||||
)
|
||||
]
|
||||
|
||||
def create_kb(vocab):
|
||||
# create artificial KB
|
||||
mykb = KnowledgeBase(vocab, entity_vector_length=3)
|
||||
mykb.add_entity(entity=entity_id, freq=12, entity_vector=[6, -4, 3])
|
||||
mykb.add_alias(
|
||||
alias="Mahler",
|
||||
entities=[entity_id],
|
||||
probabilities=[1 if meet_threshold else 0.01],
|
||||
)
|
||||
return mykb
|
||||
|
||||
# Create the Entity Linker component and add it to the pipeline
|
||||
entity_linker = nlp.add_pipe(
|
||||
"entity_linker",
|
||||
last=True,
|
||||
config={"threshold": 0.99, "model": config},
|
||||
)
|
||||
entity_linker.set_kb(create_kb) # type: ignore
|
||||
nlp.initialize(get_examples=lambda: train_examples)
|
||||
|
||||
# Add a custom rule-based component to mimick NER
|
||||
ruler = nlp.add_pipe("entity_ruler", before="entity_linker")
|
||||
ruler.add_patterns([{"label": "PERSON", "pattern": [{"LOWER": "mahler"}]}]) # type: ignore
|
||||
doc = nlp(text)
|
||||
|
||||
assert len(doc.ents) == 1
|
||||
assert doc.ents[0].kb_id_ == entity_id if meet_threshold else EntityLinker.NIL
|
||||
|
|
|
@ -47,12 +47,13 @@ architectures and their arguments and hyperparameters.
|
|||
> "model": DEFAULT_NEL_MODEL,
|
||||
> "entity_vector_length": 64,
|
||||
> "get_candidates": {'@misc': 'spacy.CandidateGenerator.v1'},
|
||||
> "threshold": None,
|
||||
> }
|
||||
> nlp.add_pipe("entity_linker", config=config)
|
||||
> ```
|
||||
|
||||
| Setting | Description |
|
||||
| ---------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| ---------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `labels_discard` | NER labels that will automatically get a "NIL" prediction. Defaults to `[]`. ~~Iterable[str]~~ |
|
||||
| `n_sents` | The number of neighbouring sentences to take into account. Defaults to 0. ~~int~~ |
|
||||
| `incl_prior` | Whether or not to include prior probabilities from the KB in the model. Defaults to `True`. ~~bool~~ |
|
||||
|
@ -63,6 +64,7 @@ architectures and their arguments and hyperparameters.
|
|||
| `get_candidates` | Function that generates plausible candidates for a given `Span` object. Defaults to [CandidateGenerator](/api/architectures#CandidateGenerator), a function looking up exact, case-dependent aliases in the KB. ~~Callable[[KnowledgeBase, Span], Iterable[Candidate]]~~ |
|
||||
| `overwrite` <Tag variant="new">3.2</Tag> | Whether existing annotation is overwritten. Defaults to `True`. ~~bool~~ |
|
||||
| `scorer` <Tag variant="new">3.2</Tag> | The scoring method. Defaults to [`Scorer.score_links`](/api/scorer#score_links). ~~Optional[Callable]~~ |
|
||||
| `threshold` <Tag variant="new">3.4</Tag> | Confidence threshold for entity predictions. The default of `None` implies that all predictions are accepted, otherwise those with a score beneath the treshold are discarded. If there are no predictions with scores above the threshold, the linked entity is `NIL`. ~~Optional[float]~~ |
|
||||
|
||||
```python
|
||||
%%GITHUB_SPACY/spacy/pipeline/entity_linker.py
|
||||
|
@ -96,7 +98,7 @@ custom knowledge base, you should either call
|
|||
[`initialize`](/api/entitylinker#initialize) call.
|
||||
|
||||
| Name | Description |
|
||||
| ---------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| ---------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `vocab` | The shared vocabulary. ~~Vocab~~ |
|
||||
| `model` | The [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. ~~Model~~ |
|
||||
| `name` | String name of the component instance. Used to add entries to the `losses` during training. ~~str~~ |
|
||||
|
@ -109,6 +111,7 @@ custom knowledge base, you should either call
|
|||
| `incl_context` | Whether or not to include the local context in the model. ~~bool~~ |
|
||||
| `overwrite` <Tag variant="new">3.2</Tag> | Whether existing annotation is overwritten. Defaults to `True`. ~~bool~~ |
|
||||
| `scorer` <Tag variant="new">3.2</Tag> | The scoring method. Defaults to [`Scorer.score_links`](/api/scorer#score_links). ~~Optional[Callable]~~ |
|
||||
| `threshold` <Tag variant="new">3.4</Tag> | Confidence threshold for entity predictions. The default of `None` implies that all predictions are accepted, otherwise those with a score beneath the treshold are discarded. If there are no predictions with scores above the threshold, the linked entity is `NIL`. ~~Optional[float]~~ |
|
||||
|
||||
## EntityLinker.\_\_call\_\_ {#call tag="method"}
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user