fix: mypy issues

This commit is contained in:
Basile Dura 2023-05-31 10:57:11 +02:00
parent 967ce504fd
commit 9cd17d7962
No known key found for this signature in database
2 changed files with 49 additions and 25 deletions

View File

@ -1,16 +1,30 @@
from typing import Callable, Protocol, Iterable, Iterator, Optional
from typing import Union, Tuple, List, Dict, Any, overload
from pathlib import Path
from typing import (
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
Optional,
Protocol,
Sequence,
Tuple,
Union,
overload,
)
import numpy as np
from cymem.cymem import Pool
from thinc.types import Floats1d, Floats2d, Ints2d
from .span import Span
from .token import Token
from ._dict_proxies import SpanGroups
from ._retokenize import Retokenizer
from ..lexeme import Lexeme
from ..vocab import Vocab
from ._dict_proxies import SpanGroups
from ._retokenize import Retokenizer
from .span import Span
from .token import Token
from .underscore import Underscore
from pathlib import Path
import numpy as np
class DocMethod(Protocol):
def __call__(self: Doc, *args: Any, **kwargs: Any) -> Any: ... # type: ignore[misc]
@ -119,7 +133,12 @@ class Doc:
def text(self) -> str: ...
@property
def text_with_ws(self) -> str: ...
ents: Tuple[Span]
# Ideally the getter would output Tuple[Span]
# see https://github.com/python/mypy/issues/3004
@property
def ents(self) -> Sequence[Span]: ...
@ents.setter
def ents(self, value: Sequence[Span]) -> None: ...
def set_ents(
self,
entities: List[Span],

View File

@ -1,16 +1,16 @@
import warnings
from typing import Union, List, Iterable, Iterator, TYPE_CHECKING, Callable
from typing import Optional
from pathlib import Path
import random
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Iterable, List, Optional, Protocol, Union
import srsly
from .. import util
from ..errors import Errors, Warnings
from ..tokens import Doc, DocBin
from ..vocab import Vocab
from .augment import dont_augment
from .example import Example
from ..errors import Warnings, Errors
from ..tokens import DocBin, Doc
from ..vocab import Vocab
if TYPE_CHECKING:
# This lets us add type hints for mypy etc. without causing circular imports
@ -19,6 +19,11 @@ if TYPE_CHECKING:
FILE_TYPE = ".spacy"
class ReaderProtocol(Protocol):
def __call__(self, nlp: "Language") -> Iterable[Example]:
pass
@util.registry.readers("spacy.Corpus.v1")
def create_docbin_reader(
path: Optional[Path],
@ -26,7 +31,7 @@ def create_docbin_reader(
max_length: int = 0,
limit: int = 0,
augmenter: Optional[Callable] = None,
) -> Callable[["Language"], Iterable[Example]]:
) -> ReaderProtocol:
if path is None:
raise ValueError(Errors.E913)
util.logger.debug("Loading corpus from path: %s", path)
@ -45,7 +50,7 @@ def create_jsonl_reader(
min_length: int = 0,
max_length: int = 0,
limit: int = 0,
) -> Callable[["Language"], Iterable[Example]]:
) -> ReaderProtocol:
return JsonlCorpus(path, min_length=min_length, max_length=max_length, limit=limit)
@ -63,7 +68,7 @@ def create_plain_text_reader(
path: Optional[Path],
min_length: int = 0,
max_length: int = 0,
) -> Callable[["Language"], Iterable[Doc]]:
) -> ReaderProtocol:
"""Iterate Example objects from a file or directory of plain text
UTF-8 files with one line per doc.
@ -144,7 +149,7 @@ class Corpus:
self.augmenter = augmenter if augmenter is not None else dont_augment
self.shuffle = shuffle
def __call__(self, nlp: "Language") -> Iterator[Example]:
def __call__(self, nlp: "Language") -> Iterable[Example]:
"""Yield examples from the data.
nlp (Language): The current nlp object.
@ -182,7 +187,7 @@ class Corpus:
def make_examples(
self, nlp: "Language", reference_docs: Iterable[Doc]
) -> Iterator[Example]:
) -> Iterable[Example]:
for reference in reference_docs:
if len(reference) == 0:
continue
@ -197,7 +202,7 @@ class Corpus:
def make_examples_gold_preproc(
self, nlp: "Language", reference_docs: Iterable[Doc]
) -> Iterator[Example]:
) -> Iterable[Example]:
for reference in reference_docs:
if reference.has_annotation("SENT_START"):
ref_sents = [sent.as_doc() for sent in reference.sents]
@ -210,7 +215,7 @@ class Corpus:
def read_docbin(
self, vocab: Vocab, locs: Iterable[Union[str, Path]]
) -> Iterator[Doc]:
) -> Iterable[Doc]:
"""Yield training examples as example dicts"""
i = 0
for loc in locs:
@ -257,7 +262,7 @@ class JsonlCorpus:
self.max_length = max_length
self.limit = limit
def __call__(self, nlp: "Language") -> Iterator[Example]:
def __call__(self, nlp: "Language") -> Iterable[Example]:
"""Yield examples from the data.
nlp (Language): The current nlp object.
@ -307,7 +312,7 @@ class PlainTextCorpus:
self.min_length = min_length
self.max_length = max_length
def __call__(self, nlp: "Language") -> Iterator[Example]:
def __call__(self, nlp: "Language") -> Iterable[Example]:
"""Yield examples from the data.
nlp (Language): The current nlp object.