diff --git a/spacy/tests/pipeline/test_entity_linker.py b/spacy/tests/pipeline/test_entity_linker.py index f6f6f6fd0..493cbe1fa 100644 --- a/spacy/tests/pipeline/test_entity_linker.py +++ b/spacy/tests/pipeline/test_entity_linker.py @@ -476,6 +476,86 @@ def test_candidate_generation(nlp): assert_almost_equal(adam_ent_cands[0].prior_prob, 0.9) +def test_candidate_generation_multiple_docs(nlp): + """Test correct candidate generation with multiple docs.""" + mykb = InMemoryLookupKB(nlp.vocab, entity_vector_length=1) + docs = [nlp("douglas adam Adam shrubbery"), nlp("shrubbery Adam douglas adam")] + + douglas_ents = [docs[0][0:1], docs[1][2:3]] + adam_ents = [docs[0][1:2], docs[1][3:4]] + Adam_ents = [docs[0][2:3], docs[1][1:2]] + shrubbery_ents = [docs[0][3:4], docs[1][0:1]] + + # adding entities + mykb.add_entity(entity="Q1", freq=27, entity_vector=[1]) + mykb.add_entity(entity="Q2", freq=12, entity_vector=[2]) + mykb.add_entity(entity="Q3", freq=5, entity_vector=[3]) + + # adding aliases + mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.1]) + mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9]) + + # test the size of the relevant candidates + adam_ent_cands = list( + get_candidates_v2( + mykb, + [ + SpanGroup(doc=docs[0], spans=[adam_ents[0]]), + SpanGroup(doc=docs[1], spans=[adam_ents[1]]), + ], + ) + ) + assert len(adam_ent_cands) == 2 + assert ( + len( + list( + get_candidates_v2( + mykb, + [ + SpanGroup(doc=docs[0], spans=[douglas_ents[0]]), + SpanGroup(doc=docs[1], spans=[douglas_ents[1]]), + ], + ) + ) + ) + == 2 + ) + Adam_ent_cands = list( + get_candidates_v2( + mykb, + [ + SpanGroup(doc=docs[0], spans=[Adam_ents[0]]), + SpanGroup(doc=docs[1], spans=[Adam_ents[1]]), + ], + ) + ) + assert len(Adam_ent_cands) == 2 + assert ( + len(Adam_ent_cands[0][0]) == 0 and len(Adam_ent_cands[1][0]) == 0 + ) # default case sensitive + shrubbery_ents_cands = list( + get_candidates_v2( + mykb, + [ + SpanGroup(doc=docs[0], spans=[shrubbery_ents[0]]), + SpanGroup(doc=docs[1], spans=[shrubbery_ents[1]]), + ], + ) + ) + assert len(shrubbery_ents_cands) == 2 + assert len(shrubbery_ents_cands[0][0]) == 0 and len(shrubbery_ents_cands[1][0]) == 0 + + # test the content of the candidates + assert ( + adam_ent_cands[0][0][0].entity_id_ == adam_ent_cands[1][0][0].entity_id_ == "Q2" + ) + assert adam_ent_cands[0][0][0].alias == adam_ent_cands[1][0][0].alias == "adam" + assert_almost_equal(adam_ent_cands[0][0][0].entity_freq, 12) + assert_almost_equal(adam_ent_cands[1][0][0].entity_freq, 12) + assert_almost_equal(adam_ent_cands[0][0][0].prior_prob, 0.9) + assert_almost_equal(adam_ent_cands[1][0][0].prior_prob, 0.9) + + def test_el_pipe_configuration(nlp): """Test correct candidate generation as part of the EL pipe""" nlp.add_pipe("sentencizer")