Improve type annotations (#4448)

This commit is contained in:
Jahongir Qurbonov 2024-09-01 01:48:11 +05:00 committed by GitHub
parent 6253d28143
commit d9ef60782a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 48 additions and 49 deletions

View File

@ -110,8 +110,6 @@ from .updates import (
from .users import get_contacts, get_me, resolve_peers, resolve_phone, resolve_username
Return = TypeVar("Return")
T = TypeVar("T")
AnyEvent = TypeVar("AnyEvent", bound=Event)
class Client:
@ -216,7 +214,7 @@ class Client:
datacenter: Optional[DataCenter] = None,
connector: Optional[Connector] = None,
) -> None:
assert isinstance(__package__, str)
assert __package__
base_logger = logger or logging.getLogger(__package__[: __package__.index(".")])
self._sender: Optional[Sender] = None
@ -269,9 +267,9 @@ class Client:
def add_event_handler(
self,
handler: Callable[[AnyEvent], Awaitable[Any]],
handler: Callable[[Event], Awaitable[Any]],
/,
event_cls: Type[AnyEvent],
event_cls: Type[Event],
filter: Optional[FilterType] = None,
) -> None:
"""
@ -760,7 +758,7 @@ class Client:
return get_file_bytes(self, media)
def get_handler_filter(
self, handler: Callable[[AnyEvent], Awaitable[Any]], /
self, handler: Callable[[Event], Awaitable[Any]], /
) -> Optional[FilterType]:
"""
Get the filter associated to the given event handler.
@ -1036,9 +1034,9 @@ class Client:
return await is_authorized(self)
def on(
self, event_cls: Type[AnyEvent], /, filter: Optional[FilterType] = None
self, event_cls: Type[Event], /, filter: Optional[FilterType] = None
) -> Callable[
[Callable[[AnyEvent], Awaitable[Any]]], Callable[[AnyEvent], Awaitable[Any]]
[Callable[[Event], Awaitable[Any]]], Callable[[Event], Awaitable[Any]]
]:
"""
Register the decorated function to be invoked when the provided event type occurs.
@ -1161,7 +1159,7 @@ class Client:
await read_message(self, chat, message_id)
def remove_event_handler(
self, handler: Callable[[AnyEvent], Awaitable[Any]], /
self, handler: Callable[[Event], Awaitable[Any]], /
) -> None:
"""
Remove the handler as a function to be called when events occur.
@ -1851,7 +1849,7 @@ class Client:
def set_handler_filter(
self,
handler: Callable[[AnyEvent], Awaitable[Any]],
handler: Callable[[Event], Awaitable[Any]],
/,
filter: Optional[FilterType] = None,
) -> None:

View File

@ -3,19 +3,17 @@ from __future__ import annotations
import asyncio
from collections.abc import Awaitable, Callable
from inspect import isawaitable
from typing import TYPE_CHECKING, Any, Optional, Sequence, Type, TypeVar
from typing import TYPE_CHECKING, Any, Optional, Sequence, Type
from ...session import Gap
from ...tl import abcs
from ..events import Continue
from ..events import Event as EventBase
from ..events import Continue, Event
from ..events.filters import FilterType
from ..types import build_chat_map
if TYPE_CHECKING:
from .client import Client
Event = TypeVar("Event", bound=EventBase)
UPDATE_LIMIT_EXCEEDED_LOG_COOLDOWN = 300

View File

@ -10,7 +10,7 @@ if TYPE_CHECKING:
from ..client.client import Client
class Event(metaclass=NoPublicConstructor):
class Event(abc.ABC, metaclass=NoPublicConstructor):
"""
The base type of all events.
"""

View File

@ -1,12 +1,11 @@
import abc
import typing
from collections.abc import Callable
from inspect import isawaitable
from typing import Awaitable, TypeAlias
from ..event import Event
FilterType: TypeAlias = Callable[[Event], bool | Awaitable[bool]]
FilterType: TypeAlias = "Callable[[Event], bool | Awaitable[bool]] | Combinable"
class Combinable(abc.ABC):
@ -22,24 +21,18 @@ class Combinable(abc.ABC):
Multiple ``~`` will toggle between using :class:`Not` and not using it.
"""
def __or__(self, other: typing.Any) -> FilterType:
if not callable(other):
return NotImplemented
def __or__(self, other: FilterType) -> "Any":
lhs = self.filters if isinstance(self, Any) else (self,)
rhs = other.filters if isinstance(other, Any) else (other,)
return Any(*lhs, *rhs) # type: ignore [arg-type]
def __and__(self, other: typing.Any) -> FilterType:
if not callable(other):
return NotImplemented
return Any(*lhs, *rhs)
def __and__(self, other: FilterType) -> "All":
lhs = self.filters if isinstance(self, All) else (self,)
rhs = other.filters if isinstance(other, All) else (other,)
return All(*lhs, *rhs) # type: ignore [arg-type]
return All(*lhs, *rhs)
def __invert__(self) -> FilterType:
return self.filter if isinstance(self, Not) else Not(self) # type: ignore [return-value]
def __invert__(self) -> "Not | FilterType":
return self.filter if isinstance(self, Not) else Not(self)
@abc.abstractmethod
async def __call__(self, event: Event) -> bool:

View File

@ -1,5 +1,6 @@
from __future__ import annotations
import abc
import weakref
from typing import TYPE_CHECKING, Optional, TypeAlias
@ -29,7 +30,7 @@ RawButtonType: TypeAlias = (
)
class Button:
class Button(abc.ABC):
"""
The button base type.

View File

@ -1,7 +1,9 @@
import abc
from .button import Button
class InlineButton(Button):
class InlineButton(Button, abc.ABC):
"""
Inline button base type.

View File

@ -1,14 +1,14 @@
from typing import Generic, Optional, TypeAlias, TypeVar
from typing import Optional, TypeAlias
from ...tl import abcs, types
from .buttons import Button, InlineButton
AnyButton = TypeVar("AnyButton", bound=Button)
AnyInlineButton = TypeVar("AnyInlineButton", bound=InlineButton)
def _build_keyboard_rows(
btns: list[AnyButton] | list[list[AnyButton]],
btns: list[Button]
| list[list[Button]]
| list[InlineButton]
| list[list[InlineButton]],
) -> list[abcs.KeyboardButtonRow]:
# list[button] -> list[list[button]]
# This does allow for "invalid" inputs (mixing lists and non-lists), but that's acceptable.
@ -24,12 +24,12 @@ def _build_keyboard_rows(
]
class Keyboard(Generic[AnyButton]):
class Keyboard:
__slots__ = ("_raw",)
def __init__(
self,
buttons: list[AnyButton] | list[list[AnyButton]],
buttons: list[Button] | list[list[Button]],
resize: bool,
single_use: bool,
selective: bool,
@ -46,12 +46,10 @@ class Keyboard(Generic[AnyButton]):
)
class InlineKeyboard(Generic[AnyInlineButton]):
class InlineKeyboard:
__slots__ = ("_raw",)
def __init__(
self, buttons: list[AnyInlineButton] | list[list[AnyInlineButton]]
) -> None:
def __init__(self, buttons: list[InlineButton] | list[list[InlineButton]]) -> None:
self._raw = types.ReplyInlineMarkup(rows=_build_keyboard_rows(buttons))

View File

@ -4,10 +4,15 @@ from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar
if TYPE_CHECKING:
from typing import Protocol
from .serializable import Serializable
class Buffer(Protocol):
def __buffer__(self, flags: int, /) -> memoryview: ...
T = TypeVar("T", bound="Serializable")
SerializableType = TypeVar("SerializableType", bound="Serializable")
def _bootstrap_get_ty(constructor_id: int) -> Optional[Type["Serializable"]]:
@ -33,7 +38,7 @@ def _bootstrap_get_ty(constructor_id: int) -> Optional[Type["Serializable"]]:
class Reader:
__slots__ = ("_view", "_pos", "_len")
def __init__(self, buffer: bytes | bytearray | memoryview) -> None:
def __init__(self, buffer: "Buffer") -> None:
self._view = (
memoryview(buffer) if not isinstance(buffer, memoryview) else buffer
)
@ -74,7 +79,7 @@ class Reader:
_get_ty = staticmethod(_bootstrap_get_ty)
def read_serializable(self, cls: Type[T]) -> T:
def read_serializable(self, cls: Type[SerializableType]) -> SerializableType:
# Calls to this method likely need to ignore "type-abstract".
# See https://github.com/python/mypy/issues/4717.
# Unfortunately `typing.cast` would add a tiny amount of runtime overhead
@ -89,16 +94,20 @@ class Reader:
@functools.cache
def single_deserializer(cls: Type[T]) -> Callable[[bytes], T]:
def deserializer(body: bytes) -> T:
def single_deserializer(
cls: Type[SerializableType],
) -> Callable[[bytes], SerializableType]:
def deserializer(body: bytes) -> SerializableType:
return Reader(body).read_serializable(cls)
return deserializer
@functools.cache
def list_deserializer(cls: Type[T]) -> Callable[[bytes], list[T]]:
def deserializer(body: bytes) -> list[T]:
def list_deserializer(
cls: Type[SerializableType],
) -> Callable[[bytes], list[SerializableType]]:
def deserializer(body: bytes) -> list[SerializableType]:
reader = Reader(body)
vec_id, length = reader.read_fmt("<ii", 8)
assert vec_id == 0x1CB5C415 and length >= 0