Fix Entity Linker with tokenization mismatches (fix #9575) (#10457)

* Add failing test

* Partial fix for issue

This kind of works. The issue with token length mismatches is gone. The
problem is that when you get empty lists of encodings to compare, it
fails because the sizes are not the same, even though they're both zero:
(0, 3) vs (0,). Not sure why that happens...

* Short circuit on empties

* Remove spurious check

The check here isn't needed now the the short circuit is fixed.

* Update spacy/tests/pipeline/test_entity_linker.py

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Use "eg", not "example"

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
This commit is contained in:
Paul O'Leary McCann 2022-05-24 03:42:26 +09:00 committed by GitHub
parent 1d34aa2b3d
commit 6be09bbd07
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 42 additions and 7 deletions

View File

@ -234,10 +234,11 @@ class EntityLinker(TrainablePipe):
nO = self.kb.entity_vector_length nO = self.kb.entity_vector_length
doc_sample = [] doc_sample = []
vector_sample = [] vector_sample = []
for example in islice(get_examples(), 10): for eg in islice(get_examples(), 10):
doc = example.x doc = eg.x
if self.use_gold_ents: if self.use_gold_ents:
doc.ents = example.y.ents ents, _ = eg.get_aligned_ents_and_ner()
doc.ents = ents
doc_sample.append(doc) doc_sample.append(doc)
vector_sample.append(self.model.ops.alloc1f(nO)) vector_sample.append(self.model.ops.alloc1f(nO))
assert len(doc_sample) > 0, Errors.E923.format(name=self.name) assert len(doc_sample) > 0, Errors.E923.format(name=self.name)
@ -312,7 +313,8 @@ class EntityLinker(TrainablePipe):
for doc, ex in zip(docs, examples): for doc, ex in zip(docs, examples):
if self.use_gold_ents: if self.use_gold_ents:
doc.ents = ex.reference.ents ents, _ = ex.get_aligned_ents_and_ner()
doc.ents = ents
else: else:
# only keep matching ents # only keep matching ents
doc.ents = ex.get_matching_ents() doc.ents = ex.get_matching_ents()
@ -345,7 +347,7 @@ class EntityLinker(TrainablePipe):
for eg in examples: for eg in examples:
kb_ids = eg.get_aligned("ENT_KB_ID", as_string=True) kb_ids = eg.get_aligned("ENT_KB_ID", as_string=True)
for ent in eg.reference.ents: for ent in eg.get_matching_ents():
kb_id = kb_ids[ent.start] kb_id = kb_ids[ent.start]
if kb_id: if kb_id:
entity_encoding = self.kb.get_vector(kb_id) entity_encoding = self.kb.get_vector(kb_id)
@ -356,7 +358,11 @@ class EntityLinker(TrainablePipe):
entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32") entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32")
selected_encodings = sentence_encodings[keep_ents] selected_encodings = sentence_encodings[keep_ents]
# If the entity encodings list is empty, then # if there are no matches, short circuit
if not keep_ents:
out = self.model.ops.alloc2f(*sentence_encodings.shape)
return 0, out
if selected_encodings.shape != entity_encodings.shape: if selected_encodings.shape != entity_encodings.shape:
err = Errors.E147.format( err = Errors.E147.format(
method="get_loss", msg="gold entities do not match up" method="get_loss", msg="gold entities do not match up"

View File

@ -14,7 +14,7 @@ from spacy.pipeline.legacy import EntityLinker_v1
from spacy.pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL from spacy.pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL
from spacy.scorer import Scorer from spacy.scorer import Scorer
from spacy.tests.util import make_tempdir from spacy.tests.util import make_tempdir
from spacy.tokens import Span from spacy.tokens import Span, Doc
from spacy.training import Example from spacy.training import Example
from spacy.util import ensure_path from spacy.util import ensure_path
from spacy.vocab import Vocab from spacy.vocab import Vocab
@ -1075,3 +1075,32 @@ def test_no_gold_ents(patterns):
# this will run the pipeline on the examples and shouldn't crash # this will run the pipeline on the examples and shouldn't crash
results = nlp.evaluate(train_examples) results = nlp.evaluate(train_examples)
@pytest.mark.issue(9575)
def test_tokenization_mismatch():
nlp = English()
# include a matching entity so that update isn't skipped
doc1 = Doc(nlp.vocab, words=["Kirby", "123456"], spaces=[True, False], ents=["B-CHARACTER", "B-CARDINAL"])
doc2 = Doc(nlp.vocab, words=["Kirby", "123", "456"], spaces=[True, False, False], ents=["B-CHARACTER", "B-CARDINAL", "B-CARDINAL"])
eg = Example(doc1, doc2)
train_examples = [eg]
vector_length = 3
def create_kb(vocab):
# create placeholder KB
mykb = KnowledgeBase(vocab, entity_vector_length=vector_length)
mykb.add_entity(entity="Q613241", freq=12, entity_vector=[6, -4, 3])
mykb.add_alias("Kirby", ["Q613241"], [0.9])
return mykb
entity_linker = nlp.add_pipe("entity_linker", last=True)
entity_linker.set_kb(create_kb)
optimizer = nlp.initialize(get_examples=lambda: train_examples)
for i in range(2):
losses = {}
nlp.update(train_examples, sgd=optimizer, losses=losses)
nlp.add_pipe("sentencizer", first=True)
results = nlp.evaluate(train_examples)