Fix EL failure with sentence-crossing entities (#12398)

* Add test reproducing EL failure in sentence-crossing entities.

* Format.

* Draft fix.

* Format.

* Fix case for len(ent.sents) == 1.

* Format.

* Format.

* Format.

* Fix mypy error.

* Merge EL sentence crossing tests.

* Remove unneeded sentencizer component.

* Fix or ignore mypy issues in test.

* Simplify ent.sents handling.

* Format. Update assert in ent.sents handling.

* Small rewrite

---------

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
This commit is contained in:
Raphael Mitsch 2023-03-14 22:02:49 +01:00 committed by GitHub
parent 2ce9a220db
commit 96b61d0671
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 29 additions and 35 deletions

View File

@ -474,18 +474,24 @@ class EntityLinker(TrainablePipe):
# Looping through each entity in batch (TODO: rewrite) # Looping through each entity in batch (TODO: rewrite)
for j, ent in enumerate(ent_batch): for j, ent in enumerate(ent_batch):
sent_index = sentences.index(ent.sent) assert hasattr(ent, "sents")
assert sent_index >= 0 sents = list(ent.sents)
sent_indices = (
sentences.index(sents[0]),
sentences.index(sents[-1]),
)
assert sent_indices[1] >= sent_indices[0] >= 0
if self.incl_context: if self.incl_context:
# get n_neighbour sentences, clipped to the length of the document # get n_neighbour sentences, clipped to the length of the document
start_sentence = max(0, sent_index - self.n_sents) start_sentence = max(0, sent_indices[0] - self.n_sents)
end_sentence = min( end_sentence = min(
len(sentences) - 1, sent_index + self.n_sents len(sentences) - 1, sent_indices[1] + self.n_sents
) )
start_token = sentences[start_sentence].start start_token = sentences[start_sentence].start
end_token = sentences[end_sentence].end end_token = sentences[end_sentence].end
sent_doc = doc[start_token:end_token].as_doc() sent_doc = doc[start_token:end_token].as_doc()
# currently, the context is the same for each entity in a sentence (should be refined) # currently, the context is the same for each entity in a sentence (should be refined)
sentence_encoding = self.model.predict([sent_doc])[0] sentence_encoding = self.model.predict([sent_doc])[0]
sentence_encoding_t = sentence_encoding.T sentence_encoding_t = sentence_encoding.T

View File

@ -1,9 +1,9 @@
from typing import Callable, Iterable, Dict, Any from typing import Callable, Iterable, Dict, Any, Tuple
import pytest import pytest
from numpy.testing import assert_equal from numpy.testing import assert_equal
from spacy import registry, util from spacy import registry, util, Language
from spacy.attrs import ENT_KB_ID from spacy.attrs import ENT_KB_ID
from spacy.compat import pickle from spacy.compat import pickle
from spacy.kb import Candidate, InMemoryLookupKB, get_candidates, KnowledgeBase from spacy.kb import Candidate, InMemoryLookupKB, get_candidates, KnowledgeBase
@ -108,18 +108,23 @@ def test_issue7065():
@pytest.mark.issue(7065) @pytest.mark.issue(7065)
def test_issue7065_b(): @pytest.mark.parametrize("entity_in_first_sentence", [True, False])
def test_sentence_crossing_ents(entity_in_first_sentence: bool):
"""Tests if NEL crashes if entities cross sentence boundaries and the first associated sentence doesn't have an
entity.
entity_in_prior_sentence (bool): Whether to include an entity in the first sentence associated with the
sentence-crossing entity.
"""
# Test that the NEL doesn't crash when an entity crosses a sentence boundary # Test that the NEL doesn't crash when an entity crosses a sentence boundary
nlp = English() nlp = English()
vector_length = 3 vector_length = 3
nlp.add_pipe("sentencizer")
text = "Mahler 's Symphony No. 8 was beautiful." text = "Mahler 's Symphony No. 8 was beautiful."
entities = [(0, 6, "PERSON"), (10, 24, "WORK")] entities = [(10, 24, "WORK")]
links = { links = {(10, 24): {"Q7304": 0.0, "Q270853": 1.0}}
(0, 6): {"Q7304": 1.0, "Q270853": 0.0}, if entity_in_first_sentence:
(10, 24): {"Q7304": 0.0, "Q270853": 1.0}, entities.append((0, 6, "PERSON"))
} links[(0, 6)] = {"Q7304": 1.0, "Q270853": 0.0}
sent_starts = [1, -1, 0, 0, 0, 0, 0, 0, 0] sent_starts = [1, -1, 0, 0, 0, 1, 0, 0, 0]
doc = nlp(text) doc = nlp(text)
example = Example.from_dict( example = Example.from_dict(
doc, {"entities": entities, "links": links, "sent_starts": sent_starts} doc, {"entities": entities, "links": links, "sent_starts": sent_starts}
@ -145,31 +150,14 @@ def test_issue7065_b():
# Create the Entity Linker component and add it to the pipeline # Create the Entity Linker component and add it to the pipeline
entity_linker = nlp.add_pipe("entity_linker", last=True) entity_linker = nlp.add_pipe("entity_linker", last=True)
entity_linker.set_kb(create_kb) entity_linker.set_kb(create_kb) # type: ignore
# train the NEL pipe # train the NEL pipe
optimizer = nlp.initialize(get_examples=lambda: train_examples) optimizer = nlp.initialize(get_examples=lambda: train_examples)
for i in range(2): for i in range(2):
losses = {} nlp.update(train_examples, sgd=optimizer)
nlp.update(train_examples, sgd=optimizer, losses=losses)
# Add a custom rule-based component to mimick NER # This shouldn't crash.
patterns = [ entity_linker.predict([example.reference]) # type: ignore
{"label": "PERSON", "pattern": [{"LOWER": "mahler"}]},
{
"label": "WORK",
"pattern": [
{"LOWER": "symphony"},
{"LOWER": "no"},
{"LOWER": "."},
{"LOWER": "8"},
],
},
]
ruler = nlp.add_pipe("entity_ruler", before="entity_linker")
ruler.add_patterns(patterns)
# test the trained model - this should not throw E148
doc = nlp(text)
assert doc
def test_no_entities(): def test_no_entities():