Add test for multiple docs with multiple entities.

This commit is contained in:
Raphael Mitsch 2024-02-19 10:53:12 +01:00
parent e83a988a42
commit eef3de098f

View File

@ -476,6 +476,86 @@ def test_candidate_generation(nlp):
assert_almost_equal(adam_ent_cands[0].prior_prob, 0.9)
def test_candidate_generation_multiple_docs(nlp):
"""Test correct candidate generation with multiple docs."""
mykb = InMemoryLookupKB(nlp.vocab, entity_vector_length=1)
docs = [nlp("douglas adam Adam shrubbery"), nlp("shrubbery Adam douglas adam")]
douglas_ents = [docs[0][0:1], docs[1][2:3]]
adam_ents = [docs[0][1:2], docs[1][3:4]]
Adam_ents = [docs[0][2:3], docs[1][1:2]]
shrubbery_ents = [docs[0][3:4], docs[1][0:1]]
# adding entities
mykb.add_entity(entity="Q1", freq=27, entity_vector=[1])
mykb.add_entity(entity="Q2", freq=12, entity_vector=[2])
mykb.add_entity(entity="Q3", freq=5, entity_vector=[3])
# adding aliases
mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.1])
mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
# test the size of the relevant candidates
adam_ent_cands = list(
get_candidates_v2(
mykb,
[
SpanGroup(doc=docs[0], spans=[adam_ents[0]]),
SpanGroup(doc=docs[1], spans=[adam_ents[1]]),
],
)
)
assert len(adam_ent_cands) == 2
assert (
len(
list(
get_candidates_v2(
mykb,
[
SpanGroup(doc=docs[0], spans=[douglas_ents[0]]),
SpanGroup(doc=docs[1], spans=[douglas_ents[1]]),
],
)
)
)
== 2
)
Adam_ent_cands = list(
get_candidates_v2(
mykb,
[
SpanGroup(doc=docs[0], spans=[Adam_ents[0]]),
SpanGroup(doc=docs[1], spans=[Adam_ents[1]]),
],
)
)
assert len(Adam_ent_cands) == 2
assert (
len(Adam_ent_cands[0][0]) == 0 and len(Adam_ent_cands[1][0]) == 0
) # default case sensitive
shrubbery_ents_cands = list(
get_candidates_v2(
mykb,
[
SpanGroup(doc=docs[0], spans=[shrubbery_ents[0]]),
SpanGroup(doc=docs[1], spans=[shrubbery_ents[1]]),
],
)
)
assert len(shrubbery_ents_cands) == 2
assert len(shrubbery_ents_cands[0][0]) == 0 and len(shrubbery_ents_cands[1][0]) == 0
# test the content of the candidates
assert (
adam_ent_cands[0][0][0].entity_id_ == adam_ent_cands[1][0][0].entity_id_ == "Q2"
)
assert adam_ent_cands[0][0][0].alias == adam_ent_cands[1][0][0].alias == "adam"
assert_almost_equal(adam_ent_cands[0][0][0].entity_freq, 12)
assert_almost_equal(adam_ent_cands[1][0][0].entity_freq, 12)
assert_almost_equal(adam_ent_cands[0][0][0].prior_prob, 0.9)
assert_almost_equal(adam_ent_cands[1][0][0].prior_prob, 0.9)
def test_el_pipe_configuration(nlp):
"""Test correct candidate generation as part of the EL pipe"""
nlp.add_pipe("sentencizer")