mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 01:16:28 +03:00
Fix EL failure with sentence-crossing entities (#12398)
* Add test reproducing EL failure in sentence-crossing entities. * Format. * Draft fix. * Format. * Fix case for len(ent.sents) == 1. * Format. * Format. * Format. * Fix mypy error. * Merge EL sentence crossing tests. * Remove unneeded sentencizer component. * Fix or ignore mypy issues in test. * Simplify ent.sents handling. * Format. Update assert in ent.sents handling. * Small rewrite --------- Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
This commit is contained in:
parent
2ce9a220db
commit
96b61d0671
|
@ -474,18 +474,24 @@ class EntityLinker(TrainablePipe):
|
||||||
|
|
||||||
# Looping through each entity in batch (TODO: rewrite)
|
# Looping through each entity in batch (TODO: rewrite)
|
||||||
for j, ent in enumerate(ent_batch):
|
for j, ent in enumerate(ent_batch):
|
||||||
sent_index = sentences.index(ent.sent)
|
assert hasattr(ent, "sents")
|
||||||
assert sent_index >= 0
|
sents = list(ent.sents)
|
||||||
|
sent_indices = (
|
||||||
|
sentences.index(sents[0]),
|
||||||
|
sentences.index(sents[-1]),
|
||||||
|
)
|
||||||
|
assert sent_indices[1] >= sent_indices[0] >= 0
|
||||||
|
|
||||||
if self.incl_context:
|
if self.incl_context:
|
||||||
# get n_neighbour sentences, clipped to the length of the document
|
# get n_neighbour sentences, clipped to the length of the document
|
||||||
start_sentence = max(0, sent_index - self.n_sents)
|
start_sentence = max(0, sent_indices[0] - self.n_sents)
|
||||||
end_sentence = min(
|
end_sentence = min(
|
||||||
len(sentences) - 1, sent_index + self.n_sents
|
len(sentences) - 1, sent_indices[1] + self.n_sents
|
||||||
)
|
)
|
||||||
start_token = sentences[start_sentence].start
|
start_token = sentences[start_sentence].start
|
||||||
end_token = sentences[end_sentence].end
|
end_token = sentences[end_sentence].end
|
||||||
sent_doc = doc[start_token:end_token].as_doc()
|
sent_doc = doc[start_token:end_token].as_doc()
|
||||||
|
|
||||||
# currently, the context is the same for each entity in a sentence (should be refined)
|
# currently, the context is the same for each entity in a sentence (should be refined)
|
||||||
sentence_encoding = self.model.predict([sent_doc])[0]
|
sentence_encoding = self.model.predict([sent_doc])[0]
|
||||||
sentence_encoding_t = sentence_encoding.T
|
sentence_encoding_t = sentence_encoding.T
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
from typing import Callable, Iterable, Dict, Any
|
from typing import Callable, Iterable, Dict, Any, Tuple
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from numpy.testing import assert_equal
|
from numpy.testing import assert_equal
|
||||||
|
|
||||||
from spacy import registry, util
|
from spacy import registry, util, Language
|
||||||
from spacy.attrs import ENT_KB_ID
|
from spacy.attrs import ENT_KB_ID
|
||||||
from spacy.compat import pickle
|
from spacy.compat import pickle
|
||||||
from spacy.kb import Candidate, InMemoryLookupKB, get_candidates, KnowledgeBase
|
from spacy.kb import Candidate, InMemoryLookupKB, get_candidates, KnowledgeBase
|
||||||
|
@ -108,18 +108,23 @@ def test_issue7065():
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.issue(7065)
|
@pytest.mark.issue(7065)
|
||||||
def test_issue7065_b():
|
@pytest.mark.parametrize("entity_in_first_sentence", [True, False])
|
||||||
|
def test_sentence_crossing_ents(entity_in_first_sentence: bool):
|
||||||
|
"""Tests if NEL crashes if entities cross sentence boundaries and the first associated sentence doesn't have an
|
||||||
|
entity.
|
||||||
|
entity_in_prior_sentence (bool): Whether to include an entity in the first sentence associated with the
|
||||||
|
sentence-crossing entity.
|
||||||
|
"""
|
||||||
# Test that the NEL doesn't crash when an entity crosses a sentence boundary
|
# Test that the NEL doesn't crash when an entity crosses a sentence boundary
|
||||||
nlp = English()
|
nlp = English()
|
||||||
vector_length = 3
|
vector_length = 3
|
||||||
nlp.add_pipe("sentencizer")
|
|
||||||
text = "Mahler 's Symphony No. 8 was beautiful."
|
text = "Mahler 's Symphony No. 8 was beautiful."
|
||||||
entities = [(0, 6, "PERSON"), (10, 24, "WORK")]
|
entities = [(10, 24, "WORK")]
|
||||||
links = {
|
links = {(10, 24): {"Q7304": 0.0, "Q270853": 1.0}}
|
||||||
(0, 6): {"Q7304": 1.0, "Q270853": 0.0},
|
if entity_in_first_sentence:
|
||||||
(10, 24): {"Q7304": 0.0, "Q270853": 1.0},
|
entities.append((0, 6, "PERSON"))
|
||||||
}
|
links[(0, 6)] = {"Q7304": 1.0, "Q270853": 0.0}
|
||||||
sent_starts = [1, -1, 0, 0, 0, 0, 0, 0, 0]
|
sent_starts = [1, -1, 0, 0, 0, 1, 0, 0, 0]
|
||||||
doc = nlp(text)
|
doc = nlp(text)
|
||||||
example = Example.from_dict(
|
example = Example.from_dict(
|
||||||
doc, {"entities": entities, "links": links, "sent_starts": sent_starts}
|
doc, {"entities": entities, "links": links, "sent_starts": sent_starts}
|
||||||
|
@ -145,31 +150,14 @@ def test_issue7065_b():
|
||||||
|
|
||||||
# Create the Entity Linker component and add it to the pipeline
|
# Create the Entity Linker component and add it to the pipeline
|
||||||
entity_linker = nlp.add_pipe("entity_linker", last=True)
|
entity_linker = nlp.add_pipe("entity_linker", last=True)
|
||||||
entity_linker.set_kb(create_kb)
|
entity_linker.set_kb(create_kb) # type: ignore
|
||||||
# train the NEL pipe
|
# train the NEL pipe
|
||||||
optimizer = nlp.initialize(get_examples=lambda: train_examples)
|
optimizer = nlp.initialize(get_examples=lambda: train_examples)
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
losses = {}
|
nlp.update(train_examples, sgd=optimizer)
|
||||||
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
|
||||||
|
|
||||||
# Add a custom rule-based component to mimick NER
|
# This shouldn't crash.
|
||||||
patterns = [
|
entity_linker.predict([example.reference]) # type: ignore
|
||||||
{"label": "PERSON", "pattern": [{"LOWER": "mahler"}]},
|
|
||||||
{
|
|
||||||
"label": "WORK",
|
|
||||||
"pattern": [
|
|
||||||
{"LOWER": "symphony"},
|
|
||||||
{"LOWER": "no"},
|
|
||||||
{"LOWER": "."},
|
|
||||||
{"LOWER": "8"},
|
|
||||||
],
|
|
||||||
},
|
|
||||||
]
|
|
||||||
ruler = nlp.add_pipe("entity_ruler", before="entity_linker")
|
|
||||||
ruler.add_patterns(patterns)
|
|
||||||
# test the trained model - this should not throw E148
|
|
||||||
doc = nlp(text)
|
|
||||||
assert doc
|
|
||||||
|
|
||||||
|
|
||||||
def test_no_entities():
|
def test_no_entities():
|
||||||
|
|
Loading…
Reference in New Issue
Block a user