Merge EL sentence crossing tests.

This commit is contained in:
Raphael Mitsch 2023-03-13 10:45:12 +01:00
parent 73c830c0e7
commit e4c02b35a6

View File

@ -1,9 +1,9 @@
from typing import Callable, Iterable, Dict, Any
from typing import Callable, Iterable, Dict, Any, Tuple
import pytest
from numpy.testing import assert_equal
from spacy import registry, util
from spacy import registry, util, Language
from spacy.attrs import ENT_KB_ID
from spacy.compat import pickle
from spacy.kb import Candidate, InMemoryLookupKB, get_candidates, KnowledgeBase
@ -108,18 +108,24 @@ def test_issue7065():
@pytest.mark.issue(7065)
def test_issue7065_b():
@pytest.mark.parametrize("entity_in_first_sentence", [True, False])
def test_sentence_crossing_ents(entity_in_first_sentence: bool):
"""Tests if NEL crashes if entities cross sentence boundaries and the first associated sentence doesn't have an
entity.
entity_in_prior_sentence (bool): Whether to include an entity in the first sentence associated with the
sentence-crossing entity.
"""
# Test that the NEL doesn't crash when an entity crosses a sentence boundary
nlp = English()
vector_length = 3
nlp.add_pipe("sentencizer")
text = "Mahler 's Symphony No. 8 was beautiful."
entities = [(0, 6, "PERSON"), (10, 24, "WORK")]
links = {
(0, 6): {"Q7304": 1.0, "Q270853": 0.0},
(10, 24): {"Q7304": 0.0, "Q270853": 1.0},
}
sent_starts = [1, -1, 0, 0, 0, 0, 0, 0, 0]
entities = [(10, 24, "WORK")]
links = {(10, 24): {"Q7304": 0.0, "Q270853": 1.0}}
if entity_in_first_sentence:
entities.append((0, 6, "PERSON"))
links[(0, 6)] = {"Q7304": 1.0, "Q270853": 0.0}
sent_starts = [1, -1, 0, 0, 0, 1, 0, 0, 0]
doc = nlp(text)
example = Example.from_dict(
doc, {"entities": entities, "links": links, "sent_starts": sent_starts}
@ -152,24 +158,8 @@ def test_issue7065_b():
losses = {}
nlp.update(train_examples, sgd=optimizer, losses=losses)
# Add a custom rule-based component to mimick NER
patterns = [
{"label": "PERSON", "pattern": [{"LOWER": "mahler"}]},
{
"label": "WORK",
"pattern": [
{"LOWER": "symphony"},
{"LOWER": "no"},
{"LOWER": "."},
{"LOWER": "8"},
],
},
]
ruler = nlp.add_pipe("entity_ruler", before="entity_linker")
ruler.add_patterns(patterns)
# test the trained model - this should not throw E148
doc = nlp(text)
assert doc
# This shouldn't crash.
entity_linker.predict([example.reference])
def test_no_entities():
@ -1219,47 +1209,3 @@ def test_span_maker_forward_with_empty():
# just to get a model
span_maker = build_span_maker()
span_maker([doc1, doc2], False)
def test_sentence_crossing_ents():
"""Tests if NEL crashes if entities cross sentence boundaries and the first associated sentence doesn't have an
entity.
"""
nlp = English()
vector_length = 3
nlp.add_pipe("sentencizer")
text = "Mahler 's Symphony No. 8 was beautiful."
entities = [(10, 24, "WORK")]
links = {
(10, 24): {"Q7304": 0.0, "Q270853": 1.0},
}
sent_starts = [1, -1, 0, 0, 0, 1, 0, 0, 0]
doc = nlp(text)
example = Example.from_dict(
doc, {"entities": entities, "links": links, "sent_starts": sent_starts}
)
assert len(list(example.reference.ents[0].sents)) == 2
train_examples = [example]
def create_kb(vocab):
# create artificial KB
mykb = InMemoryLookupKB(vocab, entity_vector_length=vector_length)
mykb.add_entity(entity="Q270853", freq=12, entity_vector=[9, 1, -7])
mykb.add_alias(
alias="No. 8",
entities=["Q270853"],
probabilities=[1.0],
)
return mykb
# Create the Entity Linker component and add it to the pipeline
entity_linker = nlp.add_pipe("entity_linker", last=True)
entity_linker.set_kb(create_kb)
# train the NEL pipe
optimizer = nlp.initialize(get_examples=lambda: train_examples)
for i in range(2):
losses = {}
nlp.update(train_examples, sgd=optimizer, losses=losses)
# This shouldn't crash.
entity_linker.predict([example.reference])