From cb640abe8157c2daa70cf8096b712541cea36a49 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Mon, 12 Dec 2022 14:04:34 +0100 Subject: [PATCH] Fix EL test. --- spacy/kb/kb_in_memory.pyx | 20 ++++++++++++-------- spacy/tests/pipeline/test_entity_linker.py | 14 +++++++++++++- 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/spacy/kb/kb_in_memory.pyx b/spacy/kb/kb_in_memory.pyx index 2b245d76f..97ae08e1e 100644 --- a/spacy/kb/kb_in_memory.pyx +++ b/spacy/kb/kb_in_memory.pyx @@ -246,14 +246,18 @@ cdef class InMemoryLookupKB(KnowledgeBase): alias_index = self._alias_index.get(alias_hash) alias_entry = self._aliases_table[alias_index] - return [Candidate(retrieve_string_from_hash=self.vocab.strings.__getitem__, - entity_hash=self._entries[entry_index].entity_hash, - entity_freq=self._entries[entry_index].freq, - entity_vector=self._vectors_table[self._entries[entry_index].vector_index], - alias_hash=alias_hash, - prior_prob=prior_prob) - for (entry_index, prior_prob) in zip(alias_entry.entry_indices, alias_entry.probs) - if entry_index != 0] + return [ + Candidate( + retrieve_string_from_hash=self.vocab.strings.__getitem__, + entity_hash=self._entries[entry_index].entity_hash, + entity_freq=self._entries[entry_index].freq, + entity_vector=self._vectors_table[self._entries[entry_index].vector_index], + alias_hash=alias_hash, + prior_prob=prior_prob + ) + for (entry_index, prior_prob) in zip(alias_entry.entry_indices, alias_entry.probs) + if entry_index != 0 + ] def get_vector(self, str entity): cdef hash_t entity_hash = self.vocab.strings[entity] diff --git a/spacy/tests/pipeline/test_entity_linker.py b/spacy/tests/pipeline/test_entity_linker.py index 4997631f3..c6030be41 100644 --- a/spacy/tests/pipeline/test_entity_linker.py +++ b/spacy/tests/pipeline/test_entity_linker.py @@ -1199,7 +1199,19 @@ def test_threshold(meet_threshold: bool, config: Dict[str, Any]): entity_linker = nlp.add_pipe( "entity_linker", last=True, - config={"threshold": None if meet_threshold else 1.0, "model": config}, + config={ + **( + {"threshold": None} + if meet_threshold + else { + "threshold": 1.0, + # Prior for candidate may be 1.0, rendering the our test setting with threshold 1.0 useless + # otherwise. + "incl_prior": False, + } + ), + "model": config, + }, ) entity_linker.set_kb(create_kb) # type: ignore nlp.initialize(get_examples=lambda: train_examples)