Improve type annotations (#4448)

This commit is contained in:
Jahongir Qurbonov 2024-09-01 01:48:11 +05:00 committed by GitHub
parent 6253d28143
commit d9ef60782a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 48 additions and 49 deletions

View File

@ -110,8 +110,6 @@ from .updates import (
from .users import get_contacts, get_me, resolve_peers, resolve_phone, resolve_username from .users import get_contacts, get_me, resolve_peers, resolve_phone, resolve_username
Return = TypeVar("Return") Return = TypeVar("Return")
T = TypeVar("T")
AnyEvent = TypeVar("AnyEvent", bound=Event)
class Client: class Client:
@ -216,7 +214,7 @@ class Client:
datacenter: Optional[DataCenter] = None, datacenter: Optional[DataCenter] = None,
connector: Optional[Connector] = None, connector: Optional[Connector] = None,
) -> None: ) -> None:
assert isinstance(__package__, str) assert __package__
base_logger = logger or logging.getLogger(__package__[: __package__.index(".")]) base_logger = logger or logging.getLogger(__package__[: __package__.index(".")])
self._sender: Optional[Sender] = None self._sender: Optional[Sender] = None
@ -269,9 +267,9 @@ class Client:
def add_event_handler( def add_event_handler(
self, self,
handler: Callable[[AnyEvent], Awaitable[Any]], handler: Callable[[Event], Awaitable[Any]],
/, /,
event_cls: Type[AnyEvent], event_cls: Type[Event],
filter: Optional[FilterType] = None, filter: Optional[FilterType] = None,
) -> None: ) -> None:
""" """
@ -760,7 +758,7 @@ class Client:
return get_file_bytes(self, media) return get_file_bytes(self, media)
def get_handler_filter( def get_handler_filter(
self, handler: Callable[[AnyEvent], Awaitable[Any]], / self, handler: Callable[[Event], Awaitable[Any]], /
) -> Optional[FilterType]: ) -> Optional[FilterType]:
""" """
Get the filter associated to the given event handler. Get the filter associated to the given event handler.
@ -1036,9 +1034,9 @@ class Client:
return await is_authorized(self) return await is_authorized(self)
def on( def on(
self, event_cls: Type[AnyEvent], /, filter: Optional[FilterType] = None self, event_cls: Type[Event], /, filter: Optional[FilterType] = None
) -> Callable[ ) -> Callable[
[Callable[[AnyEvent], Awaitable[Any]]], Callable[[AnyEvent], Awaitable[Any]] [Callable[[Event], Awaitable[Any]]], Callable[[Event], Awaitable[Any]]
]: ]:
""" """
Register the decorated function to be invoked when the provided event type occurs. Register the decorated function to be invoked when the provided event type occurs.
@ -1161,7 +1159,7 @@ class Client:
await read_message(self, chat, message_id) await read_message(self, chat, message_id)
def remove_event_handler( def remove_event_handler(
self, handler: Callable[[AnyEvent], Awaitable[Any]], / self, handler: Callable[[Event], Awaitable[Any]], /
) -> None: ) -> None:
""" """
Remove the handler as a function to be called when events occur. Remove the handler as a function to be called when events occur.
@ -1851,7 +1849,7 @@ class Client:
def set_handler_filter( def set_handler_filter(
self, self,
handler: Callable[[AnyEvent], Awaitable[Any]], handler: Callable[[Event], Awaitable[Any]],
/, /,
filter: Optional[FilterType] = None, filter: Optional[FilterType] = None,
) -> None: ) -> None:

View File

@ -3,19 +3,17 @@ from __future__ import annotations
import asyncio import asyncio
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from inspect import isawaitable from inspect import isawaitable
from typing import TYPE_CHECKING, Any, Optional, Sequence, Type, TypeVar from typing import TYPE_CHECKING, Any, Optional, Sequence, Type
from ...session import Gap from ...session import Gap
from ...tl import abcs from ...tl import abcs
from ..events import Continue from ..events import Continue, Event
from ..events import Event as EventBase
from ..events.filters import FilterType from ..events.filters import FilterType
from ..types import build_chat_map from ..types import build_chat_map
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import Client from .client import Client
Event = TypeVar("Event", bound=EventBase)
UPDATE_LIMIT_EXCEEDED_LOG_COOLDOWN = 300 UPDATE_LIMIT_EXCEEDED_LOG_COOLDOWN = 300

View File

@ -10,7 +10,7 @@ if TYPE_CHECKING:
from ..client.client import Client from ..client.client import Client
class Event(metaclass=NoPublicConstructor): class Event(abc.ABC, metaclass=NoPublicConstructor):
""" """
The base type of all events. The base type of all events.
""" """

View File

@ -1,12 +1,11 @@
import abc import abc
import typing
from collections.abc import Callable from collections.abc import Callable
from inspect import isawaitable from inspect import isawaitable
from typing import Awaitable, TypeAlias from typing import Awaitable, TypeAlias
from ..event import Event from ..event import Event
FilterType: TypeAlias = Callable[[Event], bool | Awaitable[bool]] FilterType: TypeAlias = "Callable[[Event], bool | Awaitable[bool]] | Combinable"
class Combinable(abc.ABC): class Combinable(abc.ABC):
@ -22,24 +21,18 @@ class Combinable(abc.ABC):
Multiple ``~`` will toggle between using :class:`Not` and not using it. Multiple ``~`` will toggle between using :class:`Not` and not using it.
""" """
def __or__(self, other: typing.Any) -> FilterType: def __or__(self, other: FilterType) -> "Any":
if not callable(other):
return NotImplemented
lhs = self.filters if isinstance(self, Any) else (self,) lhs = self.filters if isinstance(self, Any) else (self,)
rhs = other.filters if isinstance(other, Any) else (other,) rhs = other.filters if isinstance(other, Any) else (other,)
return Any(*lhs, *rhs) # type: ignore [arg-type] return Any(*lhs, *rhs)
def __and__(self, other: typing.Any) -> FilterType:
if not callable(other):
return NotImplemented
def __and__(self, other: FilterType) -> "All":
lhs = self.filters if isinstance(self, All) else (self,) lhs = self.filters if isinstance(self, All) else (self,)
rhs = other.filters if isinstance(other, All) else (other,) rhs = other.filters if isinstance(other, All) else (other,)
return All(*lhs, *rhs) # type: ignore [arg-type] return All(*lhs, *rhs)
def __invert__(self) -> FilterType: def __invert__(self) -> "Not | FilterType":
return self.filter if isinstance(self, Not) else Not(self) # type: ignore [return-value] return self.filter if isinstance(self, Not) else Not(self)
@abc.abstractmethod @abc.abstractmethod
async def __call__(self, event: Event) -> bool: async def __call__(self, event: Event) -> bool:

View File

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import abc
import weakref import weakref
from typing import TYPE_CHECKING, Optional, TypeAlias from typing import TYPE_CHECKING, Optional, TypeAlias
@ -29,7 +30,7 @@ RawButtonType: TypeAlias = (
) )
class Button: class Button(abc.ABC):
""" """
The button base type. The button base type.

View File

@ -1,7 +1,9 @@
import abc
from .button import Button from .button import Button
class InlineButton(Button): class InlineButton(Button, abc.ABC):
""" """
Inline button base type. Inline button base type.

View File

@ -1,14 +1,14 @@
from typing import Generic, Optional, TypeAlias, TypeVar from typing import Optional, TypeAlias
from ...tl import abcs, types from ...tl import abcs, types
from .buttons import Button, InlineButton from .buttons import Button, InlineButton
AnyButton = TypeVar("AnyButton", bound=Button)
AnyInlineButton = TypeVar("AnyInlineButton", bound=InlineButton)
def _build_keyboard_rows( def _build_keyboard_rows(
btns: list[AnyButton] | list[list[AnyButton]], btns: list[Button]
| list[list[Button]]
| list[InlineButton]
| list[list[InlineButton]],
) -> list[abcs.KeyboardButtonRow]: ) -> list[abcs.KeyboardButtonRow]:
# list[button] -> list[list[button]] # list[button] -> list[list[button]]
# This does allow for "invalid" inputs (mixing lists and non-lists), but that's acceptable. # This does allow for "invalid" inputs (mixing lists and non-lists), but that's acceptable.
@ -24,12 +24,12 @@ def _build_keyboard_rows(
] ]
class Keyboard(Generic[AnyButton]): class Keyboard:
__slots__ = ("_raw",) __slots__ = ("_raw",)
def __init__( def __init__(
self, self,
buttons: list[AnyButton] | list[list[AnyButton]], buttons: list[Button] | list[list[Button]],
resize: bool, resize: bool,
single_use: bool, single_use: bool,
selective: bool, selective: bool,
@ -46,12 +46,10 @@ class Keyboard(Generic[AnyButton]):
) )
class InlineKeyboard(Generic[AnyInlineButton]): class InlineKeyboard:
__slots__ = ("_raw",) __slots__ = ("_raw",)
def __init__( def __init__(self, buttons: list[InlineButton] | list[list[InlineButton]]) -> None:
self, buttons: list[AnyInlineButton] | list[list[AnyInlineButton]]
) -> None:
self._raw = types.ReplyInlineMarkup(rows=_build_keyboard_rows(buttons)) self._raw = types.ReplyInlineMarkup(rows=_build_keyboard_rows(buttons))

View File

@ -4,10 +4,15 @@ from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Protocol
from .serializable import Serializable from .serializable import Serializable
class Buffer(Protocol):
def __buffer__(self, flags: int, /) -> memoryview: ...
T = TypeVar("T", bound="Serializable")
SerializableType = TypeVar("SerializableType", bound="Serializable")
def _bootstrap_get_ty(constructor_id: int) -> Optional[Type["Serializable"]]: def _bootstrap_get_ty(constructor_id: int) -> Optional[Type["Serializable"]]:
@ -33,7 +38,7 @@ def _bootstrap_get_ty(constructor_id: int) -> Optional[Type["Serializable"]]:
class Reader: class Reader:
__slots__ = ("_view", "_pos", "_len") __slots__ = ("_view", "_pos", "_len")
def __init__(self, buffer: bytes | bytearray | memoryview) -> None: def __init__(self, buffer: "Buffer") -> None:
self._view = ( self._view = (
memoryview(buffer) if not isinstance(buffer, memoryview) else buffer memoryview(buffer) if not isinstance(buffer, memoryview) else buffer
) )
@ -74,7 +79,7 @@ class Reader:
_get_ty = staticmethod(_bootstrap_get_ty) _get_ty = staticmethod(_bootstrap_get_ty)
def read_serializable(self, cls: Type[T]) -> T: def read_serializable(self, cls: Type[SerializableType]) -> SerializableType:
# Calls to this method likely need to ignore "type-abstract". # Calls to this method likely need to ignore "type-abstract".
# See https://github.com/python/mypy/issues/4717. # See https://github.com/python/mypy/issues/4717.
# Unfortunately `typing.cast` would add a tiny amount of runtime overhead # Unfortunately `typing.cast` would add a tiny amount of runtime overhead
@ -89,16 +94,20 @@ class Reader:
@functools.cache @functools.cache
def single_deserializer(cls: Type[T]) -> Callable[[bytes], T]: def single_deserializer(
def deserializer(body: bytes) -> T: cls: Type[SerializableType],
) -> Callable[[bytes], SerializableType]:
def deserializer(body: bytes) -> SerializableType:
return Reader(body).read_serializable(cls) return Reader(body).read_serializable(cls)
return deserializer return deserializer
@functools.cache @functools.cache
def list_deserializer(cls: Type[T]) -> Callable[[bytes], list[T]]: def list_deserializer(
def deserializer(body: bytes) -> list[T]: cls: Type[SerializableType],
) -> Callable[[bytes], list[SerializableType]]:
def deserializer(body: bytes) -> list[SerializableType]:
reader = Reader(body) reader = Reader(body)
vec_id, length = reader.read_fmt("<ii", 8) vec_id, length = reader.read_fmt("<ii", 8)
assert vec_id == 0x1CB5C415 and length >= 0 assert vec_id == 0x1CB5C415 and length >= 0