From d0abc321d8922951fb42bbbe8a6aa9d9927287a1 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Mon, 6 Mar 2023 10:27:33 +0100 Subject: [PATCH] Format. --- spacy/pipeline/entity_linker.py | 24 ++++++++++++++++++---- spacy/tests/pipeline/test_entity_linker.py | 21 ++++++++++++++----- 2 files changed, 36 insertions(+), 9 deletions(-) diff --git a/spacy/pipeline/entity_linker.py b/spacy/pipeline/entity_linker.py index 1cfec87ed..cf3d7839e 100644 --- a/spacy/pipeline/entity_linker.py +++ b/spacy/pipeline/entity_linker.py @@ -1,4 +1,14 @@ -from typing import Optional, Iterable, Callable, Dict, Sequence, Union, List, Any, Iterator +from typing import ( + Optional, + Iterable, + Callable, + Dict, + Sequence, + Union, + List, + Any, + Iterator, +) from typing import cast from numpy import dtype from thinc.types import Floats1d, Floats2d, Ints1d, Ragged @@ -79,7 +89,9 @@ def make_entity_linker( incl_prior: bool, incl_context: bool, entity_vector_length: int, - get_candidates: Callable[[KnowledgeBase, Iterator[SpanGroup]], Iterator[Iterable[Iterable[Candidate]]]], + get_candidates: Callable[ + [KnowledgeBase, Iterator[SpanGroup]], Iterator[Iterable[Iterable[Candidate]]] + ], generate_empty_kb: Callable[[Vocab, int], KnowledgeBase], overwrite: bool, scorer: Optional[Callable], @@ -177,7 +189,10 @@ class EntityLinker(TrainablePipe): incl_prior: bool, incl_context: bool, entity_vector_length: int, - get_candidates: Callable[[KnowledgeBase, Iterator[SpanGroup]], Iterator[Iterable[Iterable[Candidate]]]], + get_candidates: Callable[ + [KnowledgeBase, Iterator[SpanGroup]], + Iterator[Iterable[Iterable[Candidate]]], + ], generate_empty_kb: Callable[[Vocab, int], KnowledgeBase], overwrite: bool = False, scorer: Optional[Callable] = entity_linker_score, @@ -313,7 +328,8 @@ class EntityLinker(TrainablePipe): If one isn't present, then the update step needs to be skipped. """ for candidates_for_doc in self.get_candidates( - self.kb, (SpanGroup(doc=eg.predicted, spans=eg.predicted.ents) for eg in examples) + self.kb, + (SpanGroup(doc=eg.predicted, spans=eg.predicted.ents) for eg in examples), ): for candidates_for_mention in candidates_for_doc: if list(candidates_for_mention): diff --git a/spacy/tests/pipeline/test_entity_linker.py b/spacy/tests/pipeline/test_entity_linker.py index 90e10ef6c..e84dc7382 100644 --- a/spacy/tests/pipeline/test_entity_linker.py +++ b/spacy/tests/pipeline/test_entity_linker.py @@ -467,9 +467,16 @@ def test_candidate_generation(nlp): # test the size of the relevant candidates adam_ent_cands = next(get_candidates(mykb, SpanGroup(doc=doc, spans=[adam_ent])))[0] assert len(adam_ent_cands) == 1 - assert len(next(get_candidates(mykb, SpanGroup(doc=doc, spans=[douglas_ent])))[0]) == 2 - assert len(next(get_candidates(mykb, SpanGroup(doc=doc, spans=[Adam_ent])))[0]) == 0 # default case sensitive - assert len(next(get_candidates(mykb, SpanGroup(doc=doc, spans=[shrubbery_ent])))[0]) == 0 + assert ( + len(next(get_candidates(mykb, SpanGroup(doc=doc, spans=[douglas_ent])))[0]) == 2 + ) + assert ( + len(next(get_candidates(mykb, SpanGroup(doc=doc, spans=[Adam_ent])))[0]) == 0 + ) # default case sensitive + assert ( + len(next(get_candidates(mykb, SpanGroup(doc=doc, spans=[shrubbery_ent])))[0]) + == 0 + ) # test the content of the candidates assert adam_ent_cands[0].entity_ == "Q2" @@ -504,11 +511,15 @@ def test_el_pipe_configuration(nlp): def get_lowercased_candidates(kb: InMemoryLookupKB, mentions: Iterator[SpanGroup]): for mentions_for_doc in mentions: - yield [kb.get_alias_candidates(ent_span.text.lower()) for ent_span in mentions_for_doc] + yield [ + kb.get_alias_candidates(ent_span.text.lower()) + for ent_span in mentions_for_doc + ] @registry.misc("spacy.LowercaseCandidateGenerator.v1") def create_candidates() -> Callable[ - [InMemoryLookupKB, Iterator[SpanGroup]], Iterator[Iterable[Iterable[InMemoryCandidate]]] + [InMemoryLookupKB, Iterator[SpanGroup]], + Iterator[Iterable[Iterable[InMemoryCandidate]]], ]: return get_lowercased_candidates