[V2] Upgrade ruff and mypy version, format files (#4474)

This commit is contained in:
Jahongir Qurbonov 2024-10-06 23:05:11 +05:00 committed by GitHub
parent 918f719ab2
commit 86d41e1f06
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
67 changed files with 177 additions and 118 deletions

1
.gitignore vendored
View File

@ -4,6 +4,7 @@ __pycache__/
*.egg-info/ *.egg-info/
.pytest_cache/ .pytest_cache/
.mypy_cache/ .mypy_cache/
.ruff_cache/
dist/ dist/
dist-doc/ dist-doc/
build/ build/

View File

@ -28,10 +28,8 @@ dynamic = ["version"]
[project.optional-dependencies] [project.optional-dependencies]
cryptg = ["cryptg~=0.4"] cryptg = ["cryptg~=0.4"]
dev = [ dev = [
"isort~=5.12", "mypy~=1.11.2",
"black~=23.3.0", "ruff~=0.6.8",
"mypy~=1.3",
"ruff~=0.0.292",
"pytest~=7.3", "pytest~=7.3",
"pytest-asyncio~=0.21", "pytest-asyncio~=0.21",
] ]
@ -55,6 +53,13 @@ backend-path = ["build_backend"]
version = {attr = "telethon.version.__version__"} version = {attr = "telethon.version.__version__"}
[tool.ruff] [tool.ruff]
exclude = ["doc"]
[tool.ruff.lint]
select = ["F", "E", "W", "I", "N", "ANN"]
ignore = [ ignore = [
"E501", # formatter takes care of lines that are too long besides documentation "E501", # formatter takes care of lines that are too long besides documentation
"ANN101", # Missing type annotation for `self` in method
"ANN102", # Missing type annotation for `cls` in classmethod
"ANN401", # Dynamically typed expressions (typing.Any) are not type checked
] ]

View File

@ -18,7 +18,7 @@ class InlineResults(metaclass=NoPublicConstructor):
bot: abcs.InputUser, bot: abcs.InputUser,
query: str, query: str,
peer: Optional[PeerRef], peer: Optional[PeerRef],
): ) -> None:
self._client = client self._client = client
self._bot = bot self._bot = bot
self._query = query self._query = query

View File

@ -29,7 +29,7 @@ class ParticipantList(AsyncList[Participant]):
self, self,
client: Client, client: Client,
peer: ChannelRef | GroupRef, peer: ChannelRef | GroupRef,
): ) -> None:
super().__init__() super().__init__()
self._client = client self._client = client
self._peer = peer self._peer = peer
@ -106,7 +106,7 @@ class RecentActionList(AsyncList[RecentAction]):
self, self,
client: Client, client: Client,
peer: ChannelRef | GroupRef, peer: ChannelRef | GroupRef,
): ) -> None:
super().__init__() super().__init__()
self._client = client self._client = client
self._peer = peer self._peer = peer
@ -148,7 +148,7 @@ class ProfilePhotoList(AsyncList[File]):
self, self,
client: Client, client: Client,
peer: PeerRef, peer: PeerRef,
): ) -> None:
super().__init__() super().__init__()
self._client = client self._client = client
self._peer = peer self._peer = peer

View File

@ -20,7 +20,7 @@ if TYPE_CHECKING:
class DialogList(AsyncList[Dialog]): class DialogList(AsyncList[Dialog]):
def __init__(self, client: Client): def __init__(self, client: Client) -> None:
super().__init__() super().__init__()
self._client = client self._client = client
self._offset = 0 self._offset = 0
@ -93,7 +93,7 @@ async def delete_dialog(self: Client, dialog: Peer | PeerRef, /) -> None:
class DraftList(AsyncList[Draft]): class DraftList(AsyncList[Draft]):
def __init__(self, client: Client): def __init__(self, client: Client) -> None:
super().__init__() super().__init__()
self._client = client self._client = client
self._offset = 0 self._offset = 0

View File

@ -425,7 +425,7 @@ class FileBytesList(AsyncList[bytes]):
self, self,
client: Client, client: Client,
file: File, file: File,
): ) -> None:
super().__init__() super().__init__()
self._client = client self._client = client
self._loc = file._input_location() self._loc = file._input_location()

View File

@ -253,7 +253,7 @@ class HistoryList(MessageList):
*, *,
offset_id: int, offset_id: int,
offset_date: int, offset_date: int,
): ) -> None:
super().__init__() super().__init__()
self._client = client self._client = client
self._peer = peer self._peer = peer
@ -323,7 +323,7 @@ class CherryPickedList(MessageList):
client: Client, client: Client,
peer: PeerRef, peer: PeerRef,
ids: list[int], ids: list[int],
): ) -> None:
super().__init__() super().__init__()
self._client = client self._client = client
self._peer = peer self._peer = peer
@ -367,7 +367,7 @@ class SearchList(MessageList):
query: str, query: str,
offset_id: int, offset_id: int,
offset_date: int, offset_date: int,
): ) -> None:
super().__init__() super().__init__()
self._client = client self._client = client
self._peer = peer self._peer = peer
@ -434,7 +434,7 @@ class GlobalSearchList(MessageList):
query: str, query: str,
offset_id: int, offset_id: int,
offset_date: int, offset_date: int,
): ) -> None:
super().__init__() super().__init__()
self._client = client self._client = client
self._limit = limit self._limit = limit

View File

@ -9,7 +9,7 @@ from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Optional, TypeVar from typing import TYPE_CHECKING, Optional, TypeVar
from ....version import __version__ from ....version import __version__
from ...mtproto import BadStatus, Full, RpcError from ...mtproto import BadStatusError, Full, RpcError
from ...mtsender import Connector, Sender from ...mtsender import Connector, Sender
from ...mtsender import connect as do_connect_sender from ...mtsender import connect as do_connect_sender
from ...session import DataCenter from ...session import DataCenter
@ -120,7 +120,7 @@ async def connect_sender(
), ),
) )
) )
except BadStatus as e: except BadStatusError as e:
if e.status == 404 and auth: if e.status == 404 and auth:
dc = DataCenter( dc = DataCenter(
id=dc.id, ipv4_addr=dc.ipv4_addr, ipv6_addr=dc.ipv6_addr, auth=None id=dc.id, ipv4_addr=dc.ipv4_addr, ipv6_addr=dc.ipv6_addr, auth=None

View File

@ -5,7 +5,7 @@ from collections.abc import Awaitable, Callable
from inspect import isawaitable from inspect import isawaitable
from typing import TYPE_CHECKING, Any, Optional, Sequence, Type from typing import TYPE_CHECKING, Any, Optional, Sequence, Type
from ...session import Gap from ...session import GapError
from ...tl import abcs from ...tl import abcs
from ..events import Continue, Event from ..events import Continue, Event
from ..events.filters import FilterType from ..events.filters import FilterType
@ -80,14 +80,14 @@ def process_socket_updates(client: Client, all_updates: list[abcs.Updates]) -> N
for updates in all_updates: for updates in all_updates:
try: try:
client._message_box.ensure_known_peer_hashes(updates, client._chat_hashes) client._message_box.ensure_known_peer_hashes(updates, client._chat_hashes)
except Gap: except GapError:
return return
try: try:
result, users, chats = client._message_box.process_updates( result, users, chats = client._message_box.process_updates(
updates, client._chat_hashes updates, client._chat_hashes
) )
except Gap: except GapError:
return return
extend_update_queue(client, result, users, chats) extend_update_queue(client, result, users, chats)

View File

@ -25,7 +25,7 @@ async def get_me(self: Client) -> Optional[User]:
class ContactList(AsyncList[User]): class ContactList(AsyncList[User]):
def __init__(self, client: Client): def __init__(self, client: Client) -> None:
super().__init__() super().__init__()
self._client = client self._client = client

View File

@ -46,7 +46,7 @@ class Raw(Event):
client: Client, client: Client,
update: abcs.Update, update: abcs.Update,
chat_map: dict[int, Peer], chat_map: dict[int, Peer],
): ) -> None:
self._client = client self._client = client
self._raw = update self._raw = update
self._chat_map = chat_map self._chat_map = chat_map

View File

@ -25,7 +25,7 @@ class ButtonCallback(Event):
client: Client, client: Client,
update: types.UpdateBotCallbackQuery, update: types.UpdateBotCallbackQuery,
chat_map: dict[int, Peer], chat_map: dict[int, Peer],
): ) -> None:
self._client = client self._client = client
self._raw = update self._raw = update
self._chat_map = chat_map self._chat_map = chat_map
@ -101,7 +101,7 @@ class InlineQuery(Event):
Only bot accounts can receive this event. Only bot accounts can receive this event.
""" """
def __init__(self, update: types.UpdateBotInlineQuery): def __init__(self, update: types.UpdateBotInlineQuery) -> None:
self._raw = update self._raw = update
@classmethod @classmethod

View File

@ -36,20 +36,20 @@ class HTMLToTelegramParser(HTMLParser):
self._open_tags_meta.appendleft(None) self._open_tags_meta.appendleft(None)
attributes = dict(attrs) attributes = dict(attrs)
EntityType: Optional[Type[MessageEntity]] = None entity_type: Optional[Type[MessageEntity]] = None
args = {} args = {}
if tag == "strong" or tag == "b": if tag == "strong" or tag == "b":
EntityType = MessageEntityBold entity_type = MessageEntityBold
elif tag == "em" or tag == "i": elif tag == "em" or tag == "i":
EntityType = MessageEntityItalic entity_type = MessageEntityItalic
elif tag == "u": elif tag == "u":
EntityType = MessageEntityUnderline entity_type = MessageEntityUnderline
elif tag == "del" or tag == "s": elif tag == "del" or tag == "s":
EntityType = MessageEntityStrike entity_type = MessageEntityStrike
elif tag == "blockquote": elif tag == "blockquote":
EntityType = MessageEntityBlockquote entity_type = MessageEntityBlockquote
elif tag == "details": elif tag == "details":
EntityType = MessageEntitySpoiler entity_type = MessageEntitySpoiler
elif tag == "code": elif tag == "code":
try: try:
# If we're in the middle of a <pre> tag, this <code> tag is # If we're in the middle of a <pre> tag, this <code> tag is
@ -63,9 +63,9 @@ class HTMLToTelegramParser(HTMLParser):
if cls := attributes.get("class"): if cls := attributes.get("class"):
pre.language = cls[len("language-") :] pre.language = cls[len("language-") :]
except KeyError: except KeyError:
EntityType = MessageEntityCode entity_type = MessageEntityCode
elif tag == "pre": elif tag == "pre":
EntityType = MessageEntityPre entity_type = MessageEntityPre
args["language"] = "" args["language"] = ""
elif tag == "a": elif tag == "a":
url = attributes.get("href") url = attributes.get("href")
@ -73,20 +73,20 @@ class HTMLToTelegramParser(HTMLParser):
return return
if url.startswith("mailto:"): if url.startswith("mailto:"):
url = url[len("mailto:") :] url = url[len("mailto:") :]
EntityType = MessageEntityEmail entity_type = MessageEntityEmail
else: else:
if self.get_starttag_text() == url: if self.get_starttag_text() == url:
EntityType = MessageEntityUrl entity_type = MessageEntityUrl
else: else:
EntityType = MessageEntityTextUrl entity_type = MessageEntityTextUrl
args["url"] = del_surrogate(url) args["url"] = del_surrogate(url)
url = None url = None
self._open_tags_meta.popleft() self._open_tags_meta.popleft()
self._open_tags_meta.appendleft(url) self._open_tags_meta.appendleft(url)
if EntityType and tag not in self._building_entities: if entity_type and tag not in self._building_entities:
Et = cast(Any, EntityType) any_entity_type = cast(Any, entity_type)
self._building_entities[tag] = Et( self._building_entities[tag] = any_entity_type(
offset=len(self.text), offset=len(self.text),
# The length will be determined when closing the tag. # The length will be determined when closing the tag.
length=0, length=0,

View File

@ -22,7 +22,7 @@ class AlbumBuilder(metaclass=NoPublicConstructor):
This class is constructed by calling :meth:`telethon.Client.prepare_album`. This class is constructed by calling :meth:`telethon.Client.prepare_album`.
""" """
def __init__(self, *, client: Client): def __init__(self, *, client: Client) -> None:
self._client = client self._client = client
self._medias: list[types.InputSingleMedia] = [] self._medias: list[types.InputSingleMedia] = []

View File

@ -122,7 +122,7 @@ class OutWrapper:
_fd: OutFileLike | BufferedWriter _fd: OutFileLike | BufferedWriter
_owned_fd: Optional[BufferedWriter] _owned_fd: Optional[BufferedWriter]
def __init__(self, file: str | Path | OutFileLike): def __init__(self, file: str | Path | OutFileLike) -> None:
if isinstance(file, str): if isinstance(file, str):
file = Path(file) file = Path(file)
@ -166,7 +166,7 @@ class File(metaclass=NoPublicConstructor):
thumbs: Optional[Sequence[abcs.PhotoSize]], thumbs: Optional[Sequence[abcs.PhotoSize]],
raw: Optional[abcs.MessageMedia | abcs.Photo | abcs.Document], raw: Optional[abcs.MessageMedia | abcs.Photo | abcs.Document],
client: Optional[Client], client: Optional[Client],
): ) -> None:
self._attributes = attributes self._attributes = attributes
self._size = size self._size = size
self._name = name self._name = name

View File

@ -25,7 +25,7 @@ class InlineResult(metaclass=NoPublicConstructor):
results: types.messages.BotResults, results: types.messages.BotResults,
result: types.BotInlineMediaResult | types.BotInlineResult, result: types.BotInlineMediaResult | types.BotInlineResult,
default_peer: Optional[PeerRef], default_peer: Optional[PeerRef],
): ) -> None:
self._client = client self._client = client
self._raw_results = results self._raw_results = results
self._raw = result self._raw = result

View File

@ -1,5 +1,5 @@
try: try:
import cryptg import cryptg # type: ignore [import-untyped]
def ige_encrypt( def ige_encrypt(
plaintext: bytes | bytearray | memoryview, key: bytes, iv: bytes plaintext: bytes | bytearray | memoryview, key: bytes, iv: bytes
@ -18,7 +18,7 @@ try:
) )
except ImportError: except ImportError:
import pyaes import pyaes # type: ignore [import-untyped]
def ige_encrypt( def ige_encrypt(
plaintext: bytes | bytearray | memoryview, key: bytes, iv: bytes plaintext: bytes | bytearray | memoryview, key: bytes, iv: bytes

View File

@ -3,7 +3,7 @@ from .authentication import step1 as auth_step1
from .authentication import step2 as auth_step2 from .authentication import step2 as auth_step2
from .authentication import step3 as auth_step3 from .authentication import step3 as auth_step3
from .mtp import ( from .mtp import (
BadMessage, BadMessageError,
Deserialization, Deserialization,
Encrypted, Encrypted,
MsgId, MsgId,
@ -13,7 +13,14 @@ from .mtp import (
RpcResult, RpcResult,
Update, Update,
) )
from .transport import Abridged, BadStatus, Full, Intermediate, MissingBytes, Transport from .transport import (
Abridged,
BadStatusError,
Full,
Intermediate,
MissingBytesError,
Transport,
)
from .utils import DEFAULT_COMPRESSION_THRESHOLD from .utils import DEFAULT_COMPRESSION_THRESHOLD
__all__ = [ __all__ = [
@ -25,7 +32,7 @@ __all__ = [
"auth_step1", "auth_step1",
"auth_step2", "auth_step2",
"auth_step3", "auth_step3",
"BadMessage", "BadMessageError",
"Deserialization", "Deserialization",
"Encrypted", "Encrypted",
"MsgId", "MsgId",
@ -35,10 +42,10 @@ __all__ = [
"RpcResult", "RpcResult",
"Update", "Update",
"Abridged", "Abridged",
"BadStatus", "BadStatusError",
"Full", "Full",
"Intermediate", "Intermediate",
"MissingBytes", "MissingBytesError",
"Transport", "Transport",
"DEFAULT_COMPRESSION_THRESHOLD", "DEFAULT_COMPRESSION_THRESHOLD",
] ]

View File

@ -310,4 +310,4 @@ def check_new_nonce_hash(got: int, expected: int) -> None:
def check_g_in_range(value: int, low: int, high: int) -> None: def check_g_in_range(value: int, low: int, high: int) -> None:
if not (low < value < high): if not (low < value < high):
raise ValueError(f"g parameter {value} not in range({low+1}, {high})") raise ValueError(f"g parameter {value} not in range({low + 1}, {high})")

View File

@ -1,11 +1,19 @@
from .encrypted import Encrypted from .encrypted import Encrypted
from .plain import Plain from .plain import Plain
from .types import BadMessage, Deserialization, MsgId, Mtp, RpcError, RpcResult, Update from .types import (
BadMessageError,
Deserialization,
MsgId,
Mtp,
RpcError,
RpcResult,
Update,
)
__all__ = [ __all__ = [
"Encrypted", "Encrypted",
"Plain", "Plain",
"BadMessage", "BadMessageError",
"Deserialization", "Deserialization",
"MsgId", "MsgId",
"Mtp", "Mtp",

View File

@ -60,7 +60,15 @@ from ..utils import (
gzip_decompress, gzip_decompress,
message_requires_ack, message_requires_ack,
) )
from .types import BadMessage, Deserialization, MsgId, Mtp, RpcError, RpcResult, Update from .types import (
BadMessageError,
Deserialization,
MsgId,
Mtp,
RpcError,
RpcResult,
Update,
)
NUM_FUTURE_SALTS = 64 NUM_FUTURE_SALTS = 64
@ -269,7 +277,7 @@ class Encrypted(Mtp):
bad_msg = AbcBadMsgNotification.from_bytes(message.body) bad_msg = AbcBadMsgNotification.from_bytes(message.body)
assert isinstance(bad_msg, (BadServerSalt, BadMsgNotification)) assert isinstance(bad_msg, (BadServerSalt, BadMsgNotification))
exc = BadMessage(msg_id=MsgId(bad_msg.bad_msg_id), code=bad_msg.error_code) exc = BadMessageError(msg_id=MsgId(bad_msg.bad_msg_id), code=bad_msg.error_code)
if bad_msg.bad_msg_id == self._salt_request_msg_id: if bad_msg.bad_msg_id == self._salt_request_msg_id:
# Response to internal request, do not propagate. # Response to internal request, do not propagate.

View File

@ -15,7 +15,7 @@ class Update:
__slots__ = ("body",) __slots__ = ("body",)
def __init__(self, body: bytes | bytearray | memoryview): def __init__(self, body: bytes | bytearray | memoryview) -> None:
self.body = body self.body = body
@ -26,7 +26,7 @@ class RpcResult:
__slots__ = ("msg_id", "body") __slots__ = ("msg_id", "body")
def __init__(self, msg_id: MsgId, body: bytes | bytearray | memoryview): def __init__(self, msg_id: MsgId, body: bytes | bytearray | memoryview) -> None:
self.msg_id = msg_id self.msg_id = msg_id
self.body = body self.body = body
@ -142,7 +142,7 @@ RETRYABLE_MSG_IDS = {16, 17, 48}
NON_FATAL_MSG_IDS = RETRYABLE_MSG_IDS & {32, 33} NON_FATAL_MSG_IDS = RETRYABLE_MSG_IDS & {32, 33}
class BadMessage(ValueError): class BadMessageError(ValueError):
def __init__( def __init__(
self, self,
*args: object, *args: object,
@ -178,7 +178,7 @@ class BadMessage(ValueError):
return self._code == other._code return self._code == other._code
Deserialization = Update | RpcResult | RpcError | BadMessage Deserialization = Update | RpcResult | RpcError | BadMessageError
# https://core.telegram.org/mtproto/description # https://core.telegram.org/mtproto/description

View File

@ -1,6 +1,13 @@
from .abcs import BadStatus, MissingBytes, Transport from .abcs import BadStatusError, MissingBytesError, Transport
from .abridged import Abridged from .abridged import Abridged
from .full import Full from .full import Full
from .intermediate import Intermediate from .intermediate import Intermediate
__all__ = ["BadStatus", "MissingBytes", "Transport", "Abridged", "Full", "Intermediate"] __all__ = [
"BadStatusError",
"MissingBytesError",
"Transport",
"Abridged",
"Full",
"Intermediate",
]

View File

@ -16,12 +16,12 @@ class Transport(ABC):
pass pass
class MissingBytes(ValueError): class MissingBytesError(ValueError):
def __init__(self, *, expected: int, got: int) -> None: def __init__(self, *, expected: int, got: int) -> None:
super().__init__(f"missing bytes, expected: {expected}, got: {got}") super().__init__(f"missing bytes, expected: {expected}, got: {got}")
class BadStatus(ValueError): class BadStatusError(ValueError):
def __init__(self, *, status: int) -> None: def __init__(self, *, status: int) -> None:
super().__init__(f"transport reported bad status: {status}") super().__init__(f"transport reported bad status: {status}")
self.status = status self.status = status

View File

@ -1,6 +1,6 @@
import struct import struct
from .abcs import BadStatus, MissingBytes, OutFn, Transport from .abcs import BadStatusError, MissingBytesError, OutFn, Transport
class Abridged(Transport): class Abridged(Transport):
@ -38,25 +38,25 @@ class Abridged(Transport):
def unpack(self, input: bytes | bytearray | memoryview, output: bytearray) -> int: def unpack(self, input: bytes | bytearray | memoryview, output: bytearray) -> int:
if not input: if not input:
raise MissingBytes(expected=1, got=0) raise MissingBytesError(expected=1, got=0)
length = input[0] length = input[0]
if 1 < length < 127: if 1 < length < 127:
header_len = 1 header_len = 1
elif len(input) < 4: elif len(input) < 4:
raise MissingBytes(expected=4, got=len(input)) raise MissingBytesError(expected=4, got=len(input))
else: else:
header_len = 4 header_len = 4
length = struct.unpack_from("<i", input)[0] >> 8 length = struct.unpack_from("<i", input)[0] >> 8
if length <= 0: if length <= 0:
if length < 0: if length < 0:
raise BadStatus(status=-length) raise BadStatusError(status=-length)
raise ValueError(f"bad length, expected > 0, got: {length}") raise ValueError(f"bad length, expected > 0, got: {length}")
length *= 4 length *= 4
if len(input) < header_len + length: if len(input) < header_len + length:
raise MissingBytes(expected=header_len + length, got=len(input)) raise MissingBytesError(expected=header_len + length, got=len(input))
output += memoryview(input)[header_len : header_len + length] output += memoryview(input)[header_len : header_len + length]
return header_len + length return header_len + length

View File

@ -1,7 +1,7 @@
import struct import struct
from zlib import crc32 from zlib import crc32
from .abcs import BadStatus, MissingBytes, OutFn, Transport from .abcs import BadStatusError, MissingBytesError, OutFn, Transport
class Full(Transport): class Full(Transport):
@ -37,17 +37,17 @@ class Full(Transport):
def unpack(self, input: bytes | bytearray | memoryview, output: bytearray) -> int: def unpack(self, input: bytes | bytearray | memoryview, output: bytearray) -> int:
if len(input) < 4: if len(input) < 4:
raise MissingBytes(expected=4, got=len(input)) raise MissingBytesError(expected=4, got=len(input))
length = struct.unpack_from("<i", input)[0] length = struct.unpack_from("<i", input)[0]
assert isinstance(length, int) assert isinstance(length, int)
if length < 12: if length < 12:
if length < 0: if length < 0:
raise BadStatus(status=-length) raise BadStatusError(status=-length)
raise ValueError(f"bad length, expected > 12, got: {length}") raise ValueError(f"bad length, expected > 12, got: {length}")
if len(input) < length: if len(input) < length:
raise MissingBytes(expected=length, got=len(input)) raise MissingBytesError(expected=length, got=len(input))
seq = struct.unpack_from("<i", input, 4)[0] seq = struct.unpack_from("<i", input, 4)[0]
if seq != self._recv_seq: if seq != self._recv_seq:

View File

@ -1,6 +1,6 @@
import struct import struct
from .abcs import BadStatus, MissingBytes, OutFn, Transport from .abcs import BadStatusError, MissingBytesError, OutFn, Transport
class Intermediate(Transport): class Intermediate(Transport):
@ -34,19 +34,19 @@ class Intermediate(Transport):
def unpack(self, input: bytes | bytearray | memoryview, output: bytearray) -> int: def unpack(self, input: bytes | bytearray | memoryview, output: bytearray) -> int:
if len(input) < 4: if len(input) < 4:
raise MissingBytes(expected=4, got=len(input)) raise MissingBytesError(expected=4, got=len(input))
length = struct.unpack_from("<i", input)[0] length = struct.unpack_from("<i", input)[0]
assert isinstance(length, int) assert isinstance(length, int)
if len(input) < length: if len(input) < length:
raise MissingBytes(expected=length, got=len(input)) raise MissingBytesError(expected=length, got=len(input))
if length <= 4: if length <= 4:
if ( if (
length >= 4 length >= 4
and (status := struct.unpack("<i", input[4 : 4 + length])[0]) < 0 and (status := struct.unpack("<i", input[4 : 4 + length])[0]) < 0
): ):
raise BadStatus(status=-status) raise BadStatusError(status=-status)
raise ValueError(f"bad length, expected > 0, got: {length}") raise ValueError(f"bad length, expected > 0, got: {length}")

View File

@ -10,9 +10,9 @@ from typing import Generic, Optional, Protocol, Self, Type, TypeVar
from ..crypto import AuthKey from ..crypto import AuthKey
from ..mtproto import ( from ..mtproto import (
BadMessage, BadMessageError,
Encrypted, Encrypted,
MissingBytes, MissingBytesError,
MsgId, MsgId,
Mtp, Mtp,
Plain, Plain,
@ -133,7 +133,7 @@ class NotSerialized(RequestState):
class Serialized(RequestState): class Serialized(RequestState):
__slots__ = ("msg_id", "container_msg_id") __slots__ = ("msg_id", "container_msg_id")
def __init__(self, msg_id: MsgId): def __init__(self, msg_id: MsgId) -> None:
self.msg_id = msg_id self.msg_id = msg_id
self.container_msg_id = msg_id self.container_msg_id = msg_id
@ -141,7 +141,7 @@ class Serialized(RequestState):
class Sent(RequestState): class Sent(RequestState):
__slots__ = ("msg_id", "container_msg_id") __slots__ = ("msg_id", "container_msg_id")
def __init__(self, msg_id: MsgId, container_msg_id: MsgId): def __init__(self, msg_id: MsgId, container_msg_id: MsgId) -> None:
self.msg_id = msg_id self.msg_id = msg_id
self.container_msg_id = container_msg_id self.container_msg_id = container_msg_id
@ -298,7 +298,7 @@ class Sender:
self._mtp_buffer.clear() self._mtp_buffer.clear()
try: try:
n = self._transport.unpack(self._read_buffer, self._mtp_buffer) n = self._transport.unpack(self._read_buffer, self._mtp_buffer)
except MissingBytes: except MissingBytesError:
break break
else: else:
del self._read_buffer[:n] del self._read_buffer[:n]
@ -403,7 +403,7 @@ class Sender:
result, result,
) )
def _process_bad_message(self, result: BadMessage) -> None: def _process_bad_message(self, result: BadMessageError) -> None:
for req in self._drain_requests(result.msg_id): for req in self._drain_requests(result.msg_id):
if result.retryable: if result.retryable:
self._logger.log( self._logger.log(

View File

@ -11,7 +11,7 @@ from .message_box import (
BOT_CHANNEL_DIFF_LIMIT, BOT_CHANNEL_DIFF_LIMIT,
NO_UPDATES_TIMEOUT, NO_UPDATES_TIMEOUT,
USER_CHANNEL_DIFF_LIMIT, USER_CHANNEL_DIFF_LIMIT,
Gap, GapError,
MessageBox, MessageBox,
PossibleGap, PossibleGap,
PrematureEndReason, PrematureEndReason,
@ -32,7 +32,7 @@ __all__ = [
"BOT_CHANNEL_DIFF_LIMIT", "BOT_CHANNEL_DIFF_LIMIT",
"NO_UPDATES_TIMEOUT", "NO_UPDATES_TIMEOUT",
"USER_CHANNEL_DIFF_LIMIT", "USER_CHANNEL_DIFF_LIMIT",
"Gap", "GapError",
"MessageBox", "MessageBox",
"PossibleGap", "PossibleGap",
"PrematureEndReason", "PrematureEndReason",

View File

@ -9,7 +9,7 @@ PeerRefType: TypeAlias = Type[UserRef] | Type[ChannelRef] | Type[GroupRef]
class ChatHashCache: class ChatHashCache:
__slots__ = ("_hash_map", "_self_id", "_self_bot") __slots__ = ("_hash_map", "_self_id", "_self_bot")
def __init__(self, self_user: Optional[tuple[int, bool]]): def __init__(self, self_user: Optional[tuple[int, bool]]) -> None:
self._hash_map: dict[int, tuple[PeerRefType, int]] = {} self._hash_map: dict[int, tuple[PeerRefType, int]] = {}
self._self_id = self_user[0] if self_user else None self._self_id = self_user[0] if self_user else None
self._self_bot = self_user[1] if self_user else False self._self_bot = self_user[1] if self_user else False

View File

@ -2,7 +2,7 @@ from .defs import (
BOT_CHANNEL_DIFF_LIMIT, BOT_CHANNEL_DIFF_LIMIT,
NO_UPDATES_TIMEOUT, NO_UPDATES_TIMEOUT,
USER_CHANNEL_DIFF_LIMIT, USER_CHANNEL_DIFF_LIMIT,
Gap, GapError,
PossibleGap, PossibleGap,
PrematureEndReason, PrematureEndReason,
PtsInfo, PtsInfo,
@ -14,7 +14,7 @@ __all__ = [
"BOT_CHANNEL_DIFF_LIMIT", "BOT_CHANNEL_DIFF_LIMIT",
"NO_UPDATES_TIMEOUT", "NO_UPDATES_TIMEOUT",
"USER_CHANNEL_DIFF_LIMIT", "USER_CHANNEL_DIFF_LIMIT",
"Gap", "GapError",
"PossibleGap", "PossibleGap",
"PrematureEndReason", "PrematureEndReason",
"PtsInfo", "PtsInfo",

View File

@ -2,7 +2,7 @@ from typing import Optional
from ...tl import abcs, types from ...tl import abcs, types
from ..chat import ChatHashCache from ..chat import ChatHashCache
from .defs import ENTRY_ACCOUNT, ENTRY_SECRET, NO_SEQ, Gap, PtsInfo from .defs import ENTRY_ACCOUNT, ENTRY_SECRET, NO_SEQ, GapError, PtsInfo
def updates_(updates: types.Updates) -> types.UpdatesCombined: def updates_(updates: types.Updates) -> types.UpdatesCombined:
@ -147,7 +147,7 @@ def update_short_sent_message(
def adapt(updates: abcs.Updates, chat_hashes: ChatHashCache) -> types.UpdatesCombined: def adapt(updates: abcs.Updates, chat_hashes: ChatHashCache) -> types.UpdatesCombined:
if isinstance(updates, types.UpdatesTooLong): if isinstance(updates, types.UpdatesTooLong):
raise Gap raise GapError
elif isinstance(updates, types.UpdateShortMessage): elif isinstance(updates, types.UpdateShortMessage):
return update_short_message(updates, chat_hashes.self_id) return update_short_message(updates, chat_hashes.self_id)
elif isinstance(updates, types.UpdateShortChatMessage): elif isinstance(updates, types.UpdateShortChatMessage):

View File

@ -80,6 +80,6 @@ class PrematureEndReason(Enum):
BANNED = "ban" BANNED = "ban"
class Gap(ValueError): class GapError(ValueError):
def __repr__(self) -> str: def __repr__(self) -> str:
return "Gap()" return "Gap()"

View File

@ -20,7 +20,7 @@ from .defs import (
POSSIBLE_GAP_TIMEOUT, POSSIBLE_GAP_TIMEOUT,
USER_CHANNEL_DIFF_LIMIT, USER_CHANNEL_DIFF_LIMIT,
Entry, Entry,
Gap, GapError,
PossibleGap, PossibleGap,
PrematureEndReason, PrematureEndReason,
State, State,
@ -252,7 +252,7 @@ class MessageBox:
) )
if can_recover: if can_recover:
self.try_begin_get_diff(ENTRY_ACCOUNT, "missing hash") self.try_begin_get_diff(ENTRY_ACCOUNT, "missing hash")
raise Gap raise GapError
# https://core.telegram.org/api/updates # https://core.telegram.org/api/updates
def process_updates( def process_updates(
@ -281,7 +281,7 @@ class MessageBox:
return result, combined.users, combined.chats return result, combined.users, combined.chats
elif self.seq + 1 < combined.seq_start: elif self.seq + 1 < combined.seq_start:
self.try_begin_get_diff(ENTRY_ACCOUNT, "detected gap") self.try_begin_get_diff(ENTRY_ACCOUNT, "detected gap")
raise Gap raise GapError
def update_sort_key(update: abcs.Update) -> int: def update_sort_key(update: abcs.Update) -> int:
pts = pts_info_from_update(update) pts = pts_info_from_update(update)

View File

@ -168,7 +168,7 @@ class Session:
dcs: Optional[list[DataCenter]] = None, dcs: Optional[list[DataCenter]] = None,
user: Optional[User] = None, user: Optional[User] = None,
state: Optional[UpdateState] = None, state: Optional[UpdateState] = None,
): ) -> None:
self.dcs = dcs or [] self.dcs = dcs or []
"List of known data-centers." "List of known data-centers."
self.user = user self.user = user

View File

@ -13,7 +13,7 @@ class MemorySession(Storage):
__slots__ = ("session",) __slots__ = ("session",)
def __init__(self, session: Optional[Session] = None): def __init__(self, session: Optional[Session] = None) -> None:
self.session = session self.session = session
async def load(self) -> Optional[Session]: async def load(self) -> Optional[Session]:

View File

@ -20,7 +20,7 @@ class SqliteSession(Storage):
an VCS by accident (adding ``*.session`` to ``.gitignore`` will catch them). an VCS by accident (adding ``*.session`` to ``.gitignore`` will catch them).
""" """
def __init__(self, file: str | Path): def __init__(self, file: str | Path) -> None:
path = Path(file) path = Path(file)
if not path.suffix: if not path.suffix:
path = path.with_suffix(EXTENSION) path = path.with_suffix(EXTENSION)

View File

@ -26,11 +26,11 @@ def _bootstrap_get_ty(constructor_id: int) -> Optional[Type["Serializable"]]:
raise RuntimeError( raise RuntimeError(
"generated api and mtproto schemas cannot have colliding constructor identifiers" "generated api and mtproto schemas cannot have colliding constructor identifiers"
) )
ALL_TYPES = API_TYPES | MTPROTO_TYPES all_types = API_TYPES | MTPROTO_TYPES
# Signatures don't fully match, but this is a private method # Signatures don't fully match, but this is a private method
# and all previous uses are compatible with `dict.get`. # and all previous uses are compatible with `dict.get`.
Reader._get_ty = ALL_TYPES.get # type: ignore [assignment] Reader._get_ty = all_types.get # type: ignore [assignment]
return Reader._get_ty(constructor_id) return Reader._get_ty(constructor_id)

View File

@ -17,9 +17,9 @@ def _bootstrap_get_deserializer(
raise RuntimeError( raise RuntimeError(
"generated api and mtproto schemas cannot have colliding constructor identifiers" "generated api and mtproto schemas cannot have colliding constructor identifiers"
) )
ALL_DESER = API_DESER | MTPROTO_DESER all_deser = API_DESER | MTPROTO_DESER
Request._get_deserializer = ALL_DESER.get # type: ignore [assignment] Request._get_deserializer = all_deser.get # type: ignore [assignment]
return Request._get_deserializer(constructor_id) return Request._get_deserializer(constructor_id)

View File

@ -19,6 +19,7 @@ and those you can define when using :meth:`telethon.Client.send_message`:
buttons.Callback('Demo', b'data') buttons.Callback('Demo', b'data')
]) ])
""" """
from .._impl.client.types.buttons import ( from .._impl.client.types.buttons import (
Callback, Callback,
RequestGeoLocation, RequestGeoLocation,

View File

@ -2,6 +2,7 @@ import struct
from typing import Optional from typing import Optional
from pytest import raises from pytest import raises
from telethon._impl.crypto import AuthKey from telethon._impl.crypto import AuthKey
from telethon._impl.mtproto import Encrypted, Plain, RpcError from telethon._impl.mtproto import Encrypted, Plain, RpcError
from telethon._impl.mtproto.mtp.types import MsgId from telethon._impl.mtproto.mtp.types import MsgId

View File

@ -2,6 +2,7 @@ import asyncio
import logging import logging
from pytest import LogCaptureFixture, mark from pytest import LogCaptureFixture, mark
from telethon._impl.mtproto import Full from telethon._impl.mtproto import Full
from telethon._impl.mtsender import connect from telethon._impl.mtsender import connect
from telethon._impl.tl import LAYER, abcs, functions, types from telethon._impl.tl import LAYER, abcs, functions, types

View File

@ -1,6 +1,7 @@
import inspect import inspect
from pytest import raises from pytest import raises
from telethon._impl.session import ChannelRef, GroupRef, PeerRef, UserRef from telethon._impl.session import ChannelRef, GroupRef, PeerRef, UserRef
USER = UserRef(12, 34) USER = UserRef(12, 34)

View File

@ -1,6 +1,7 @@
import struct import struct
from pytest import mark from pytest import mark
from telethon._impl.tl.core import Reader, Serializable from telethon._impl.tl.core import Reader, Serializable
from telethon._impl.tl.mtproto.types import BadServerSalt from telethon._impl.tl.mtproto.types import BadServerSalt
from telethon._impl.tl.types import GeoPoint from telethon._impl.tl.types import GeoPoint

View File

@ -1,4 +1,5 @@
from rsa import PublicKey from rsa import PublicKey
from telethon._impl.crypto.rsa import ( from telethon._impl.crypto.rsa import (
PRODUCTION_RSA_KEY, PRODUCTION_RSA_KEY,
TESTMODE_RSA_KEY, TESTMODE_RSA_KEY,

View File

@ -1,4 +1,5 @@
from pytest import mark from pytest import mark
from telethon._impl.tl.core import serialize_bytes_to from telethon._impl.tl.core import serialize_bytes_to

View File

@ -1,4 +1,5 @@
from pytest import raises from pytest import raises
from telethon._impl.mtproto import Abridged from telethon._impl.mtproto import Abridged

View File

@ -1,4 +1,5 @@
from pytest import raises from pytest import raises
from telethon._impl.mtproto import Full from telethon._impl.mtproto import Full

View File

@ -1,4 +1,5 @@
from pytest import raises from pytest import raises
from telethon._impl.mtproto import Intermediate from telethon._impl.mtproto import Intermediate

View File

@ -1,4 +1,5 @@
from pytest import mark, raises from pytest import mark, raises
from telethon._impl.crypto.two_factor_auth import ( from telethon._impl.crypto.two_factor_auth import (
calculate_2fa, calculate_2fa,
check_p_prime_and_subgroup, check_p_prime_and_subgroup,

View File

@ -1,4 +1,5 @@
from pytest import mark from pytest import mark
from telethon._impl.client.types import AdminRight from telethon._impl.client.types import AdminRight
from telethon._impl.tl import types from telethon._impl.tl import types

View File

@ -37,7 +37,11 @@ build-backend = "setuptools.build_meta"
[tool.setuptools.dynamic] [tool.setuptools.dynamic]
version = {attr = "telethon_generator.version.__version__"} version = {attr = "telethon_generator.version.__version__"}
[tool.ruff] [tool.ruff.lint]
select = ["F", "E", "W", "I", "N", "ANN"]
ignore = [ ignore = [
"E501", # formatter takes care of lines that are too long besides documentation "E501", # formatter takes care of lines that are too long besides documentation
"ANN101", # Missing type annotation for `self` in method
"ANN102", # Missing type annotation for `cls` in classmethod
"ANN401", # Dynamically typed expressions (typing.Any) are not type checked
] ]

View File

@ -7,7 +7,7 @@ from .tl import (
NormalParameter, NormalParameter,
Parameter, Parameter,
Type, Type,
TypeDefNotImplemented, TypeDefNotImplementedError,
) )
from .tl_iterator import FunctionDef, TypeDef from .tl_iterator import FunctionDef, TypeDef
from .tl_iterator import iterate as parse_tl_file from .tl_iterator import iterate as parse_tl_file
@ -19,7 +19,7 @@ __all__ = [
"Definition", "Definition",
"Flag", "Flag",
"Parameter", "Parameter",
"TypeDefNotImplemented", "TypeDefNotImplementedError",
"BaseParameter", "BaseParameter",
"FlagsParameter", "FlagsParameter",
"NormalParameter", "NormalParameter",

View File

@ -1,6 +1,6 @@
from .definition import Definition from .definition import Definition
from .flag import Flag from .flag import Flag
from .parameter import Parameter, TypeDefNotImplemented from .parameter import Parameter, TypeDefNotImplementedError
from .parameter_type import BaseParameter, FlagsParameter, NormalParameter from .parameter_type import BaseParameter, FlagsParameter, NormalParameter
from .ty import Type from .ty import Type
@ -8,7 +8,7 @@ __all__ = [
"Definition", "Definition",
"Flag", "Flag",
"Parameter", "Parameter",
"TypeDefNotImplemented", "TypeDefNotImplementedError",
"BaseParameter", "BaseParameter",
"FlagsParameter", "FlagsParameter",
"NormalParameter", "NormalParameter",

View File

@ -2,7 +2,7 @@ from dataclasses import dataclass
from typing import Self from typing import Self
from ..utils import infer_id from ..utils import infer_id
from .parameter import Parameter, TypeDefNotImplemented from .parameter import Parameter, TypeDefNotImplementedError
from .parameter_type import FlagsParameter, NormalParameter from .parameter_type import FlagsParameter, NormalParameter
from .ty import Type from .ty import Type
@ -65,7 +65,7 @@ class Definition:
for param_str in middle.split(): for param_str in middle.split():
try: try:
param = Parameter.from_str(param_str) param = Parameter.from_str(param_str)
except TypeDefNotImplemented as e: except TypeDefNotImplementedError as e:
type_defs.append(e.name) type_defs.append(e.name)
continue continue

View File

@ -4,8 +4,8 @@ from typing import Self
from .parameter_type import BaseParameter from .parameter_type import BaseParameter
class TypeDefNotImplemented(NotImplementedError): class TypeDefNotImplementedError(NotImplementedError):
def __init__(self, name: str): def __init__(self, name: str) -> None:
super().__init__(f"typedef not implemented: {name}") super().__init__(f"typedef not implemented: {name}")
self.name = name self.name = name
@ -19,7 +19,7 @@ class Parameter:
def from_str(cls, param: str) -> Self: def from_str(cls, param: str) -> Self:
if param.startswith("{"): if param.startswith("{"):
if param.endswith(":Type}"): if param.endswith(":Type}"):
raise TypeDefNotImplemented(param[1 : param.index(":")]) raise TypeDefNotImplementedError(param[1 : param.index(":")])
else: else:
raise ValueError("missing def") raise ValueError("missing def")

View File

@ -9,7 +9,7 @@ from .._impl.tl_parser import (
ParsedTl, ParsedTl,
Type, Type,
TypeDef, TypeDef,
TypeDefNotImplemented, TypeDefNotImplementedError,
load_tl_file, load_tl_file,
parse_tl_file, parse_tl_file,
) )
@ -19,7 +19,7 @@ __all__ = [
"Flag", "Flag",
"Parameter", "Parameter",
"ParsedTl", "ParsedTl",
"TypeDefNotImplemented", "TypeDefNotImplementedError",
"BaseParameter", "BaseParameter",
"FlagsParameter", "FlagsParameter",
"NormalParameter", "NormalParameter",

View File

@ -1,4 +1,5 @@
from pytest import mark from pytest import mark
from telethon_generator._impl.codegen.serde.common import ( from telethon_generator._impl.codegen.serde.common import (
split_words, split_words,
to_class_name, to_class_name,

View File

@ -1,4 +1,5 @@
from pytest import mark, raises from pytest import mark, raises
from telethon_generator.tl_parser import ( from telethon_generator.tl_parser import (
Definition, Definition,
Flag, Flag,

View File

@ -1,4 +1,5 @@
from pytest import mark, raises from pytest import mark, raises
from telethon_generator.tl_parser import ( from telethon_generator.tl_parser import (
Flag, Flag,
FlagsParameter, FlagsParameter,

View File

@ -1,4 +1,5 @@
from pytest import raises from pytest import raises
from telethon_generator.tl_parser import FunctionDef, TypeDef, parse_tl_file from telethon_generator.tl_parser import FunctionDef, TypeDef, parse_tl_file

View File

@ -1,6 +1,7 @@
from typing import Optional from typing import Optional
from pytest import mark, raises from pytest import mark, raises
from telethon_generator.tl_parser import Type from telethon_generator.tl_parser import Type

View File

@ -1,6 +1,7 @@
""" """
Check formatting, type-check and run offline tests. Check formatting, type-check and run offline tests.
""" """
import subprocess import subprocess
import sys import sys
import tempfile import tempfile
@ -15,9 +16,7 @@ def run(*args: str) -> int:
def main() -> None: def main() -> None:
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
exit( exit(
run("isort", ".", "-c", "--profile", "black", "--gitignore") run("mypy", "--strict", ".")
or run("black", ".", "--check", "--extend-exclude", BLACK_IGNORE)
or run("mypy", "--strict", ".")
or run("ruff", "check", ".") or run("ruff", "check", ".")
or run("sphinx", "-M", "dummy", "client/doc", tmp_dir, "-n", "-W") or run("sphinx", "-M", "dummy", "client/doc", tmp_dir, "-n", "-W")
or run("pytest", ".", "-m", "not net") or run("pytest", ".", "-m", "not net")

View File

@ -2,6 +2,7 @@
Run `telethon_generator.codegen` on both `api.tl` and `mtproto.tl` to output Run `telethon_generator.codegen` on both `api.tl` and `mtproto.tl` to output
corresponding Python code in the default directories under the `client/`. corresponding Python code in the default directories under the `client/`.
""" """
import subprocess import subprocess
import sys import sys

View File

@ -110,13 +110,13 @@ def main() -> None:
function.args.args[0].annotation = None function.args.args[0].annotation = None
if isinstance(function, ast.AsyncFunctionDef): if isinstance(function, ast.AsyncFunctionDef):
call = ast.Await(value=call) call = ast.Await(value=call) # type: ignore [arg-type]
match function.returns: match function.returns:
case ast.Constant(value=None): case ast.Constant(value=None):
call = ast.Expr(value=call) call = ast.Expr(value=call) # type: ignore [arg-type]
case _: case _:
call = ast.Return(value=call) call = ast.Return(value=call) # type: ignore [arg-type]
function.body.append(call) function.body.append(call)
class_body.append(function) class_body.append(function)

View File

@ -1,6 +1,7 @@
""" """
Run `sphinx-build` to create HTML documentation and detect errors. Run `sphinx-build` to create HTML documentation and detect errors.
""" """
import subprocess import subprocess
import sys import sys

View File

@ -1,6 +1,7 @@
""" """
Sort imports and format code. Sort imports and format code.
""" """
import subprocess import subprocess
import sys import sys