Update typing hints (#10109)

* Improve typing hints for Matcher.__call__

* Add typing hints for DependencyMatcher

* Add typing hints to underscore extensions

* Update Doc.tensor type (requires numpy 1.21)

* Fix typing hints for Language.component decorator

* Use generic np.ndarray type in Doc to avoid numpy version update

* Fix mypy errors

* Fix cyclic import caused by Underscore typing hints

* Use Literal type from spacy.compat

* Update matcher.pyi import format

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
This commit is contained in:
Eduard Zorita 2022-01-28 16:59:54 +01:00 committed by GitHub
parent 09734c56fc
commit 30cf9d6a05
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 127 additions and 25 deletions

View File

@ -522,7 +522,7 @@ class Language:
requires: Iterable[str] = SimpleFrozenList(),
retokenizes: bool = False,
func: Optional["Pipe"] = None,
) -> Callable:
) -> Callable[..., Any]:
"""Register a new pipeline component. Can be used for stateless function
components that don't require a separate factory. Can be used as a
decorator on a function or classmethod, or called as a function with the

View File

@ -0,0 +1,66 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from .matcher import Matcher
from ..vocab import Vocab
from ..tokens.doc import Doc
from ..tokens.span import Span
class DependencyMatcher:
"""Match dependency parse tree based on pattern rules."""
_patterns: Dict[str, List[Any]]
_raw_patterns: Dict[str, List[Any]]
_tokens_to_key: Dict[str, List[Any]]
_root: Dict[str, List[Any]]
_tree: Dict[str, List[Any]]
_callbacks: Dict[
Any, Callable[[DependencyMatcher, Doc, int, List[Tuple[int, List[int]]]], Any]
]
_ops: Dict[str, Any]
vocab: Vocab
_matcher: Matcher
def __init__(self, vocab: Vocab, *, validate: bool = ...) -> None: ...
def __reduce__(
self,
) -> Tuple[
Callable[
[Vocab, Dict[str, Any], Dict[str, Callable[..., Any]]], DependencyMatcher
],
Tuple[
Vocab,
Dict[str, List[Any]],
Dict[
str,
Callable[
[DependencyMatcher, Doc, int, List[Tuple[int, List[int]]]], Any
],
],
],
None,
None,
]: ...
def __len__(self) -> int: ...
def __contains__(self, key: Union[str, int]) -> bool: ...
def add(
self,
key: Union[str, int],
patterns: List[List[Dict[str, Any]]],
*,
on_match: Optional[
Callable[[DependencyMatcher, Doc, int, List[Tuple[int, List[int]]]], Any]
] = ...
) -> None: ...
def has_key(self, key: Union[str, int]) -> bool: ...
def get(
self, key: Union[str, int], default: Optional[Any] = ...
) -> Tuple[
Optional[
Callable[[DependencyMatcher, Doc, int, List[Tuple[int, List[int]]]], Any]
],
List[List[Dict[str, Any]]],
]: ...
def remove(self, key: Union[str, int]) -> None: ...
def __call__(self, doclike: Union[Doc, Span]) -> List[Tuple[int, List[int]]]: ...
def unpickle_matcher(
vocab: Vocab, patterns: Dict[str, Any], callbacks: Dict[str, Callable[..., Any]]
) -> DependencyMatcher: ...

View File

@ -1,4 +1,6 @@
from typing import Any, List, Dict, Tuple, Optional, Callable, Union, Iterator, Iterable
from typing import Any, List, Dict, Tuple, Optional, Callable, Union
from typing import Iterator, Iterable, overload
from ..compat import Literal
from ..vocab import Vocab
from ..tokens import Doc, Span
@ -31,12 +33,22 @@ class Matcher:
) -> Union[
Iterator[Tuple[Tuple[Doc, Any], Any]], Iterator[Tuple[Doc, Any]], Iterator[Doc]
]: ...
@overload
def __call__(
self,
doclike: Union[Doc, Span],
*,
as_spans: bool = ...,
as_spans: Literal[False] = ...,
allow_missing: bool = ...,
with_alignments: bool = ...
) -> Union[List[Tuple[int, int, int]], List[Span]]: ...
) -> List[Tuple[int, int, int]]: ...
@overload
def __call__(
self,
doclike: Union[Doc, Span],
*,
as_spans: Literal[True],
allow_missing: bool = ...,
with_alignments: bool = ...
) -> List[Span]: ...
def _normalize_key(self, key: Any) -> Any: ...

View File

@ -1,6 +1,6 @@
from typing import List, Tuple, Union, Optional, Callable, Any, Dict
from . import Matcher
from typing import List, Tuple, Union, Optional, Callable, Any, Dict, overload
from ..compat import Literal
from .matcher import Matcher
from ..vocab import Vocab
from ..tokens import Doc, Span
@ -21,9 +21,17 @@ class PhraseMatcher:
] = ...,
) -> None: ...
def remove(self, key: str) -> None: ...
@overload
def __call__(
self,
doclike: Union[Doc, Span],
*,
as_spans: bool = ...,
) -> Union[List[Tuple[int, int, int]], List[Span]]: ...
as_spans: Literal[False] = ...,
) -> List[Tuple[int, int, int]]: ...
@overload
def __call__(
self,
doclike: Union[Doc, Span],
*,
as_spans: Literal[True],
) -> List[Span]: ...

View File

@ -10,7 +10,7 @@ from ..lexeme import Lexeme
from ..vocab import Vocab
from .underscore import Underscore
from pathlib import Path
import numpy
import numpy as np
class DocMethod(Protocol):
def __call__(self: Doc, *args: Any, **kwargs: Any) -> Any: ... # type: ignore[misc]
@ -26,7 +26,7 @@ class Doc:
user_hooks: Dict[str, Callable[..., Any]]
user_token_hooks: Dict[str, Callable[..., Any]]
user_span_hooks: Dict[str, Callable[..., Any]]
tensor: numpy.ndarray
tensor: np.ndarray[Any, np.dtype[np.float_]]
user_data: Dict[str, Any]
has_unknown_spaces: bool
_context: Any
@ -144,7 +144,7 @@ class Doc:
) -> Doc: ...
def to_array(
self, py_attr_ids: Union[int, str, List[Union[int, str]]]
) -> numpy.ndarray: ...
) -> np.ndarray[Any, np.dtype[np.float_]]: ...
@staticmethod
def from_docs(
docs: List[Doc],

View File

@ -1,17 +1,31 @@
from typing import Dict, Any
from typing import Dict, Any, List, Optional, Tuple, Union, TYPE_CHECKING
import functools
import copy
from ..errors import Errors
if TYPE_CHECKING:
from .doc import Doc
from .span import Span
from .token import Token
class Underscore:
mutable_types = (dict, list, set)
doc_extensions: Dict[Any, Any] = {}
span_extensions: Dict[Any, Any] = {}
token_extensions: Dict[Any, Any] = {}
_extensions: Dict[str, Any]
_obj: Union["Doc", "Span", "Token"]
_start: Optional[int]
_end: Optional[int]
def __init__(self, extensions, obj, start=None, end=None):
def __init__(
self,
extensions: Dict[str, Any],
obj: Union["Doc", "Span", "Token"],
start: Optional[int] = None,
end: Optional[int] = None,
):
object.__setattr__(self, "_extensions", extensions)
object.__setattr__(self, "_obj", obj)
# Assumption is that for doc values, _start and _end will both be None
@ -23,12 +37,12 @@ class Underscore:
object.__setattr__(self, "_start", start)
object.__setattr__(self, "_end", end)
def __dir__(self):
def __dir__(self) -> List[str]:
# Hack to enable autocomplete on custom extensions
extensions = list(self._extensions.keys())
return ["set", "get", "has"] + extensions
def __getattr__(self, name):
def __getattr__(self, name: str) -> Any:
if name not in self._extensions:
raise AttributeError(Errors.E046.format(name=name))
default, method, getter, setter = self._extensions[name]
@ -56,7 +70,7 @@ class Underscore:
return new_default
return default
def __setattr__(self, name, value):
def __setattr__(self, name: str, value: Any):
if name not in self._extensions:
raise AttributeError(Errors.E047.format(name=name))
default, method, getter, setter = self._extensions[name]
@ -65,28 +79,30 @@ class Underscore:
else:
self._doc.user_data[self._get_key(name)] = value
def set(self, name, value):
def set(self, name: str, value: Any):
return self.__setattr__(name, value)
def get(self, name):
def get(self, name: str) -> Any:
return self.__getattr__(name)
def has(self, name):
def has(self, name: str) -> bool:
return name in self._extensions
def _get_key(self, name):
def _get_key(self, name: str) -> Tuple[str, str, Optional[int], Optional[int]]:
return ("._.", name, self._start, self._end)
@classmethod
def get_state(cls):
def get_state(cls) -> Tuple[Dict[Any, Any], Dict[Any, Any], Dict[Any, Any]]:
return cls.token_extensions, cls.span_extensions, cls.doc_extensions
@classmethod
def load_state(cls, state):
def load_state(
cls, state: Tuple[Dict[Any, Any], Dict[Any, Any], Dict[Any, Any]]
) -> None:
cls.token_extensions, cls.span_extensions, cls.doc_extensions = state
def get_ext_args(**kwargs):
def get_ext_args(**kwargs: Any):
"""Validate and convert arguments. Reused in Doc, Token and Span."""
default = kwargs.get("default")
getter = kwargs.get("getter")