mirror of
https://github.com/explosion/spaCy.git
synced 2025-05-28 09:43:17 +03:00
Reformat with black.
This commit is contained in:
parent
7c28424f47
commit
d8183121f6
|
@ -107,6 +107,7 @@ def create_candidates() -> Callable[[KnowledgeBase, Span], Iterable[Candidate]]:
|
|||
|
||||
@registry.misc("spacy.CandidateAllGenerator.v1")
|
||||
def create_candidates_all() -> Callable[
|
||||
[KnowledgeBase, Generator[Iterable[Span], None, None]], Generator[Iterable[Iterable[Candidate]], None, None]
|
||||
[KnowledgeBase, Generator[Iterable[Span], None, None]],
|
||||
Generator[Iterable[Iterable[Candidate]], None, None],
|
||||
]:
|
||||
return get_candidates_all
|
||||
|
|
|
@ -79,7 +79,7 @@ def make_entity_linker(
|
|||
get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]],
|
||||
get_candidates_all: Callable[
|
||||
[KnowledgeBase, Generator[Iterable[Span], None, None]],
|
||||
Generator[Iterable[Iterable[Candidate]], None, None]
|
||||
Generator[Iterable[Iterable[Candidate]], None, None],
|
||||
],
|
||||
overwrite: bool,
|
||||
scorer: Optional[Callable],
|
||||
|
@ -180,7 +180,7 @@ class EntityLinker(TrainablePipe):
|
|||
get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]],
|
||||
get_candidates_all: Callable[
|
||||
[KnowledgeBase, Generator[Iterable[Span], None, None]],
|
||||
Generator[Iterable[Iterable[Candidate]], None, None]
|
||||
Generator[Iterable[Iterable[Candidate]], None, None],
|
||||
],
|
||||
overwrite: bool = BACKWARD_OVERWRITE,
|
||||
scorer: Optional[Callable] = entity_linker_score,
|
||||
|
@ -456,19 +456,28 @@ class EntityLinker(TrainablePipe):
|
|||
for idx in range(len(doc.ents))
|
||||
if doc.ents[idx].label_ not in self.labels_discard
|
||||
]
|
||||
for doc in docs if len(doc.ents)
|
||||
for doc in docs
|
||||
if len(doc.ents)
|
||||
)
|
||||
# Call candidate generator.
|
||||
if self.candidates_doc_mode:
|
||||
all_ent_cands = self.get_candidates_all(
|
||||
self.kb,
|
||||
([doc.ents[idx] for idx in next(valid_ent_idx_per_doc)] for doc in docs if len(doc.ents))
|
||||
(
|
||||
[doc.ents[idx] for idx in next(valid_ent_idx_per_doc)]
|
||||
for doc in docs
|
||||
if len(doc.ents)
|
||||
),
|
||||
)
|
||||
else:
|
||||
# Alternative: collect entities the old-fashioned way - by retrieving entities individually.
|
||||
all_ent_cands = (
|
||||
[self.get_candidates(self.kb, doc.ents[idx]) for idx in next(valid_ent_idx_per_doc)]
|
||||
for doc in docs if len(doc.ents)
|
||||
[
|
||||
self.get_candidates(self.kb, doc.ents[idx])
|
||||
for idx in next(valid_ent_idx_per_doc)
|
||||
]
|
||||
for doc in docs
|
||||
if len(doc.ents)
|
||||
)
|
||||
|
||||
for doc_idx, doc in enumerate(docs):
|
||||
|
@ -485,9 +494,7 @@ class EntityLinker(TrainablePipe):
|
|||
if self.incl_context:
|
||||
# get n_neighbour sentences, clipped to the length of the document
|
||||
start_sentence = max(0, sent_index - self.n_sents)
|
||||
end_sentence = min(
|
||||
len(sentences) - 1, sent_index + self.n_sents
|
||||
)
|
||||
end_sentence = min(len(sentences) - 1, sent_index + self.n_sents)
|
||||
start_token = sentences[start_sentence].start
|
||||
end_token = sentences[end_sentence].end
|
||||
sent_doc = doc[start_token:end_token].as_doc()
|
||||
|
@ -536,8 +543,7 @@ class EntityLinker(TrainablePipe):
|
|||
scores = prior_probs + sims - (prior_probs * sims)
|
||||
final_kb_ids.append(
|
||||
candidates[scores.argmax().item()].entity_
|
||||
if self.threshold is None
|
||||
or scores.max() >= self.threshold
|
||||
if self.threshold is None or scores.max() >= self.threshold
|
||||
else EntityLinker.NIL
|
||||
)
|
||||
|
||||
|
|
|
@ -515,7 +515,7 @@ def test_el_pipe_configuration(nlp):
|
|||
@registry.misc("spacy.LowercaseCandidateAllGenerator.v1")
|
||||
def create_candidates_batch() -> Callable[
|
||||
[InMemoryLookupKB, Generator[Iterable["Span"], None, None]],
|
||||
Generator[Iterable[Iterable[Candidate]], None, None]
|
||||
Generator[Iterable[Iterable[Candidate]], None, None],
|
||||
]:
|
||||
return get_lowercased_candidates_all
|
||||
|
||||
|
@ -683,7 +683,6 @@ def test_preserving_links_asdoc(nlp):
|
|||
assert s_ent.kb_id_ == orig_kb_id
|
||||
|
||||
|
||||
|
||||
def test_preserving_links_ents(nlp):
|
||||
"""Test that doc.ents preserves KB annotations"""
|
||||
text = "She lives in Boston. He lives in Denver."
|
||||
|
|
Loading…
Reference in New Issue
Block a user