diff --git a/spacy/ml/models/entity_linker.py b/spacy/ml/models/entity_linker.py index 9aac71d40..455bcc3b1 100644 --- a/spacy/ml/models/entity_linker.py +++ b/spacy/ml/models/entity_linker.py @@ -107,6 +107,7 @@ def create_candidates() -> Callable[[KnowledgeBase, Span], Iterable[Candidate]]: @registry.misc("spacy.CandidateAllGenerator.v1") def create_candidates_all() -> Callable[ - [KnowledgeBase, Generator[Iterable[Span], None, None]], Generator[Iterable[Iterable[Candidate]], None, None] + [KnowledgeBase, Generator[Iterable[Span], None, None]], + Generator[Iterable[Iterable[Candidate]], None, None], ]: return get_candidates_all diff --git a/spacy/pipeline/entity_linker.py b/spacy/pipeline/entity_linker.py index 4d3baf2f3..eb546b3a0 100644 --- a/spacy/pipeline/entity_linker.py +++ b/spacy/pipeline/entity_linker.py @@ -79,7 +79,7 @@ def make_entity_linker( get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]], get_candidates_all: Callable[ [KnowledgeBase, Generator[Iterable[Span], None, None]], - Generator[Iterable[Iterable[Candidate]], None, None] + Generator[Iterable[Iterable[Candidate]], None, None], ], overwrite: bool, scorer: Optional[Callable], @@ -180,7 +180,7 @@ class EntityLinker(TrainablePipe): get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]], get_candidates_all: Callable[ [KnowledgeBase, Generator[Iterable[Span], None, None]], - Generator[Iterable[Iterable[Candidate]], None, None] + Generator[Iterable[Iterable[Candidate]], None, None], ], overwrite: bool = BACKWARD_OVERWRITE, scorer: Optional[Callable] = entity_linker_score, @@ -456,19 +456,28 @@ class EntityLinker(TrainablePipe): for idx in range(len(doc.ents)) if doc.ents[idx].label_ not in self.labels_discard ] - for doc in docs if len(doc.ents) + for doc in docs + if len(doc.ents) ) # Call candidate generator. if self.candidates_doc_mode: all_ent_cands = self.get_candidates_all( self.kb, - ([doc.ents[idx] for idx in next(valid_ent_idx_per_doc)] for doc in docs if len(doc.ents)) + ( + [doc.ents[idx] for idx in next(valid_ent_idx_per_doc)] + for doc in docs + if len(doc.ents) + ), ) else: # Alternative: collect entities the old-fashioned way - by retrieving entities individually. all_ent_cands = ( - [self.get_candidates(self.kb, doc.ents[idx]) for idx in next(valid_ent_idx_per_doc)] - for doc in docs if len(doc.ents) + [ + self.get_candidates(self.kb, doc.ents[idx]) + for idx in next(valid_ent_idx_per_doc) + ] + for doc in docs + if len(doc.ents) ) for doc_idx, doc in enumerate(docs): @@ -485,9 +494,7 @@ class EntityLinker(TrainablePipe): if self.incl_context: # get n_neighbour sentences, clipped to the length of the document start_sentence = max(0, sent_index - self.n_sents) - end_sentence = min( - len(sentences) - 1, sent_index + self.n_sents - ) + end_sentence = min(len(sentences) - 1, sent_index + self.n_sents) start_token = sentences[start_sentence].start end_token = sentences[end_sentence].end sent_doc = doc[start_token:end_token].as_doc() @@ -536,8 +543,7 @@ class EntityLinker(TrainablePipe): scores = prior_probs + sims - (prior_probs * sims) final_kb_ids.append( candidates[scores.argmax().item()].entity_ - if self.threshold is None - or scores.max() >= self.threshold + if self.threshold is None or scores.max() >= self.threshold else EntityLinker.NIL ) diff --git a/spacy/tests/pipeline/test_entity_linker.py b/spacy/tests/pipeline/test_entity_linker.py index a579b0fac..877a4c5ce 100644 --- a/spacy/tests/pipeline/test_entity_linker.py +++ b/spacy/tests/pipeline/test_entity_linker.py @@ -515,7 +515,7 @@ def test_el_pipe_configuration(nlp): @registry.misc("spacy.LowercaseCandidateAllGenerator.v1") def create_candidates_batch() -> Callable[ [InMemoryLookupKB, Generator[Iterable["Span"], None, None]], - Generator[Iterable[Iterable[Candidate]], None, None] + Generator[Iterable[Iterable[Candidate]], None, None], ]: return get_lowercased_candidates_all @@ -683,7 +683,6 @@ def test_preserving_links_asdoc(nlp): assert s_ent.kb_id_ == orig_kb_id - def test_preserving_links_ents(nlp): """Test that doc.ents preserves KB annotations""" text = "She lives in Boston. He lives in Denver."