diff --git a/spacy/tokens/doc.pyi b/spacy/tokens/doc.pyi index 9d45960ab..353983fd0 100644 --- a/spacy/tokens/doc.pyi +++ b/spacy/tokens/doc.pyi @@ -1,16 +1,30 @@ -from typing import Callable, Protocol, Iterable, Iterator, Optional -from typing import Union, Tuple, List, Dict, Any, overload +from pathlib import Path +from typing import ( + Any, + Callable, + Dict, + Iterable, + Iterator, + List, + Optional, + Protocol, + Sequence, + Tuple, + Union, + overload, +) + +import numpy as np from cymem.cymem import Pool from thinc.types import Floats1d, Floats2d, Ints2d -from .span import Span -from .token import Token -from ._dict_proxies import SpanGroups -from ._retokenize import Retokenizer + from ..lexeme import Lexeme from ..vocab import Vocab +from ._dict_proxies import SpanGroups +from ._retokenize import Retokenizer +from .span import Span +from .token import Token from .underscore import Underscore -from pathlib import Path -import numpy as np class DocMethod(Protocol): def __call__(self: Doc, *args: Any, **kwargs: Any) -> Any: ... # type: ignore[misc] @@ -119,7 +133,12 @@ class Doc: def text(self) -> str: ... @property def text_with_ws(self) -> str: ... - ents: Tuple[Span] + # Ideally the getter would output Tuple[Span] + # see https://github.com/python/mypy/issues/3004 + @property + def ents(self) -> Sequence[Span]: ... + @ents.setter + def ents(self, value: Sequence[Span]) -> None: ... def set_ents( self, entities: List[Span], diff --git a/spacy/training/corpus.py b/spacy/training/corpus.py index 086ad831c..f05d09bcb 100644 --- a/spacy/training/corpus.py +++ b/spacy/training/corpus.py @@ -1,16 +1,16 @@ -import warnings -from typing import Union, List, Iterable, Iterator, TYPE_CHECKING, Callable -from typing import Optional -from pathlib import Path import random +import warnings +from pathlib import Path +from typing import TYPE_CHECKING, Callable, Iterable, List, Optional, Protocol, Union + import srsly from .. import util +from ..errors import Errors, Warnings +from ..tokens import Doc, DocBin +from ..vocab import Vocab from .augment import dont_augment from .example import Example -from ..errors import Warnings, Errors -from ..tokens import DocBin, Doc -from ..vocab import Vocab if TYPE_CHECKING: # This lets us add type hints for mypy etc. without causing circular imports @@ -19,6 +19,11 @@ if TYPE_CHECKING: FILE_TYPE = ".spacy" +class ReaderProtocol(Protocol): + def __call__(self, nlp: "Language") -> Iterable[Example]: + pass + + @util.registry.readers("spacy.Corpus.v1") def create_docbin_reader( path: Optional[Path], @@ -26,7 +31,7 @@ def create_docbin_reader( max_length: int = 0, limit: int = 0, augmenter: Optional[Callable] = None, -) -> Callable[["Language"], Iterable[Example]]: +) -> ReaderProtocol: if path is None: raise ValueError(Errors.E913) util.logger.debug("Loading corpus from path: %s", path) @@ -45,7 +50,7 @@ def create_jsonl_reader( min_length: int = 0, max_length: int = 0, limit: int = 0, -) -> Callable[["Language"], Iterable[Example]]: +) -> ReaderProtocol: return JsonlCorpus(path, min_length=min_length, max_length=max_length, limit=limit) @@ -63,7 +68,7 @@ def create_plain_text_reader( path: Optional[Path], min_length: int = 0, max_length: int = 0, -) -> Callable[["Language"], Iterable[Doc]]: +) -> ReaderProtocol: """Iterate Example objects from a file or directory of plain text UTF-8 files with one line per doc. @@ -144,7 +149,7 @@ class Corpus: self.augmenter = augmenter if augmenter is not None else dont_augment self.shuffle = shuffle - def __call__(self, nlp: "Language") -> Iterator[Example]: + def __call__(self, nlp: "Language") -> Iterable[Example]: """Yield examples from the data. nlp (Language): The current nlp object. @@ -182,7 +187,7 @@ class Corpus: def make_examples( self, nlp: "Language", reference_docs: Iterable[Doc] - ) -> Iterator[Example]: + ) -> Iterable[Example]: for reference in reference_docs: if len(reference) == 0: continue @@ -197,7 +202,7 @@ class Corpus: def make_examples_gold_preproc( self, nlp: "Language", reference_docs: Iterable[Doc] - ) -> Iterator[Example]: + ) -> Iterable[Example]: for reference in reference_docs: if reference.has_annotation("SENT_START"): ref_sents = [sent.as_doc() for sent in reference.sents] @@ -210,7 +215,7 @@ class Corpus: def read_docbin( self, vocab: Vocab, locs: Iterable[Union[str, Path]] - ) -> Iterator[Doc]: + ) -> Iterable[Doc]: """Yield training examples as example dicts""" i = 0 for loc in locs: @@ -257,7 +262,7 @@ class JsonlCorpus: self.max_length = max_length self.limit = limit - def __call__(self, nlp: "Language") -> Iterator[Example]: + def __call__(self, nlp: "Language") -> Iterable[Example]: """Yield examples from the data. nlp (Language): The current nlp object. @@ -307,7 +312,7 @@ class PlainTextCorpus: self.min_length = min_length self.max_length = max_length - def __call__(self, nlp: "Language") -> Iterator[Example]: + def __call__(self, nlp: "Language") -> Iterable[Example]: """Yield examples from the data. nlp (Language): The current nlp object.