diff --git a/spacy/training/corpus.py b/spacy/training/corpus.py index 37af9e476..c15c31a2a 100644 --- a/spacy/training/corpus.py +++ b/spacy/training/corpus.py @@ -6,7 +6,6 @@ from typing import TYPE_CHECKING, Callable, Iterable, Iterator, List, Optional, import srsly from .. import util -from ..compat import Protocol from ..errors import Errors, Warnings from ..tokens import Doc, DocBin from ..vocab import Vocab @@ -20,10 +19,6 @@ if TYPE_CHECKING: FILE_TYPE = ".spacy" -class ReaderProtocol(Protocol): - def __call__(self, nlp: "Language") -> Iterable[Example]: - pass - @util.registry.readers("spacy.Corpus.v1") def create_docbin_reader( @@ -32,7 +27,7 @@ def create_docbin_reader( max_length: int = 0, limit: int = 0, augmenter: Optional[Callable] = None, -) -> ReaderProtocol: +) -> Callable[["Language"], Iterable[Example]]: if path is None: raise ValueError(Errors.E913) util.logger.debug("Loading corpus from path: %s", path) @@ -51,7 +46,7 @@ def create_jsonl_reader( min_length: int = 0, max_length: int = 0, limit: int = 0, -) -> ReaderProtocol: +) -> Callable[["Language"], Iterable[Example]]: return JsonlCorpus(path, min_length=min_length, max_length=max_length, limit=limit) @@ -69,7 +64,7 @@ def create_plain_text_reader( path: Optional[Path], min_length: int = 0, max_length: int = 0, -) -> ReaderProtocol: +) -> Callable[["Language"], Iterable[Example]]: """Iterate Example objects from a file or directory of plain text UTF-8 files with one line per doc. @@ -150,7 +145,7 @@ class Corpus: self.augmenter = augmenter if augmenter is not None else dont_augment self.shuffle = shuffle - def __call__(self, nlp: "Language") -> Iterable[Example]: + def __call__(self, nlp: "Language") -> Iterator[Example]: """Yield examples from the data. nlp (Language): The current nlp object. @@ -188,7 +183,7 @@ class Corpus: def make_examples( self, nlp: "Language", reference_docs: Iterable[Doc] - ) -> Iterable[Example]: + ) -> Iterator[Example]: for reference in reference_docs: if len(reference) == 0: continue @@ -203,7 +198,7 @@ class Corpus: def make_examples_gold_preproc( self, nlp: "Language", reference_docs: Iterable[Doc] - ) -> Iterable[Example]: + ) -> Iterator[Example]: for reference in reference_docs: if reference.has_annotation("SENT_START"): ref_sents = [sent.as_doc() for sent in reference.sents] @@ -263,7 +258,7 @@ class JsonlCorpus: self.max_length = max_length self.limit = limit - def __call__(self, nlp: "Language") -> Iterable[Example]: + def __call__(self, nlp: "Language") -> Iterator[Example]: """Yield examples from the data. nlp (Language): The current nlp object. @@ -313,7 +308,7 @@ class PlainTextCorpus: self.min_length = min_length self.max_length = max_length - def __call__(self, nlp: "Language") -> Iterable[Example]: + def __call__(self, nlp: "Language") -> Iterator[Example]: """Yield examples from the data. nlp (Language): The current nlp object.