Fix type aliases.

This commit is contained in:
Raphael Mitsch 2024-02-01 14:51:49 +01:00
parent 4c7bd3026d
commit 7d6ae1b960
5 changed files with 9 additions and 14 deletions

View File

@ -2,7 +2,6 @@
from cymem.cymem cimport Pool
from libc.stdint cimport int64_t
from ..vocab cimport Vocab

View File

@ -8,7 +8,7 @@ from cymem.cymem cimport Pool
from ..errors import Errors
from ..tokens import SpanGroup
from ..util import SimpleFrozenList
from .typedefs cimport CandidatesForMention
from .candidate cimport Candidate
cdef class KnowledgeBase:
@ -19,6 +19,8 @@ cdef class KnowledgeBase:
DOCS: https://spacy.io/api/kb
"""
CandidatesForMentionT = Iterable[Candidate]
CandidatesForDocT = Iterable[CandidatesForMentionT]
def __init__(self, vocab: Vocab, entity_vector_length: int):
"""Create a KnowledgeBase."""
@ -32,14 +34,14 @@ cdef class KnowledgeBase:
self.entity_vector_length = entity_vector_length
self.mem = Pool()
def get_candidates(self, mentions: Iterator[SpanGroup]) -> Iterator[Iterable[CandidatesForMention]]:
def get_candidates(self, mentions: Iterator[SpanGroup]) -> Iterator[CandidatesForDocT]:
"""
Return candidate entities for a specified Span mention. Each candidate defines at least the entity and the
entity's embedding vector. Depending on the KB implementation, further properties - such as the prior
probability of the specified mention text resolving to that entity - might be included.
If no candidates are found for a given mention, an empty list is returned.
mentions (Iterator[SpanGroup]): Mentions for which to get candidates.
RETURNS (Iterator[Iterable[Iterable[Candidate]]]): Identified candidates.
RETURNS (Iterator[Iterable[Iterable[Candidate]]]): Identified candidates per mention/doc/doc batch.
"""
raise NotImplementedError(
Errors.E1045.format(

View File

@ -1,5 +0,0 @@
from typing import Iterable
from .candidate import Candidate
ctypedef Iterable[Candidate] CandidatesForMention

View File

@ -1 +0,0 @@
# cython: profile=False

View File

@ -1,6 +1,6 @@
import random
import warnings
from itertools import islice
from itertools import islice, tee
from pathlib import Path
from typing import (
Any,
@ -446,7 +446,7 @@ class EntityLinker(TrainablePipe):
if isinstance(docs, Doc):
docs = [docs]
docs = list(docs)
docs_iters = tee(docs, 2)
# Call candidate generator.
all_ent_cands = self.get_candidates(
@ -458,11 +458,11 @@ class EntityLinker(TrainablePipe):
ent for ent in doc.ents if ent.label_ not in self.labels_discard
],
)
for doc in docs
for doc in docs_iters[0]
),
)
for doc in docs:
for doc in docs_iters[1]:
doc_ents: List[Ints1d] = []
doc_scores: List[Floats1d] = []
if len(doc) == 0 or len(doc.ents) == 0: