mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 17:24:41 +03:00
Fix use_gold_ents behaviour for EntityLinker (#13400)
* fix type annotation in docs * only restore entities after loss calculation * restore entities of sample in initialization * rename overfitting function * fix EL scorer * Relax test * fix formatting * Update spacy/pipeline/entity_linker.py Co-authored-by: Raphael Mitsch <r.mitsch@outlook.com> * rename to _ensure_ents * further rename * allow for scorer to be None --------- Co-authored-by: Raphael Mitsch <r.mitsch@outlook.com>
This commit is contained in:
parent
2e96797696
commit
2e2334632b
|
@ -11,7 +11,6 @@ from .. import util
|
|||
from ..errors import Errors
|
||||
from ..kb import Candidate, KnowledgeBase
|
||||
from ..language import Language
|
||||
from ..ml import empty_kb
|
||||
from ..scorer import Scorer
|
||||
from ..tokens import Doc, Span
|
||||
from ..training import Example, validate_examples, validate_get_examples
|
||||
|
@ -105,7 +104,7 @@ def make_entity_linker(
|
|||
): Function that produces a list of candidates, given a certain knowledge base and several textual mentions.
|
||||
generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase.
|
||||
scorer (Optional[Callable]): The scoring method.
|
||||
use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another
|
||||
use_gold_ents (bool): Whether to copy entities from gold docs during training or not. If false, another
|
||||
component must provide entity annotations.
|
||||
candidates_batch_size (int): Size of batches for entity candidate generation.
|
||||
threshold (Optional[float]): Confidence threshold for entity predictions. If confidence is below the threshold,
|
||||
|
@ -235,7 +234,6 @@ class EntityLinker(TrainablePipe):
|
|||
self.cfg: Dict[str, Any] = {"overwrite": overwrite}
|
||||
self.distance = CosineDistance(normalize=False)
|
||||
self.kb = generate_empty_kb(self.vocab, entity_vector_length)
|
||||
self.scorer = scorer
|
||||
self.use_gold_ents = use_gold_ents
|
||||
self.candidates_batch_size = candidates_batch_size
|
||||
self.threshold = threshold
|
||||
|
@ -243,6 +241,37 @@ class EntityLinker(TrainablePipe):
|
|||
if candidates_batch_size < 1:
|
||||
raise ValueError(Errors.E1044)
|
||||
|
||||
def _score_with_ents_set(examples: Iterable[Example], **kwargs):
|
||||
# Because of how spaCy works, we can't just score immediately, because Language.evaluate
|
||||
# calls pipe() on the predicted docs, which won't have entities if there is no NER in the pipeline.
|
||||
if not scorer:
|
||||
return scorer
|
||||
if not self.use_gold_ents:
|
||||
return scorer(examples, **kwargs)
|
||||
else:
|
||||
examples = self._ensure_ents(examples)
|
||||
docs = self.pipe(
|
||||
(eg.predicted for eg in examples),
|
||||
)
|
||||
for eg, doc in zip(examples, docs):
|
||||
eg.predicted = doc
|
||||
return scorer(examples, **kwargs)
|
||||
|
||||
self.scorer = _score_with_ents_set
|
||||
|
||||
def _ensure_ents(self, examples: Iterable[Example]) -> Iterable[Example]:
|
||||
"""If use_gold_ents is true, set the gold entities to (a copy of) eg.predicted."""
|
||||
if not self.use_gold_ents:
|
||||
return examples
|
||||
|
||||
new_examples = []
|
||||
for eg in examples:
|
||||
ents, _ = eg.get_aligned_ents_and_ner()
|
||||
new_eg = eg.copy()
|
||||
new_eg.predicted.ents = ents
|
||||
new_examples.append(new_eg)
|
||||
return new_examples
|
||||
|
||||
def set_kb(self, kb_loader: Callable[[Vocab], KnowledgeBase]):
|
||||
"""Define the KB of this pipe by providing a function that will
|
||||
create it using this object's vocab."""
|
||||
|
@ -284,11 +313,9 @@ class EntityLinker(TrainablePipe):
|
|||
nO = self.kb.entity_vector_length
|
||||
doc_sample = []
|
||||
vector_sample = []
|
||||
for eg in islice(get_examples(), 10):
|
||||
examples = self._ensure_ents(islice(get_examples(), 10))
|
||||
for eg in examples:
|
||||
doc = eg.x
|
||||
if self.use_gold_ents:
|
||||
ents, _ = eg.get_aligned_ents_and_ner()
|
||||
doc.ents = ents
|
||||
doc_sample.append(doc)
|
||||
vector_sample.append(self.model.ops.alloc1f(nO))
|
||||
assert len(doc_sample) > 0, Errors.E923.format(name=self.name)
|
||||
|
@ -354,31 +381,17 @@ class EntityLinker(TrainablePipe):
|
|||
losses.setdefault(self.name, 0.0)
|
||||
if not examples:
|
||||
return losses
|
||||
examples = self._ensure_ents(examples)
|
||||
validate_examples(examples, "EntityLinker.update")
|
||||
|
||||
set_dropout_rate(self.model, drop)
|
||||
docs = [eg.predicted for eg in examples]
|
||||
# save to restore later
|
||||
old_ents = [doc.ents for doc in docs]
|
||||
|
||||
for doc, ex in zip(docs, examples):
|
||||
if self.use_gold_ents:
|
||||
ents, _ = ex.get_aligned_ents_and_ner()
|
||||
doc.ents = ents
|
||||
else:
|
||||
# only keep matching ents
|
||||
doc.ents = ex.get_matching_ents()
|
||||
|
||||
# make sure we have something to learn from, if not, short-circuit
|
||||
if not self.batch_has_learnable_example(examples):
|
||||
return losses
|
||||
|
||||
set_dropout_rate(self.model, drop)
|
||||
docs = [eg.predicted for eg in examples]
|
||||
sentence_encodings, bp_context = self.model.begin_update(docs)
|
||||
|
||||
# now restore the ents
|
||||
for doc, old in zip(docs, old_ents):
|
||||
doc.ents = old
|
||||
|
||||
loss, d_scores = self.get_loss(
|
||||
sentence_encodings=sentence_encodings, examples=examples
|
||||
)
|
||||
|
@ -386,11 +399,13 @@ class EntityLinker(TrainablePipe):
|
|||
if sgd is not None:
|
||||
self.finish_update(sgd)
|
||||
losses[self.name] += loss
|
||||
|
||||
return losses
|
||||
|
||||
def get_loss(self, examples: Iterable[Example], sentence_encodings: Floats2d):
|
||||
validate_examples(examples, "EntityLinker.get_loss")
|
||||
entity_encodings = []
|
||||
# We assume that get_loss is called with gold ents set in the examples if need be
|
||||
eidx = 0 # indices in gold entities to keep
|
||||
keep_ents = [] # indices in sentence_encodings to keep
|
||||
|
||||
|
|
|
@ -717,7 +717,7 @@ GOLD_entities = ["Q2146908", "Q7381115", "Q7381115", "Q2146908"]
|
|||
# fmt: on
|
||||
|
||||
|
||||
def test_overfitting_IO():
|
||||
def test_overfitting_IO_gold_entities():
|
||||
# Simple test to try and quickly overfit the NEL component - ensuring the ML models work correctly
|
||||
nlp = English()
|
||||
vector_length = 3
|
||||
|
@ -744,7 +744,9 @@ def test_overfitting_IO():
|
|||
return mykb
|
||||
|
||||
# Create the Entity Linker component and add it to the pipeline
|
||||
entity_linker = nlp.add_pipe("entity_linker", last=True)
|
||||
entity_linker = nlp.add_pipe(
|
||||
"entity_linker", last=True, config={"use_gold_ents": True}
|
||||
)
|
||||
assert isinstance(entity_linker, EntityLinker)
|
||||
entity_linker.set_kb(create_kb)
|
||||
assert "Q2146908" in entity_linker.vocab.strings
|
||||
|
@ -807,6 +809,107 @@ def test_overfitting_IO():
|
|||
assert_equal(batch_deps_1, batch_deps_2)
|
||||
assert_equal(batch_deps_1, no_batch_deps)
|
||||
|
||||
eval = nlp.evaluate(train_examples)
|
||||
assert "nel_macro_p" in eval
|
||||
assert "nel_macro_r" in eval
|
||||
assert "nel_macro_f" in eval
|
||||
assert "nel_micro_p" in eval
|
||||
assert "nel_micro_r" in eval
|
||||
assert "nel_micro_f" in eval
|
||||
assert "nel_f_per_type" in eval
|
||||
assert "PERSON" in eval["nel_f_per_type"]
|
||||
|
||||
assert eval["nel_macro_f"] > 0
|
||||
assert eval["nel_micro_f"] > 0
|
||||
|
||||
|
||||
def test_overfitting_IO_with_ner():
|
||||
# Simple test to try and overfit the NER and NEL component in combination - ensuring the ML models work correctly
|
||||
nlp = English()
|
||||
vector_length = 3
|
||||
assert "Q2146908" not in nlp.vocab.strings
|
||||
|
||||
# Convert the texts to docs to make sure we have doc.ents set for the training examples
|
||||
train_examples = []
|
||||
for text, annotation in TRAIN_DATA:
|
||||
doc = nlp(text)
|
||||
train_examples.append(Example.from_dict(doc, annotation))
|
||||
|
||||
def create_kb(vocab):
|
||||
# create artificial KB - assign same prior weight to the two russ cochran's
|
||||
# Q2146908 (Russ Cochran): American golfer
|
||||
# Q7381115 (Russ Cochran): publisher
|
||||
mykb = InMemoryLookupKB(vocab, entity_vector_length=vector_length)
|
||||
mykb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3])
|
||||
mykb.add_entity(entity="Q7381115", freq=12, entity_vector=[9, 1, -7])
|
||||
mykb.add_alias(
|
||||
alias="Russ Cochran",
|
||||
entities=["Q2146908", "Q7381115"],
|
||||
probabilities=[0.5, 0.5],
|
||||
)
|
||||
return mykb
|
||||
|
||||
# Create the NER and EL components and add them to the pipeline
|
||||
ner = nlp.add_pipe("ner", first=True)
|
||||
entity_linker = nlp.add_pipe(
|
||||
"entity_linker", last=True, config={"use_gold_ents": False}
|
||||
)
|
||||
entity_linker.set_kb(create_kb)
|
||||
|
||||
train_examples = []
|
||||
for text, annotations in TRAIN_DATA:
|
||||
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
|
||||
for ent in annotations.get("entities"):
|
||||
ner.add_label(ent[2])
|
||||
optimizer = nlp.initialize()
|
||||
|
||||
# train the NER and NEL pipes
|
||||
for i in range(50):
|
||||
losses = {}
|
||||
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||
assert losses["ner"] < 0.001
|
||||
assert losses["entity_linker"] < 0.001
|
||||
|
||||
# adding additional components that are required for the entity_linker
|
||||
nlp.add_pipe("sentencizer", first=True)
|
||||
|
||||
# test the trained model
|
||||
test_text = "Russ Cochran captured his first major title with his son as caddie."
|
||||
doc = nlp(test_text)
|
||||
ents = doc.ents
|
||||
assert len(ents) == 1
|
||||
assert ents[0].text == "Russ Cochran"
|
||||
assert ents[0].label_ == "PERSON"
|
||||
assert ents[0].kb_id_ != "NIL"
|
||||
|
||||
# TODO: below assert is still flaky - EL doesn't properly overfit quite yet
|
||||
# assert ents[0].kb_id_ == "Q2146908"
|
||||
|
||||
# Also test the results are still the same after IO
|
||||
with make_tempdir() as tmp_dir:
|
||||
nlp.to_disk(tmp_dir)
|
||||
nlp2 = util.load_model_from_path(tmp_dir)
|
||||
assert nlp2.pipe_names == nlp.pipe_names
|
||||
doc2 = nlp2(test_text)
|
||||
ents2 = doc2.ents
|
||||
assert len(ents2) == 1
|
||||
assert ents2[0].text == "Russ Cochran"
|
||||
assert ents2[0].label_ == "PERSON"
|
||||
assert ents2[0].kb_id_ != "NIL"
|
||||
|
||||
eval = nlp.evaluate(train_examples)
|
||||
assert "nel_macro_f" in eval
|
||||
assert "nel_micro_f" in eval
|
||||
assert "ents_f" in eval
|
||||
assert "nel_f_per_type" in eval
|
||||
assert "ents_per_type" in eval
|
||||
assert "PERSON" in eval["nel_f_per_type"]
|
||||
assert "PERSON" in eval["ents_per_type"]
|
||||
|
||||
assert eval["nel_macro_f"] > 0
|
||||
assert eval["nel_micro_f"] > 0
|
||||
assert eval["ents_f"] > 0
|
||||
|
||||
|
||||
def test_kb_serialization():
|
||||
# Test that the KB can be used in a pipeline with a different vocab
|
||||
|
|
|
@ -61,7 +61,7 @@ architectures and their arguments and hyperparameters.
|
|||
| `incl_context` | Whether or not to include the local context in the model. Defaults to `True`. ~~bool~~ |
|
||||
| `model` | The [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. Defaults to [EntityLinker](/api/architectures#EntityLinker). ~~Model~~ |
|
||||
| `entity_vector_length` | Size of encoding vectors in the KB. Defaults to `64`. ~~int~~ |
|
||||
| `use_gold_ents` | Whether to copy entities from the gold docs or not. Defaults to `True`. If `False`, entities must be set in the training data or by an annotating component in the pipeline. ~~int~~ |
|
||||
| `use_gold_ents` | Whether to copy entities from the gold docs or not. Defaults to `True`. If `False`, entities must be set in the training data or by an annotating component in the pipeline. ~~bool~~ |
|
||||
| `get_candidates` | Function that generates plausible candidates for a given `Span` object. Defaults to [CandidateGenerator](/api/architectures#CandidateGenerator), a function looking up exact, case-dependent aliases in the KB. ~~Callable[[KnowledgeBase, Span], Iterable[Candidate]]~~ |
|
||||
| `get_candidates_batch` <Tag variant="new">3.5</Tag> | Function that generates plausible candidates for a given batch of `Span` objects. Defaults to [CandidateBatchGenerator](/api/architectures#CandidateBatchGenerator), a function looking up exact, case-dependent aliases in the KB. ~~Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]]~~ |
|
||||
| `generate_empty_kb` <Tag variant="new">3.5.1</Tag> | Function that generates an empty `KnowledgeBase` object. Defaults to [`spacy.EmptyKB.v2`](/api/architectures#EmptyKB), which generates an empty [`InMemoryLookupKB`](/api/inmemorylookupkb). ~~Callable[[Vocab, int], KnowledgeBase]~~ |
|
||||
|
|
Loading…
Reference in New Issue
Block a user