mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
Fix EL failure with sentence-crossing entities (#12398)
* Add test reproducing EL failure in sentence-crossing entities. * Format. * Draft fix. * Format. * Fix case for len(ent.sents) == 1. * Format. * Format. * Format. * Fix mypy error. * Merge EL sentence crossing tests. * Remove unneeded sentencizer component. * Fix or ignore mypy issues in test. * Simplify ent.sents handling. * Format. Update assert in ent.sents handling. * Small rewrite --------- Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
This commit is contained in:
parent
2ce9a220db
commit
96b61d0671
|
@ -474,18 +474,24 @@ class EntityLinker(TrainablePipe):
|
|||
|
||||
# Looping through each entity in batch (TODO: rewrite)
|
||||
for j, ent in enumerate(ent_batch):
|
||||
sent_index = sentences.index(ent.sent)
|
||||
assert sent_index >= 0
|
||||
assert hasattr(ent, "sents")
|
||||
sents = list(ent.sents)
|
||||
sent_indices = (
|
||||
sentences.index(sents[0]),
|
||||
sentences.index(sents[-1]),
|
||||
)
|
||||
assert sent_indices[1] >= sent_indices[0] >= 0
|
||||
|
||||
if self.incl_context:
|
||||
# get n_neighbour sentences, clipped to the length of the document
|
||||
start_sentence = max(0, sent_index - self.n_sents)
|
||||
start_sentence = max(0, sent_indices[0] - self.n_sents)
|
||||
end_sentence = min(
|
||||
len(sentences) - 1, sent_index + self.n_sents
|
||||
len(sentences) - 1, sent_indices[1] + self.n_sents
|
||||
)
|
||||
start_token = sentences[start_sentence].start
|
||||
end_token = sentences[end_sentence].end
|
||||
sent_doc = doc[start_token:end_token].as_doc()
|
||||
|
||||
# currently, the context is the same for each entity in a sentence (should be refined)
|
||||
sentence_encoding = self.model.predict([sent_doc])[0]
|
||||
sentence_encoding_t = sentence_encoding.T
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
from typing import Callable, Iterable, Dict, Any
|
||||
from typing import Callable, Iterable, Dict, Any, Tuple
|
||||
|
||||
import pytest
|
||||
from numpy.testing import assert_equal
|
||||
|
||||
from spacy import registry, util
|
||||
from spacy import registry, util, Language
|
||||
from spacy.attrs import ENT_KB_ID
|
||||
from spacy.compat import pickle
|
||||
from spacy.kb import Candidate, InMemoryLookupKB, get_candidates, KnowledgeBase
|
||||
|
@ -108,18 +108,23 @@ def test_issue7065():
|
|||
|
||||
|
||||
@pytest.mark.issue(7065)
|
||||
def test_issue7065_b():
|
||||
@pytest.mark.parametrize("entity_in_first_sentence", [True, False])
|
||||
def test_sentence_crossing_ents(entity_in_first_sentence: bool):
|
||||
"""Tests if NEL crashes if entities cross sentence boundaries and the first associated sentence doesn't have an
|
||||
entity.
|
||||
entity_in_prior_sentence (bool): Whether to include an entity in the first sentence associated with the
|
||||
sentence-crossing entity.
|
||||
"""
|
||||
# Test that the NEL doesn't crash when an entity crosses a sentence boundary
|
||||
nlp = English()
|
||||
vector_length = 3
|
||||
nlp.add_pipe("sentencizer")
|
||||
text = "Mahler 's Symphony No. 8 was beautiful."
|
||||
entities = [(0, 6, "PERSON"), (10, 24, "WORK")]
|
||||
links = {
|
||||
(0, 6): {"Q7304": 1.0, "Q270853": 0.0},
|
||||
(10, 24): {"Q7304": 0.0, "Q270853": 1.0},
|
||||
}
|
||||
sent_starts = [1, -1, 0, 0, 0, 0, 0, 0, 0]
|
||||
entities = [(10, 24, "WORK")]
|
||||
links = {(10, 24): {"Q7304": 0.0, "Q270853": 1.0}}
|
||||
if entity_in_first_sentence:
|
||||
entities.append((0, 6, "PERSON"))
|
||||
links[(0, 6)] = {"Q7304": 1.0, "Q270853": 0.0}
|
||||
sent_starts = [1, -1, 0, 0, 0, 1, 0, 0, 0]
|
||||
doc = nlp(text)
|
||||
example = Example.from_dict(
|
||||
doc, {"entities": entities, "links": links, "sent_starts": sent_starts}
|
||||
|
@ -145,31 +150,14 @@ def test_issue7065_b():
|
|||
|
||||
# Create the Entity Linker component and add it to the pipeline
|
||||
entity_linker = nlp.add_pipe("entity_linker", last=True)
|
||||
entity_linker.set_kb(create_kb)
|
||||
entity_linker.set_kb(create_kb) # type: ignore
|
||||
# train the NEL pipe
|
||||
optimizer = nlp.initialize(get_examples=lambda: train_examples)
|
||||
for i in range(2):
|
||||
losses = {}
|
||||
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||
nlp.update(train_examples, sgd=optimizer)
|
||||
|
||||
# Add a custom rule-based component to mimick NER
|
||||
patterns = [
|
||||
{"label": "PERSON", "pattern": [{"LOWER": "mahler"}]},
|
||||
{
|
||||
"label": "WORK",
|
||||
"pattern": [
|
||||
{"LOWER": "symphony"},
|
||||
{"LOWER": "no"},
|
||||
{"LOWER": "."},
|
||||
{"LOWER": "8"},
|
||||
],
|
||||
},
|
||||
]
|
||||
ruler = nlp.add_pipe("entity_ruler", before="entity_linker")
|
||||
ruler.add_patterns(patterns)
|
||||
# test the trained model - this should not throw E148
|
||||
doc = nlp(text)
|
||||
assert doc
|
||||
# This shouldn't crash.
|
||||
entity_linker.predict([example.reference]) # type: ignore
|
||||
|
||||
|
||||
def test_no_entities():
|
||||
|
|
Loading…
Reference in New Issue
Block a user