mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-10 08:12:24 +03:00
Format.
This commit is contained in:
parent
8b24f31b65
commit
d0abc321d8
|
@ -1,4 +1,14 @@
|
||||||
from typing import Optional, Iterable, Callable, Dict, Sequence, Union, List, Any, Iterator
|
from typing import (
|
||||||
|
Optional,
|
||||||
|
Iterable,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Sequence,
|
||||||
|
Union,
|
||||||
|
List,
|
||||||
|
Any,
|
||||||
|
Iterator,
|
||||||
|
)
|
||||||
from typing import cast
|
from typing import cast
|
||||||
from numpy import dtype
|
from numpy import dtype
|
||||||
from thinc.types import Floats1d, Floats2d, Ints1d, Ragged
|
from thinc.types import Floats1d, Floats2d, Ints1d, Ragged
|
||||||
|
@ -79,7 +89,9 @@ def make_entity_linker(
|
||||||
incl_prior: bool,
|
incl_prior: bool,
|
||||||
incl_context: bool,
|
incl_context: bool,
|
||||||
entity_vector_length: int,
|
entity_vector_length: int,
|
||||||
get_candidates: Callable[[KnowledgeBase, Iterator[SpanGroup]], Iterator[Iterable[Iterable[Candidate]]]],
|
get_candidates: Callable[
|
||||||
|
[KnowledgeBase, Iterator[SpanGroup]], Iterator[Iterable[Iterable[Candidate]]]
|
||||||
|
],
|
||||||
generate_empty_kb: Callable[[Vocab, int], KnowledgeBase],
|
generate_empty_kb: Callable[[Vocab, int], KnowledgeBase],
|
||||||
overwrite: bool,
|
overwrite: bool,
|
||||||
scorer: Optional[Callable],
|
scorer: Optional[Callable],
|
||||||
|
@ -177,7 +189,10 @@ class EntityLinker(TrainablePipe):
|
||||||
incl_prior: bool,
|
incl_prior: bool,
|
||||||
incl_context: bool,
|
incl_context: bool,
|
||||||
entity_vector_length: int,
|
entity_vector_length: int,
|
||||||
get_candidates: Callable[[KnowledgeBase, Iterator[SpanGroup]], Iterator[Iterable[Iterable[Candidate]]]],
|
get_candidates: Callable[
|
||||||
|
[KnowledgeBase, Iterator[SpanGroup]],
|
||||||
|
Iterator[Iterable[Iterable[Candidate]]],
|
||||||
|
],
|
||||||
generate_empty_kb: Callable[[Vocab, int], KnowledgeBase],
|
generate_empty_kb: Callable[[Vocab, int], KnowledgeBase],
|
||||||
overwrite: bool = False,
|
overwrite: bool = False,
|
||||||
scorer: Optional[Callable] = entity_linker_score,
|
scorer: Optional[Callable] = entity_linker_score,
|
||||||
|
@ -313,7 +328,8 @@ class EntityLinker(TrainablePipe):
|
||||||
If one isn't present, then the update step needs to be skipped.
|
If one isn't present, then the update step needs to be skipped.
|
||||||
"""
|
"""
|
||||||
for candidates_for_doc in self.get_candidates(
|
for candidates_for_doc in self.get_candidates(
|
||||||
self.kb, (SpanGroup(doc=eg.predicted, spans=eg.predicted.ents) for eg in examples)
|
self.kb,
|
||||||
|
(SpanGroup(doc=eg.predicted, spans=eg.predicted.ents) for eg in examples),
|
||||||
):
|
):
|
||||||
for candidates_for_mention in candidates_for_doc:
|
for candidates_for_mention in candidates_for_doc:
|
||||||
if list(candidates_for_mention):
|
if list(candidates_for_mention):
|
||||||
|
|
|
@ -467,9 +467,16 @@ def test_candidate_generation(nlp):
|
||||||
# test the size of the relevant candidates
|
# test the size of the relevant candidates
|
||||||
adam_ent_cands = next(get_candidates(mykb, SpanGroup(doc=doc, spans=[adam_ent])))[0]
|
adam_ent_cands = next(get_candidates(mykb, SpanGroup(doc=doc, spans=[adam_ent])))[0]
|
||||||
assert len(adam_ent_cands) == 1
|
assert len(adam_ent_cands) == 1
|
||||||
assert len(next(get_candidates(mykb, SpanGroup(doc=doc, spans=[douglas_ent])))[0]) == 2
|
assert (
|
||||||
assert len(next(get_candidates(mykb, SpanGroup(doc=doc, spans=[Adam_ent])))[0]) == 0 # default case sensitive
|
len(next(get_candidates(mykb, SpanGroup(doc=doc, spans=[douglas_ent])))[0]) == 2
|
||||||
assert len(next(get_candidates(mykb, SpanGroup(doc=doc, spans=[shrubbery_ent])))[0]) == 0
|
)
|
||||||
|
assert (
|
||||||
|
len(next(get_candidates(mykb, SpanGroup(doc=doc, spans=[Adam_ent])))[0]) == 0
|
||||||
|
) # default case sensitive
|
||||||
|
assert (
|
||||||
|
len(next(get_candidates(mykb, SpanGroup(doc=doc, spans=[shrubbery_ent])))[0])
|
||||||
|
== 0
|
||||||
|
)
|
||||||
|
|
||||||
# test the content of the candidates
|
# test the content of the candidates
|
||||||
assert adam_ent_cands[0].entity_ == "Q2"
|
assert adam_ent_cands[0].entity_ == "Q2"
|
||||||
|
@ -504,11 +511,15 @@ def test_el_pipe_configuration(nlp):
|
||||||
|
|
||||||
def get_lowercased_candidates(kb: InMemoryLookupKB, mentions: Iterator[SpanGroup]):
|
def get_lowercased_candidates(kb: InMemoryLookupKB, mentions: Iterator[SpanGroup]):
|
||||||
for mentions_for_doc in mentions:
|
for mentions_for_doc in mentions:
|
||||||
yield [kb.get_alias_candidates(ent_span.text.lower()) for ent_span in mentions_for_doc]
|
yield [
|
||||||
|
kb.get_alias_candidates(ent_span.text.lower())
|
||||||
|
for ent_span in mentions_for_doc
|
||||||
|
]
|
||||||
|
|
||||||
@registry.misc("spacy.LowercaseCandidateGenerator.v1")
|
@registry.misc("spacy.LowercaseCandidateGenerator.v1")
|
||||||
def create_candidates() -> Callable[
|
def create_candidates() -> Callable[
|
||||||
[InMemoryLookupKB, Iterator[SpanGroup]], Iterator[Iterable[Iterable[InMemoryCandidate]]]
|
[InMemoryLookupKB, Iterator[SpanGroup]],
|
||||||
|
Iterator[Iterable[Iterable[InMemoryCandidate]]],
|
||||||
]:
|
]:
|
||||||
return get_lowercased_candidates
|
return get_lowercased_candidates
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user