revert changes to corpus.py

This commit is contained in:
svlandeg 2023-07-07 10:18:05 +02:00
parent 350b8dd644
commit 75514c4bc4

View File

@ -6,7 +6,6 @@ from typing import TYPE_CHECKING, Callable, Iterable, Iterator, List, Optional,
import srsly
from .. import util
from ..compat import Protocol
from ..errors import Errors, Warnings
from ..tokens import Doc, DocBin
from ..vocab import Vocab
@ -20,10 +19,6 @@ if TYPE_CHECKING:
FILE_TYPE = ".spacy"
class ReaderProtocol(Protocol):
def __call__(self, nlp: "Language") -> Iterable[Example]:
pass
@util.registry.readers("spacy.Corpus.v1")
def create_docbin_reader(
@ -32,7 +27,7 @@ def create_docbin_reader(
max_length: int = 0,
limit: int = 0,
augmenter: Optional[Callable] = None,
) -> ReaderProtocol:
) -> Callable[["Language"], Iterable[Example]]:
if path is None:
raise ValueError(Errors.E913)
util.logger.debug("Loading corpus from path: %s", path)
@ -51,7 +46,7 @@ def create_jsonl_reader(
min_length: int = 0,
max_length: int = 0,
limit: int = 0,
) -> ReaderProtocol:
) -> Callable[["Language"], Iterable[Example]]:
return JsonlCorpus(path, min_length=min_length, max_length=max_length, limit=limit)
@ -69,7 +64,7 @@ def create_plain_text_reader(
path: Optional[Path],
min_length: int = 0,
max_length: int = 0,
) -> ReaderProtocol:
) -> Callable[["Language"], Iterable[Example]]:
"""Iterate Example objects from a file or directory of plain text
UTF-8 files with one line per doc.
@ -150,7 +145,7 @@ class Corpus:
self.augmenter = augmenter if augmenter is not None else dont_augment
self.shuffle = shuffle
def __call__(self, nlp: "Language") -> Iterable[Example]:
def __call__(self, nlp: "Language") -> Iterator[Example]:
"""Yield examples from the data.
nlp (Language): The current nlp object.
@ -188,7 +183,7 @@ class Corpus:
def make_examples(
self, nlp: "Language", reference_docs: Iterable[Doc]
) -> Iterable[Example]:
) -> Iterator[Example]:
for reference in reference_docs:
if len(reference) == 0:
continue
@ -203,7 +198,7 @@ class Corpus:
def make_examples_gold_preproc(
self, nlp: "Language", reference_docs: Iterable[Doc]
) -> Iterable[Example]:
) -> Iterator[Example]:
for reference in reference_docs:
if reference.has_annotation("SENT_START"):
ref_sents = [sent.as_doc() for sent in reference.sents]
@ -263,7 +258,7 @@ class JsonlCorpus:
self.max_length = max_length
self.limit = limit
def __call__(self, nlp: "Language") -> Iterable[Example]:
def __call__(self, nlp: "Language") -> Iterator[Example]:
"""Yield examples from the data.
nlp (Language): The current nlp object.
@ -313,7 +308,7 @@ class PlainTextCorpus:
self.min_length = min_length
self.max_length = max_length
def __call__(self, nlp: "Language") -> Iterable[Example]:
def __call__(self, nlp: "Language") -> Iterator[Example]:
"""Yield examples from the data.
nlp (Language): The current nlp object.