This commit is contained in:
Raphael Mitsch 2023-03-06 10:27:33 +01:00
parent 8b24f31b65
commit d0abc321d8
2 changed files with 36 additions and 9 deletions

View File

@ -1,4 +1,14 @@
from typing import Optional, Iterable, Callable, Dict, Sequence, Union, List, Any, Iterator from typing import (
Optional,
Iterable,
Callable,
Dict,
Sequence,
Union,
List,
Any,
Iterator,
)
from typing import cast from typing import cast
from numpy import dtype from numpy import dtype
from thinc.types import Floats1d, Floats2d, Ints1d, Ragged from thinc.types import Floats1d, Floats2d, Ints1d, Ragged
@ -79,7 +89,9 @@ def make_entity_linker(
incl_prior: bool, incl_prior: bool,
incl_context: bool, incl_context: bool,
entity_vector_length: int, entity_vector_length: int,
get_candidates: Callable[[KnowledgeBase, Iterator[SpanGroup]], Iterator[Iterable[Iterable[Candidate]]]], get_candidates: Callable[
[KnowledgeBase, Iterator[SpanGroup]], Iterator[Iterable[Iterable[Candidate]]]
],
generate_empty_kb: Callable[[Vocab, int], KnowledgeBase], generate_empty_kb: Callable[[Vocab, int], KnowledgeBase],
overwrite: bool, overwrite: bool,
scorer: Optional[Callable], scorer: Optional[Callable],
@ -177,7 +189,10 @@ class EntityLinker(TrainablePipe):
incl_prior: bool, incl_prior: bool,
incl_context: bool, incl_context: bool,
entity_vector_length: int, entity_vector_length: int,
get_candidates: Callable[[KnowledgeBase, Iterator[SpanGroup]], Iterator[Iterable[Iterable[Candidate]]]], get_candidates: Callable[
[KnowledgeBase, Iterator[SpanGroup]],
Iterator[Iterable[Iterable[Candidate]]],
],
generate_empty_kb: Callable[[Vocab, int], KnowledgeBase], generate_empty_kb: Callable[[Vocab, int], KnowledgeBase],
overwrite: bool = False, overwrite: bool = False,
scorer: Optional[Callable] = entity_linker_score, scorer: Optional[Callable] = entity_linker_score,
@ -313,7 +328,8 @@ class EntityLinker(TrainablePipe):
If one isn't present, then the update step needs to be skipped. If one isn't present, then the update step needs to be skipped.
""" """
for candidates_for_doc in self.get_candidates( for candidates_for_doc in self.get_candidates(
self.kb, (SpanGroup(doc=eg.predicted, spans=eg.predicted.ents) for eg in examples) self.kb,
(SpanGroup(doc=eg.predicted, spans=eg.predicted.ents) for eg in examples),
): ):
for candidates_for_mention in candidates_for_doc: for candidates_for_mention in candidates_for_doc:
if list(candidates_for_mention): if list(candidates_for_mention):

View File

@ -467,9 +467,16 @@ def test_candidate_generation(nlp):
# test the size of the relevant candidates # test the size of the relevant candidates
adam_ent_cands = next(get_candidates(mykb, SpanGroup(doc=doc, spans=[adam_ent])))[0] adam_ent_cands = next(get_candidates(mykb, SpanGroup(doc=doc, spans=[adam_ent])))[0]
assert len(adam_ent_cands) == 1 assert len(adam_ent_cands) == 1
assert len(next(get_candidates(mykb, SpanGroup(doc=doc, spans=[douglas_ent])))[0]) == 2 assert (
assert len(next(get_candidates(mykb, SpanGroup(doc=doc, spans=[Adam_ent])))[0]) == 0 # default case sensitive len(next(get_candidates(mykb, SpanGroup(doc=doc, spans=[douglas_ent])))[0]) == 2
assert len(next(get_candidates(mykb, SpanGroup(doc=doc, spans=[shrubbery_ent])))[0]) == 0 )
assert (
len(next(get_candidates(mykb, SpanGroup(doc=doc, spans=[Adam_ent])))[0]) == 0
) # default case sensitive
assert (
len(next(get_candidates(mykb, SpanGroup(doc=doc, spans=[shrubbery_ent])))[0])
== 0
)
# test the content of the candidates # test the content of the candidates
assert adam_ent_cands[0].entity_ == "Q2" assert adam_ent_cands[0].entity_ == "Q2"
@ -504,11 +511,15 @@ def test_el_pipe_configuration(nlp):
def get_lowercased_candidates(kb: InMemoryLookupKB, mentions: Iterator[SpanGroup]): def get_lowercased_candidates(kb: InMemoryLookupKB, mentions: Iterator[SpanGroup]):
for mentions_for_doc in mentions: for mentions_for_doc in mentions:
yield [kb.get_alias_candidates(ent_span.text.lower()) for ent_span in mentions_for_doc] yield [
kb.get_alias_candidates(ent_span.text.lower())
for ent_span in mentions_for_doc
]
@registry.misc("spacy.LowercaseCandidateGenerator.v1") @registry.misc("spacy.LowercaseCandidateGenerator.v1")
def create_candidates() -> Callable[ def create_candidates() -> Callable[
[InMemoryLookupKB, Iterator[SpanGroup]], Iterator[Iterable[Iterable[InMemoryCandidate]]] [InMemoryLookupKB, Iterator[SpanGroup]],
Iterator[Iterable[Iterable[InMemoryCandidate]]],
]: ]:
return get_lowercased_candidates return get_lowercased_candidates