Fix EL test.

This commit is contained in:
Raphael Mitsch 2022-12-12 14:04:34 +01:00
parent 77680421b4
commit cb640abe81
2 changed files with 25 additions and 9 deletions

View File

@ -246,14 +246,18 @@ cdef class InMemoryLookupKB(KnowledgeBase):
alias_index = <int64_t>self._alias_index.get(alias_hash) alias_index = <int64_t>self._alias_index.get(alias_hash)
alias_entry = self._aliases_table[alias_index] alias_entry = self._aliases_table[alias_index]
return [Candidate(retrieve_string_from_hash=self.vocab.strings.__getitem__, return [
entity_hash=self._entries[entry_index].entity_hash, Candidate(
entity_freq=self._entries[entry_index].freq, retrieve_string_from_hash=self.vocab.strings.__getitem__,
entity_vector=self._vectors_table[self._entries[entry_index].vector_index], entity_hash=self._entries[entry_index].entity_hash,
alias_hash=alias_hash, entity_freq=self._entries[entry_index].freq,
prior_prob=prior_prob) entity_vector=self._vectors_table[self._entries[entry_index].vector_index],
for (entry_index, prior_prob) in zip(alias_entry.entry_indices, alias_entry.probs) alias_hash=alias_hash,
if entry_index != 0] prior_prob=prior_prob
)
for (entry_index, prior_prob) in zip(alias_entry.entry_indices, alias_entry.probs)
if entry_index != 0
]
def get_vector(self, str entity): def get_vector(self, str entity):
cdef hash_t entity_hash = self.vocab.strings[entity] cdef hash_t entity_hash = self.vocab.strings[entity]

View File

@ -1199,7 +1199,19 @@ def test_threshold(meet_threshold: bool, config: Dict[str, Any]):
entity_linker = nlp.add_pipe( entity_linker = nlp.add_pipe(
"entity_linker", "entity_linker",
last=True, last=True,
config={"threshold": None if meet_threshold else 1.0, "model": config}, config={
**(
{"threshold": None}
if meet_threshold
else {
"threshold": 1.0,
# Prior for candidate may be 1.0, rendering the our test setting with threshold 1.0 useless
# otherwise.
"incl_prior": False,
}
),
"model": config,
},
) )
entity_linker.set_kb(create_kb) # type: ignore entity_linker.set_kb(create_kb) # type: ignore
nlp.initialize(get_examples=lambda: train_examples) nlp.initialize(get_examples=lambda: train_examples)