Fix type aliases.

This commit is contained in:
Raphael Mitsch 2024-02-01 14:51:49 +01:00
parent 4c7bd3026d
commit 7d6ae1b960
5 changed files with 9 additions and 14 deletions

View File

@ -2,7 +2,6 @@
from cymem.cymem cimport Pool from cymem.cymem cimport Pool
from libc.stdint cimport int64_t from libc.stdint cimport int64_t
from ..vocab cimport Vocab from ..vocab cimport Vocab

View File

@ -8,7 +8,7 @@ from cymem.cymem cimport Pool
from ..errors import Errors from ..errors import Errors
from ..tokens import SpanGroup from ..tokens import SpanGroup
from ..util import SimpleFrozenList from ..util import SimpleFrozenList
from .typedefs cimport CandidatesForMention from .candidate cimport Candidate
cdef class KnowledgeBase: cdef class KnowledgeBase:
@ -19,6 +19,8 @@ cdef class KnowledgeBase:
DOCS: https://spacy.io/api/kb DOCS: https://spacy.io/api/kb
""" """
CandidatesForMentionT = Iterable[Candidate]
CandidatesForDocT = Iterable[CandidatesForMentionT]
def __init__(self, vocab: Vocab, entity_vector_length: int): def __init__(self, vocab: Vocab, entity_vector_length: int):
"""Create a KnowledgeBase.""" """Create a KnowledgeBase."""
@ -32,14 +34,14 @@ cdef class KnowledgeBase:
self.entity_vector_length = entity_vector_length self.entity_vector_length = entity_vector_length
self.mem = Pool() self.mem = Pool()
def get_candidates(self, mentions: Iterator[SpanGroup]) -> Iterator[Iterable[CandidatesForMention]]: def get_candidates(self, mentions: Iterator[SpanGroup]) -> Iterator[CandidatesForDocT]:
""" """
Return candidate entities for a specified Span mention. Each candidate defines at least the entity and the Return candidate entities for a specified Span mention. Each candidate defines at least the entity and the
entity's embedding vector. Depending on the KB implementation, further properties - such as the prior entity's embedding vector. Depending on the KB implementation, further properties - such as the prior
probability of the specified mention text resolving to that entity - might be included. probability of the specified mention text resolving to that entity - might be included.
If no candidates are found for a given mention, an empty list is returned. If no candidates are found for a given mention, an empty list is returned.
mentions (Iterator[SpanGroup]): Mentions for which to get candidates. mentions (Iterator[SpanGroup]): Mentions for which to get candidates.
RETURNS (Iterator[Iterable[Iterable[Candidate]]]): Identified candidates. RETURNS (Iterator[Iterable[Iterable[Candidate]]]): Identified candidates per mention/doc/doc batch.
""" """
raise NotImplementedError( raise NotImplementedError(
Errors.E1045.format( Errors.E1045.format(

View File

@ -1,5 +0,0 @@
from typing import Iterable
from .candidate import Candidate
ctypedef Iterable[Candidate] CandidatesForMention

View File

@ -1 +0,0 @@
# cython: profile=False

View File

@ -1,6 +1,6 @@
import random import random
import warnings import warnings
from itertools import islice from itertools import islice, tee
from pathlib import Path from pathlib import Path
from typing import ( from typing import (
Any, Any,
@ -446,7 +446,7 @@ class EntityLinker(TrainablePipe):
if isinstance(docs, Doc): if isinstance(docs, Doc):
docs = [docs] docs = [docs]
docs = list(docs) docs_iters = tee(docs, 2)
# Call candidate generator. # Call candidate generator.
all_ent_cands = self.get_candidates( all_ent_cands = self.get_candidates(
@ -458,11 +458,11 @@ class EntityLinker(TrainablePipe):
ent for ent in doc.ents if ent.label_ not in self.labels_discard ent for ent in doc.ents if ent.label_ not in self.labels_discard
], ],
) )
for doc in docs for doc in docs_iters[0]
), ),
) )
for doc in docs: for doc in docs_iters[1]:
doc_ents: List[Ints1d] = [] doc_ents: List[Ints1d] = []
doc_scores: List[Floats1d] = [] doc_scores: List[Floats1d] = []
if len(doc) == 0 or len(doc.ents) == 0: if len(doc) == 0 or len(doc.ents) == 0: