Reformat with black.

This commit is contained in:
Raphael Mitsch 2022-10-18 17:33:37 +02:00
parent 7c28424f47
commit d8183121f6
3 changed files with 20 additions and 14 deletions

View File

@ -107,6 +107,7 @@ def create_candidates() -> Callable[[KnowledgeBase, Span], Iterable[Candidate]]:
@registry.misc("spacy.CandidateAllGenerator.v1")
def create_candidates_all() -> Callable[
[KnowledgeBase, Generator[Iterable[Span], None, None]], Generator[Iterable[Iterable[Candidate]], None, None]
[KnowledgeBase, Generator[Iterable[Span], None, None]],
Generator[Iterable[Iterable[Candidate]], None, None],
]:
return get_candidates_all

View File

@ -79,7 +79,7 @@ def make_entity_linker(
get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]],
get_candidates_all: Callable[
[KnowledgeBase, Generator[Iterable[Span], None, None]],
Generator[Iterable[Iterable[Candidate]], None, None]
Generator[Iterable[Iterable[Candidate]], None, None],
],
overwrite: bool,
scorer: Optional[Callable],
@ -180,7 +180,7 @@ class EntityLinker(TrainablePipe):
get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]],
get_candidates_all: Callable[
[KnowledgeBase, Generator[Iterable[Span], None, None]],
Generator[Iterable[Iterable[Candidate]], None, None]
Generator[Iterable[Iterable[Candidate]], None, None],
],
overwrite: bool = BACKWARD_OVERWRITE,
scorer: Optional[Callable] = entity_linker_score,
@ -456,19 +456,28 @@ class EntityLinker(TrainablePipe):
for idx in range(len(doc.ents))
if doc.ents[idx].label_ not in self.labels_discard
]
for doc in docs if len(doc.ents)
for doc in docs
if len(doc.ents)
)
# Call candidate generator.
if self.candidates_doc_mode:
all_ent_cands = self.get_candidates_all(
self.kb,
([doc.ents[idx] for idx in next(valid_ent_idx_per_doc)] for doc in docs if len(doc.ents))
(
[doc.ents[idx] for idx in next(valid_ent_idx_per_doc)]
for doc in docs
if len(doc.ents)
),
)
else:
# Alternative: collect entities the old-fashioned way - by retrieving entities individually.
all_ent_cands = (
[self.get_candidates(self.kb, doc.ents[idx]) for idx in next(valid_ent_idx_per_doc)]
for doc in docs if len(doc.ents)
[
self.get_candidates(self.kb, doc.ents[idx])
for idx in next(valid_ent_idx_per_doc)
]
for doc in docs
if len(doc.ents)
)
for doc_idx, doc in enumerate(docs):
@ -485,9 +494,7 @@ class EntityLinker(TrainablePipe):
if self.incl_context:
# get n_neighbour sentences, clipped to the length of the document
start_sentence = max(0, sent_index - self.n_sents)
end_sentence = min(
len(sentences) - 1, sent_index + self.n_sents
)
end_sentence = min(len(sentences) - 1, sent_index + self.n_sents)
start_token = sentences[start_sentence].start
end_token = sentences[end_sentence].end
sent_doc = doc[start_token:end_token].as_doc()
@ -536,8 +543,7 @@ class EntityLinker(TrainablePipe):
scores = prior_probs + sims - (prior_probs * sims)
final_kb_ids.append(
candidates[scores.argmax().item()].entity_
if self.threshold is None
or scores.max() >= self.threshold
if self.threshold is None or scores.max() >= self.threshold
else EntityLinker.NIL
)

View File

@ -515,7 +515,7 @@ def test_el_pipe_configuration(nlp):
@registry.misc("spacy.LowercaseCandidateAllGenerator.v1")
def create_candidates_batch() -> Callable[
[InMemoryLookupKB, Generator[Iterable["Span"], None, None]],
Generator[Iterable[Iterable[Candidate]], None, None]
Generator[Iterable[Iterable[Candidate]], None, None],
]:
return get_lowercased_candidates_all
@ -683,7 +683,6 @@ def test_preserving_links_asdoc(nlp):
assert s_ent.kb_id_ == orig_kb_id
def test_preserving_links_ents(nlp):
"""Test that doc.ents preserves KB annotations"""
text = "She lives in Boston. He lives in Denver."