diff --git a/spacy/tests/pipeline/test_el.py b/spacy/tests/pipeline/test_el.py index 068a228d8..78ee0f358 100644 --- a/spacy/tests/pipeline/test_el.py +++ b/spacy/tests/pipeline/test_el.py @@ -63,3 +63,20 @@ def test_kb_invalid_combination(): with pytest.raises(ValueError): mykb.add_alias(alias="douglassss", entities=["Q2", "Q3"], probabilities=[0.3, 0.4, 0.1]) + +def test_candidate_generation(): + """Test correct candidate generation""" + mykb = KnowledgeBase() + + # adding entities + mykb.add_entity(entity_id="Q1", prob=0.9) + mykb.add_entity(entity_id="Q2", prob=0.2) + mykb.add_entity(entity_id="Q3", prob=0.5) + + # adding aliases + mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.2]) + mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9]) + + # test the size of the relevant candidates + assert(len(mykb.get_candidates("douglas")) == 2) + assert(len(mykb.get_candidates("adam")) == 1)