From 7d6ae1b960e7a3a09739ed359b55c344ce2fe0c6 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Thu, 1 Feb 2024 14:51:49 +0100 Subject: [PATCH] Fix type aliases. --- spacy/kb/kb.pxd | 1 - spacy/kb/kb.pyx | 8 +++++--- spacy/kb/typedefs.pxd | 5 ----- spacy/kb/typedefs.pyx | 1 - spacy/pipeline/entity_linker.py | 8 ++++---- 5 files changed, 9 insertions(+), 14 deletions(-) delete mode 100644 spacy/kb/typedefs.pxd delete mode 100644 spacy/kb/typedefs.pyx diff --git a/spacy/kb/kb.pxd b/spacy/kb/kb.pxd index 263469546..c7652bca8 100644 --- a/spacy/kb/kb.pxd +++ b/spacy/kb/kb.pxd @@ -2,7 +2,6 @@ from cymem.cymem cimport Pool from libc.stdint cimport int64_t - from ..vocab cimport Vocab diff --git a/spacy/kb/kb.pyx b/spacy/kb/kb.pyx index ac78d3f0f..0e2ba7c65 100644 --- a/spacy/kb/kb.pyx +++ b/spacy/kb/kb.pyx @@ -8,7 +8,7 @@ from cymem.cymem cimport Pool from ..errors import Errors from ..tokens import SpanGroup from ..util import SimpleFrozenList -from .typedefs cimport CandidatesForMention +from .candidate cimport Candidate cdef class KnowledgeBase: @@ -19,6 +19,8 @@ cdef class KnowledgeBase: DOCS: https://spacy.io/api/kb """ + CandidatesForMentionT = Iterable[Candidate] + CandidatesForDocT = Iterable[CandidatesForMentionT] def __init__(self, vocab: Vocab, entity_vector_length: int): """Create a KnowledgeBase.""" @@ -32,14 +34,14 @@ cdef class KnowledgeBase: self.entity_vector_length = entity_vector_length self.mem = Pool() - def get_candidates(self, mentions: Iterator[SpanGroup]) -> Iterator[Iterable[CandidatesForMention]]: + def get_candidates(self, mentions: Iterator[SpanGroup]) -> Iterator[CandidatesForDocT]: """ Return candidate entities for a specified Span mention. Each candidate defines at least the entity and the entity's embedding vector. Depending on the KB implementation, further properties - such as the prior probability of the specified mention text resolving to that entity - might be included. If no candidates are found for a given mention, an empty list is returned. mentions (Iterator[SpanGroup]): Mentions for which to get candidates. - RETURNS (Iterator[Iterable[Iterable[Candidate]]]): Identified candidates. + RETURNS (Iterator[Iterable[Iterable[Candidate]]]): Identified candidates per mention/doc/doc batch. """ raise NotImplementedError( Errors.E1045.format( diff --git a/spacy/kb/typedefs.pxd b/spacy/kb/typedefs.pxd deleted file mode 100644 index 4588b422e..000000000 --- a/spacy/kb/typedefs.pxd +++ /dev/null @@ -1,5 +0,0 @@ -from typing import Iterable - -from .candidate import Candidate - -ctypedef Iterable[Candidate] CandidatesForMention diff --git a/spacy/kb/typedefs.pyx b/spacy/kb/typedefs.pyx deleted file mode 100644 index 61bf62038..000000000 --- a/spacy/kb/typedefs.pyx +++ /dev/null @@ -1 +0,0 @@ -# cython: profile=False diff --git a/spacy/pipeline/entity_linker.py b/spacy/pipeline/entity_linker.py index c8f088820..07534c523 100644 --- a/spacy/pipeline/entity_linker.py +++ b/spacy/pipeline/entity_linker.py @@ -1,6 +1,6 @@ import random import warnings -from itertools import islice +from itertools import islice, tee from pathlib import Path from typing import ( Any, @@ -446,7 +446,7 @@ class EntityLinker(TrainablePipe): if isinstance(docs, Doc): docs = [docs] - docs = list(docs) + docs_iters = tee(docs, 2) # Call candidate generator. all_ent_cands = self.get_candidates( @@ -458,11 +458,11 @@ class EntityLinker(TrainablePipe): ent for ent in doc.ents if ent.label_ not in self.labels_discard ], ) - for doc in docs + for doc in docs_iters[0] ), ) - for doc in docs: + for doc in docs_iters[1]: doc_ents: List[Ints1d] = [] doc_scores: List[Floats1d] = [] if len(doc) == 0 or len(doc.ents) == 0: