Fix use_gold_ents behaviour for EntityLinker (#13400)

* fix type annotation in docs

* only restore entities after loss calculation

* restore entities of sample in initialization

* rename overfitting function

* fix EL scorer

* Relax test

* fix formatting

* Update spacy/pipeline/entity_linker.py

Co-authored-by: Raphael Mitsch <r.mitsch@outlook.com>

* rename to _ensure_ents

* further rename

* allow for scorer to be None

---------

Co-authored-by: Raphael Mitsch <r.mitsch@outlook.com>
This commit is contained in:
Sofie Van Landeghem 2024-04-16 12:00:22 +02:00 committed by GitHub
parent 2e96797696
commit 2e2334632b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 145 additions and 27 deletions

View File

@ -11,7 +11,6 @@ from .. import util
from ..errors import Errors from ..errors import Errors
from ..kb import Candidate, KnowledgeBase from ..kb import Candidate, KnowledgeBase
from ..language import Language from ..language import Language
from ..ml import empty_kb
from ..scorer import Scorer from ..scorer import Scorer
from ..tokens import Doc, Span from ..tokens import Doc, Span
from ..training import Example, validate_examples, validate_get_examples from ..training import Example, validate_examples, validate_get_examples
@ -105,7 +104,7 @@ def make_entity_linker(
): Function that produces a list of candidates, given a certain knowledge base and several textual mentions. ): Function that produces a list of candidates, given a certain knowledge base and several textual mentions.
generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase. generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase.
scorer (Optional[Callable]): The scoring method. scorer (Optional[Callable]): The scoring method.
use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another use_gold_ents (bool): Whether to copy entities from gold docs during training or not. If false, another
component must provide entity annotations. component must provide entity annotations.
candidates_batch_size (int): Size of batches for entity candidate generation. candidates_batch_size (int): Size of batches for entity candidate generation.
threshold (Optional[float]): Confidence threshold for entity predictions. If confidence is below the threshold, threshold (Optional[float]): Confidence threshold for entity predictions. If confidence is below the threshold,
@ -235,7 +234,6 @@ class EntityLinker(TrainablePipe):
self.cfg: Dict[str, Any] = {"overwrite": overwrite} self.cfg: Dict[str, Any] = {"overwrite": overwrite}
self.distance = CosineDistance(normalize=False) self.distance = CosineDistance(normalize=False)
self.kb = generate_empty_kb(self.vocab, entity_vector_length) self.kb = generate_empty_kb(self.vocab, entity_vector_length)
self.scorer = scorer
self.use_gold_ents = use_gold_ents self.use_gold_ents = use_gold_ents
self.candidates_batch_size = candidates_batch_size self.candidates_batch_size = candidates_batch_size
self.threshold = threshold self.threshold = threshold
@ -243,6 +241,37 @@ class EntityLinker(TrainablePipe):
if candidates_batch_size < 1: if candidates_batch_size < 1:
raise ValueError(Errors.E1044) raise ValueError(Errors.E1044)
def _score_with_ents_set(examples: Iterable[Example], **kwargs):
# Because of how spaCy works, we can't just score immediately, because Language.evaluate
# calls pipe() on the predicted docs, which won't have entities if there is no NER in the pipeline.
if not scorer:
return scorer
if not self.use_gold_ents:
return scorer(examples, **kwargs)
else:
examples = self._ensure_ents(examples)
docs = self.pipe(
(eg.predicted for eg in examples),
)
for eg, doc in zip(examples, docs):
eg.predicted = doc
return scorer(examples, **kwargs)
self.scorer = _score_with_ents_set
def _ensure_ents(self, examples: Iterable[Example]) -> Iterable[Example]:
"""If use_gold_ents is true, set the gold entities to (a copy of) eg.predicted."""
if not self.use_gold_ents:
return examples
new_examples = []
for eg in examples:
ents, _ = eg.get_aligned_ents_and_ner()
new_eg = eg.copy()
new_eg.predicted.ents = ents
new_examples.append(new_eg)
return new_examples
def set_kb(self, kb_loader: Callable[[Vocab], KnowledgeBase]): def set_kb(self, kb_loader: Callable[[Vocab], KnowledgeBase]):
"""Define the KB of this pipe by providing a function that will """Define the KB of this pipe by providing a function that will
create it using this object's vocab.""" create it using this object's vocab."""
@ -284,11 +313,9 @@ class EntityLinker(TrainablePipe):
nO = self.kb.entity_vector_length nO = self.kb.entity_vector_length
doc_sample = [] doc_sample = []
vector_sample = [] vector_sample = []
for eg in islice(get_examples(), 10): examples = self._ensure_ents(islice(get_examples(), 10))
for eg in examples:
doc = eg.x doc = eg.x
if self.use_gold_ents:
ents, _ = eg.get_aligned_ents_and_ner()
doc.ents = ents
doc_sample.append(doc) doc_sample.append(doc)
vector_sample.append(self.model.ops.alloc1f(nO)) vector_sample.append(self.model.ops.alloc1f(nO))
assert len(doc_sample) > 0, Errors.E923.format(name=self.name) assert len(doc_sample) > 0, Errors.E923.format(name=self.name)
@ -354,31 +381,17 @@ class EntityLinker(TrainablePipe):
losses.setdefault(self.name, 0.0) losses.setdefault(self.name, 0.0)
if not examples: if not examples:
return losses return losses
examples = self._ensure_ents(examples)
validate_examples(examples, "EntityLinker.update") validate_examples(examples, "EntityLinker.update")
set_dropout_rate(self.model, drop)
docs = [eg.predicted for eg in examples]
# save to restore later
old_ents = [doc.ents for doc in docs]
for doc, ex in zip(docs, examples):
if self.use_gold_ents:
ents, _ = ex.get_aligned_ents_and_ner()
doc.ents = ents
else:
# only keep matching ents
doc.ents = ex.get_matching_ents()
# make sure we have something to learn from, if not, short-circuit # make sure we have something to learn from, if not, short-circuit
if not self.batch_has_learnable_example(examples): if not self.batch_has_learnable_example(examples):
return losses return losses
set_dropout_rate(self.model, drop)
docs = [eg.predicted for eg in examples]
sentence_encodings, bp_context = self.model.begin_update(docs) sentence_encodings, bp_context = self.model.begin_update(docs)
# now restore the ents
for doc, old in zip(docs, old_ents):
doc.ents = old
loss, d_scores = self.get_loss( loss, d_scores = self.get_loss(
sentence_encodings=sentence_encodings, examples=examples sentence_encodings=sentence_encodings, examples=examples
) )
@ -386,11 +399,13 @@ class EntityLinker(TrainablePipe):
if sgd is not None: if sgd is not None:
self.finish_update(sgd) self.finish_update(sgd)
losses[self.name] += loss losses[self.name] += loss
return losses return losses
def get_loss(self, examples: Iterable[Example], sentence_encodings: Floats2d): def get_loss(self, examples: Iterable[Example], sentence_encodings: Floats2d):
validate_examples(examples, "EntityLinker.get_loss") validate_examples(examples, "EntityLinker.get_loss")
entity_encodings = [] entity_encodings = []
# We assume that get_loss is called with gold ents set in the examples if need be
eidx = 0 # indices in gold entities to keep eidx = 0 # indices in gold entities to keep
keep_ents = [] # indices in sentence_encodings to keep keep_ents = [] # indices in sentence_encodings to keep

View File

@ -717,7 +717,7 @@ GOLD_entities = ["Q2146908", "Q7381115", "Q7381115", "Q2146908"]
# fmt: on # fmt: on
def test_overfitting_IO(): def test_overfitting_IO_gold_entities():
# Simple test to try and quickly overfit the NEL component - ensuring the ML models work correctly # Simple test to try and quickly overfit the NEL component - ensuring the ML models work correctly
nlp = English() nlp = English()
vector_length = 3 vector_length = 3
@ -744,7 +744,9 @@ def test_overfitting_IO():
return mykb return mykb
# Create the Entity Linker component and add it to the pipeline # Create the Entity Linker component and add it to the pipeline
entity_linker = nlp.add_pipe("entity_linker", last=True) entity_linker = nlp.add_pipe(
"entity_linker", last=True, config={"use_gold_ents": True}
)
assert isinstance(entity_linker, EntityLinker) assert isinstance(entity_linker, EntityLinker)
entity_linker.set_kb(create_kb) entity_linker.set_kb(create_kb)
assert "Q2146908" in entity_linker.vocab.strings assert "Q2146908" in entity_linker.vocab.strings
@ -807,6 +809,107 @@ def test_overfitting_IO():
assert_equal(batch_deps_1, batch_deps_2) assert_equal(batch_deps_1, batch_deps_2)
assert_equal(batch_deps_1, no_batch_deps) assert_equal(batch_deps_1, no_batch_deps)
eval = nlp.evaluate(train_examples)
assert "nel_macro_p" in eval
assert "nel_macro_r" in eval
assert "nel_macro_f" in eval
assert "nel_micro_p" in eval
assert "nel_micro_r" in eval
assert "nel_micro_f" in eval
assert "nel_f_per_type" in eval
assert "PERSON" in eval["nel_f_per_type"]
assert eval["nel_macro_f"] > 0
assert eval["nel_micro_f"] > 0
def test_overfitting_IO_with_ner():
# Simple test to try and overfit the NER and NEL component in combination - ensuring the ML models work correctly
nlp = English()
vector_length = 3
assert "Q2146908" not in nlp.vocab.strings
# Convert the texts to docs to make sure we have doc.ents set for the training examples
train_examples = []
for text, annotation in TRAIN_DATA:
doc = nlp(text)
train_examples.append(Example.from_dict(doc, annotation))
def create_kb(vocab):
# create artificial KB - assign same prior weight to the two russ cochran's
# Q2146908 (Russ Cochran): American golfer
# Q7381115 (Russ Cochran): publisher
mykb = InMemoryLookupKB(vocab, entity_vector_length=vector_length)
mykb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3])
mykb.add_entity(entity="Q7381115", freq=12, entity_vector=[9, 1, -7])
mykb.add_alias(
alias="Russ Cochran",
entities=["Q2146908", "Q7381115"],
probabilities=[0.5, 0.5],
)
return mykb
# Create the NER and EL components and add them to the pipeline
ner = nlp.add_pipe("ner", first=True)
entity_linker = nlp.add_pipe(
"entity_linker", last=True, config={"use_gold_ents": False}
)
entity_linker.set_kb(create_kb)
train_examples = []
for text, annotations in TRAIN_DATA:
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
for ent in annotations.get("entities"):
ner.add_label(ent[2])
optimizer = nlp.initialize()
# train the NER and NEL pipes
for i in range(50):
losses = {}
nlp.update(train_examples, sgd=optimizer, losses=losses)
assert losses["ner"] < 0.001
assert losses["entity_linker"] < 0.001
# adding additional components that are required for the entity_linker
nlp.add_pipe("sentencizer", first=True)
# test the trained model
test_text = "Russ Cochran captured his first major title with his son as caddie."
doc = nlp(test_text)
ents = doc.ents
assert len(ents) == 1
assert ents[0].text == "Russ Cochran"
assert ents[0].label_ == "PERSON"
assert ents[0].kb_id_ != "NIL"
# TODO: below assert is still flaky - EL doesn't properly overfit quite yet
# assert ents[0].kb_id_ == "Q2146908"
# Also test the results are still the same after IO
with make_tempdir() as tmp_dir:
nlp.to_disk(tmp_dir)
nlp2 = util.load_model_from_path(tmp_dir)
assert nlp2.pipe_names == nlp.pipe_names
doc2 = nlp2(test_text)
ents2 = doc2.ents
assert len(ents2) == 1
assert ents2[0].text == "Russ Cochran"
assert ents2[0].label_ == "PERSON"
assert ents2[0].kb_id_ != "NIL"
eval = nlp.evaluate(train_examples)
assert "nel_macro_f" in eval
assert "nel_micro_f" in eval
assert "ents_f" in eval
assert "nel_f_per_type" in eval
assert "ents_per_type" in eval
assert "PERSON" in eval["nel_f_per_type"]
assert "PERSON" in eval["ents_per_type"]
assert eval["nel_macro_f"] > 0
assert eval["nel_micro_f"] > 0
assert eval["ents_f"] > 0
def test_kb_serialization(): def test_kb_serialization():
# Test that the KB can be used in a pipeline with a different vocab # Test that the KB can be used in a pipeline with a different vocab

View File

@ -61,7 +61,7 @@ architectures and their arguments and hyperparameters.
| `incl_context` | Whether or not to include the local context in the model. Defaults to `True`. ~~bool~~ | | `incl_context` | Whether or not to include the local context in the model. Defaults to `True`. ~~bool~~ |
| `model` | The [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. Defaults to [EntityLinker](/api/architectures#EntityLinker). ~~Model~~ | | `model` | The [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. Defaults to [EntityLinker](/api/architectures#EntityLinker). ~~Model~~ |
| `entity_vector_length` | Size of encoding vectors in the KB. Defaults to `64`. ~~int~~ | | `entity_vector_length` | Size of encoding vectors in the KB. Defaults to `64`. ~~int~~ |
| `use_gold_ents` | Whether to copy entities from the gold docs or not. Defaults to `True`. If `False`, entities must be set in the training data or by an annotating component in the pipeline. ~~int~~ | | `use_gold_ents` | Whether to copy entities from the gold docs or not. Defaults to `True`. If `False`, entities must be set in the training data or by an annotating component in the pipeline. ~~bool~~ |
| `get_candidates` | Function that generates plausible candidates for a given `Span` object. Defaults to [CandidateGenerator](/api/architectures#CandidateGenerator), a function looking up exact, case-dependent aliases in the KB. ~~Callable[[KnowledgeBase, Span], Iterable[Candidate]]~~ | | `get_candidates` | Function that generates plausible candidates for a given `Span` object. Defaults to [CandidateGenerator](/api/architectures#CandidateGenerator), a function looking up exact, case-dependent aliases in the KB. ~~Callable[[KnowledgeBase, Span], Iterable[Candidate]]~~ |
| `get_candidates_batch` <Tag variant="new">3.5</Tag> | Function that generates plausible candidates for a given batch of `Span` objects. Defaults to [CandidateBatchGenerator](/api/architectures#CandidateBatchGenerator), a function looking up exact, case-dependent aliases in the KB. ~~Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]]~~ | | `get_candidates_batch` <Tag variant="new">3.5</Tag> | Function that generates plausible candidates for a given batch of `Span` objects. Defaults to [CandidateBatchGenerator](/api/architectures#CandidateBatchGenerator), a function looking up exact, case-dependent aliases in the KB. ~~Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]]~~ |
| `generate_empty_kb` <Tag variant="new">3.5.1</Tag> | Function that generates an empty `KnowledgeBase` object. Defaults to [`spacy.EmptyKB.v2`](/api/architectures#EmptyKB), which generates an empty [`InMemoryLookupKB`](/api/inmemorylookupkb). ~~Callable[[Vocab, int], KnowledgeBase]~~ | | `generate_empty_kb` <Tag variant="new">3.5.1</Tag> | Function that generates an empty `KnowledgeBase` object. Defaults to [`spacy.EmptyKB.v2`](/api/architectures#EmptyKB), which generates an empty [`InMemoryLookupKB`](/api/inmemorylookupkb). ~~Callable[[Vocab, int], KnowledgeBase]~~ |