fix: mypy issues

This commit is contained in:
Basile Dura 2023-05-31 10:57:11 +02:00
parent 967ce504fd
commit 9cd17d7962
No known key found for this signature in database
2 changed files with 49 additions and 25 deletions

View File

@ -1,16 +1,30 @@
from typing import Callable, Protocol, Iterable, Iterator, Optional from pathlib import Path
from typing import Union, Tuple, List, Dict, Any, overload from typing import (
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
Optional,
Protocol,
Sequence,
Tuple,
Union,
overload,
)
import numpy as np
from cymem.cymem import Pool from cymem.cymem import Pool
from thinc.types import Floats1d, Floats2d, Ints2d from thinc.types import Floats1d, Floats2d, Ints2d
from .span import Span
from .token import Token
from ._dict_proxies import SpanGroups
from ._retokenize import Retokenizer
from ..lexeme import Lexeme from ..lexeme import Lexeme
from ..vocab import Vocab from ..vocab import Vocab
from ._dict_proxies import SpanGroups
from ._retokenize import Retokenizer
from .span import Span
from .token import Token
from .underscore import Underscore from .underscore import Underscore
from pathlib import Path
import numpy as np
class DocMethod(Protocol): class DocMethod(Protocol):
def __call__(self: Doc, *args: Any, **kwargs: Any) -> Any: ... # type: ignore[misc] def __call__(self: Doc, *args: Any, **kwargs: Any) -> Any: ... # type: ignore[misc]
@ -119,7 +133,12 @@ class Doc:
def text(self) -> str: ... def text(self) -> str: ...
@property @property
def text_with_ws(self) -> str: ... def text_with_ws(self) -> str: ...
ents: Tuple[Span] # Ideally the getter would output Tuple[Span]
# see https://github.com/python/mypy/issues/3004
@property
def ents(self) -> Sequence[Span]: ...
@ents.setter
def ents(self, value: Sequence[Span]) -> None: ...
def set_ents( def set_ents(
self, self,
entities: List[Span], entities: List[Span],

View File

@ -1,16 +1,16 @@
import warnings
from typing import Union, List, Iterable, Iterator, TYPE_CHECKING, Callable
from typing import Optional
from pathlib import Path
import random import random
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Iterable, List, Optional, Protocol, Union
import srsly import srsly
from .. import util from .. import util
from ..errors import Errors, Warnings
from ..tokens import Doc, DocBin
from ..vocab import Vocab
from .augment import dont_augment from .augment import dont_augment
from .example import Example from .example import Example
from ..errors import Warnings, Errors
from ..tokens import DocBin, Doc
from ..vocab import Vocab
if TYPE_CHECKING: if TYPE_CHECKING:
# This lets us add type hints for mypy etc. without causing circular imports # This lets us add type hints for mypy etc. without causing circular imports
@ -19,6 +19,11 @@ if TYPE_CHECKING:
FILE_TYPE = ".spacy" FILE_TYPE = ".spacy"
class ReaderProtocol(Protocol):
def __call__(self, nlp: "Language") -> Iterable[Example]:
pass
@util.registry.readers("spacy.Corpus.v1") @util.registry.readers("spacy.Corpus.v1")
def create_docbin_reader( def create_docbin_reader(
path: Optional[Path], path: Optional[Path],
@ -26,7 +31,7 @@ def create_docbin_reader(
max_length: int = 0, max_length: int = 0,
limit: int = 0, limit: int = 0,
augmenter: Optional[Callable] = None, augmenter: Optional[Callable] = None,
) -> Callable[["Language"], Iterable[Example]]: ) -> ReaderProtocol:
if path is None: if path is None:
raise ValueError(Errors.E913) raise ValueError(Errors.E913)
util.logger.debug("Loading corpus from path: %s", path) util.logger.debug("Loading corpus from path: %s", path)
@ -45,7 +50,7 @@ def create_jsonl_reader(
min_length: int = 0, min_length: int = 0,
max_length: int = 0, max_length: int = 0,
limit: int = 0, limit: int = 0,
) -> Callable[["Language"], Iterable[Example]]: ) -> ReaderProtocol:
return JsonlCorpus(path, min_length=min_length, max_length=max_length, limit=limit) return JsonlCorpus(path, min_length=min_length, max_length=max_length, limit=limit)
@ -63,7 +68,7 @@ def create_plain_text_reader(
path: Optional[Path], path: Optional[Path],
min_length: int = 0, min_length: int = 0,
max_length: int = 0, max_length: int = 0,
) -> Callable[["Language"], Iterable[Doc]]: ) -> ReaderProtocol:
"""Iterate Example objects from a file or directory of plain text """Iterate Example objects from a file or directory of plain text
UTF-8 files with one line per doc. UTF-8 files with one line per doc.
@ -144,7 +149,7 @@ class Corpus:
self.augmenter = augmenter if augmenter is not None else dont_augment self.augmenter = augmenter if augmenter is not None else dont_augment
self.shuffle = shuffle self.shuffle = shuffle
def __call__(self, nlp: "Language") -> Iterator[Example]: def __call__(self, nlp: "Language") -> Iterable[Example]:
"""Yield examples from the data. """Yield examples from the data.
nlp (Language): The current nlp object. nlp (Language): The current nlp object.
@ -182,7 +187,7 @@ class Corpus:
def make_examples( def make_examples(
self, nlp: "Language", reference_docs: Iterable[Doc] self, nlp: "Language", reference_docs: Iterable[Doc]
) -> Iterator[Example]: ) -> Iterable[Example]:
for reference in reference_docs: for reference in reference_docs:
if len(reference) == 0: if len(reference) == 0:
continue continue
@ -197,7 +202,7 @@ class Corpus:
def make_examples_gold_preproc( def make_examples_gold_preproc(
self, nlp: "Language", reference_docs: Iterable[Doc] self, nlp: "Language", reference_docs: Iterable[Doc]
) -> Iterator[Example]: ) -> Iterable[Example]:
for reference in reference_docs: for reference in reference_docs:
if reference.has_annotation("SENT_START"): if reference.has_annotation("SENT_START"):
ref_sents = [sent.as_doc() for sent in reference.sents] ref_sents = [sent.as_doc() for sent in reference.sents]
@ -210,7 +215,7 @@ class Corpus:
def read_docbin( def read_docbin(
self, vocab: Vocab, locs: Iterable[Union[str, Path]] self, vocab: Vocab, locs: Iterable[Union[str, Path]]
) -> Iterator[Doc]: ) -> Iterable[Doc]:
"""Yield training examples as example dicts""" """Yield training examples as example dicts"""
i = 0 i = 0
for loc in locs: for loc in locs:
@ -257,7 +262,7 @@ class JsonlCorpus:
self.max_length = max_length self.max_length = max_length
self.limit = limit self.limit = limit
def __call__(self, nlp: "Language") -> Iterator[Example]: def __call__(self, nlp: "Language") -> Iterable[Example]:
"""Yield examples from the data. """Yield examples from the data.
nlp (Language): The current nlp object. nlp (Language): The current nlp object.
@ -307,7 +312,7 @@ class PlainTextCorpus:
self.min_length = min_length self.min_length = min_length
self.max_length = max_length self.max_length = max_length
def __call__(self, nlp: "Language") -> Iterator[Example]: def __call__(self, nlp: "Language") -> Iterable[Example]:
"""Yield examples from the data. """Yield examples from the data.
nlp (Language): The current nlp object. nlp (Language): The current nlp object.