From 053bd9c02bc852c58c7975d6221fff88d6297bde Mon Sep 17 00:00:00 2001 From: Jahongir Qurbonov Date: Sun, 6 Oct 2024 14:02:23 +0500 Subject: [PATCH] Rollback line length --- client/doc/roles/tl.py | 4 +- client/pyproject.toml | 3 - .../src/telethon/_impl/client/client/auth.py | 20 ++- .../src/telethon/_impl/client/client/bots.py | 14 ++- .../src/telethon/_impl/client/client/chats.py | 35 ++++-- .../telethon/_impl/client/client/client.py | 64 +++++++--- .../telethon/_impl/client/client/dialogs.py | 8 +- .../src/telethon/_impl/client/client/files.py | 40 ++++-- .../telethon/_impl/client/client/messages.py | 88 +++++++++++--- .../src/telethon/_impl/client/client/net.py | 35 ++++-- .../telethon/_impl/client/client/updates.py | 15 ++- .../src/telethon/_impl/client/client/users.py | 14 ++- client/src/telethon/_impl/client/errors.py | 12 +- .../src/telethon/_impl/client/events/event.py | 8 +- .../client/events/filters/combinators.py | 8 +- .../telethon/_impl/client/events/messages.py | 32 +++-- .../telethon/_impl/client/events/queries.py | 14 ++- .../src/telethon/_impl/client/parsers/html.py | 11 +- .../telethon/_impl/client/parsers/markdown.py | 24 +++- .../telethon/_impl/client/parsers/strings.py | 6 +- .../_impl/client/types/album_builder.py | 42 +++++-- .../_impl/client/types/buttons/button.py | 4 +- .../client/types/buttons/inline_button.py | 4 +- .../client/types/buttons/switch_inline.py | 4 +- .../_impl/client/types/chat_restriction.py | 4 +- .../src/telethon/_impl/client/types/dialog.py | 5 +- .../src/telethon/_impl/client/types/draft.py | 38 ++++-- .../src/telethon/_impl/client/types/file.py | 42 +++++-- .../telethon/_impl/client/types/keyboard.py | 13 +- .../telethon/_impl/client/types/message.py | 62 ++++++++-- .../src/telethon/_impl/client/types/meta.py | 13 +- .../_impl/client/types/participant.py | 24 +++- .../_impl/client/types/peer/__init__.py | 14 ++- .../telethon/_impl/client/types/peer/group.py | 12 +- client/src/telethon/_impl/crypto/aes.py | 30 +++-- client/src/telethon/_impl/crypto/auth_key.py | 6 +- client/src/telethon/_impl/crypto/crypto.py | 16 ++- client/src/telethon/_impl/crypto/rsa.py | 12 +- .../telethon/_impl/crypto/two_factor_auth.py | 4 +- .../telethon/_impl/mtproto/authentication.py | 33 +++-- .../telethon/_impl/mtproto/mtp/encrypted.py | 38 ++++-- .../src/telethon/_impl/mtproto/mtp/plain.py | 8 +- .../src/telethon/_impl/mtproto/mtp/types.py | 14 ++- .../_impl/mtproto/transport/intermediate.py | 5 +- client/src/telethon/_impl/mtproto/utils.py | 4 +- client/src/telethon/_impl/mtsender/sender.py | 16 ++- .../telethon/_impl/session/chat/hash_cache.py | 8 +- .../telethon/_impl/session/chat/peer_ref.py | 39 ++++-- .../_impl/session/message_box/adaptor.py | 8 +- .../_impl/session/message_box/defs.py | 8 +- .../_impl/session/message_box/messagebox.py | 115 ++++++++++++++---- .../telethon/_impl/session/storage/sqlite.py | 14 ++- client/src/telethon/_impl/tl/core/reader.py | 8 +- client/src/telethon/_impl/tl/core/request.py | 4 +- .../telethon/_impl/tl/core/serializable.py | 4 +- client/tests/auth_key_test.py | 15 ++- client/tests/crypto_test.py | 14 ++- client/tests/parsers_test.py | 10 +- generator/pyproject.toml | 3 - .../_impl/codegen/generator.py | 36 ++++-- .../_impl/codegen/serde/common.py | 4 +- .../_impl/codegen/serde/deserialization.py | 32 +++-- .../_impl/codegen/serde/serialization.py | 32 +++-- .../_impl/tl_parser/loader.py | 4 +- generator/tests/generator_test.py | 8 +- generator/tests/parameter_test.py | 8 +- generator/tests/ty_test.py | 4 +- 67 files changed, 995 insertions(+), 305 deletions(-) diff --git a/client/doc/roles/tl.py b/client/doc/roles/tl.py index ac6cd515..f48c0a18 100644 --- a/client/doc/roles/tl.py +++ b/client/doc/roles/tl.py @@ -15,7 +15,9 @@ def make_link_node(rawtext, app, name, options): base += "/" set_classes(options) - node = nodes.reference(rawtext, utils.unescape(name), refuri="{}?q={}".format(base, name), **options) + node = nodes.reference( + rawtext, utils.unescape(name), refuri="{}?q={}".format(base, name), **options + ) return node diff --git a/client/pyproject.toml b/client/pyproject.toml index 917e54bf..1f841bd4 100644 --- a/client/pyproject.toml +++ b/client/pyproject.toml @@ -52,9 +52,6 @@ backend-path = ["build_backend"] [tool.setuptools.dynamic] version = {attr = "telethon.version.__version__"} -[tool.ruff] -line-length = 120 - [tool.ruff.lint] select = ["F", "E", "W", "I"] ignore = [ diff --git a/client/src/telethon/_impl/client/client/auth.py b/client/src/telethon/_impl/client/client/auth.py index 2dfd05ba..7220a4ec 100644 --- a/client/src/telethon/_impl/client/client/auth.py +++ b/client/src/telethon/_impl/client/client/auth.py @@ -31,7 +31,9 @@ async def complete_login(client: Client, auth: abcs.auth.Authorization) -> User: assert isinstance(auth, types.auth.Authorization) assert isinstance(auth.user, types.User) user = User._from_raw(auth.user) - client._session.user = SessionUser(id=user.id, dc=client._sender.dc_id, bot=user.bot, username=user.username) + client._session.user = SessionUser( + id=user.id, dc=client._sender.dc_id, bot=user.bot, username=user.username + ) client._chat_hashes.set_self_user(user.id, user.bot) @@ -54,7 +56,9 @@ async def complete_login(client: Client, auth: abcs.auth.Authorization) -> User: async def handle_migrate(client: Client, dc_id: Optional[int]) -> None: assert dc_id is not None - sender, client._session.dcs = await connect_sender(client._config, client._session.dcs, DataCenter(id=dc_id)) + sender, client._session.dcs = await connect_sender( + client._config, client._session.dcs, DataCenter(id=dc_id) + ) async with client._sender_lock: client._sender = sender @@ -173,7 +177,9 @@ async def interactive_login( user = await self.check_password(user_or_token, password) else: while True: - print("Please enter your password (prompt is hidden; type and press enter)") + print( + "Please enter your password (prompt is hidden; type and press enter)" + ) password = getpass.getpass(": ") try: user = await self.check_password(user_or_token, password) @@ -202,9 +208,13 @@ async def get_password_information(client: Client) -> PasswordToken: return PasswordToken._new(result) -async def check_password(self: Client, token: PasswordToken, password: str | bytes) -> User: +async def check_password( + self: Client, token: PasswordToken, password: str | bytes +) -> User: algo = token._password.current_algo - if not isinstance(algo, types.PasswordKdfAlgoSha256Sha256Pbkdf2HmacshA512Iter100000Sha256ModPow): + if not isinstance( + algo, types.PasswordKdfAlgoSha256Sha256Pbkdf2HmacshA512Iter100000Sha256ModPow + ): raise RuntimeError("unrecognised 2FA algorithm") if not two_factor_auth.check_p_and_g(algo.p, algo.g): diff --git a/client/src/telethon/_impl/client/client/bots.py b/client/src/telethon/_impl/client/client/bots.py index 5464e836..10aa79d8 100644 --- a/client/src/telethon/_impl/client/client/bots.py +++ b/client/src/telethon/_impl/client/client/bots.py @@ -38,7 +38,11 @@ class InlineResults(metaclass=NoPublicConstructor): result = await self._client( functions.messages.get_inline_bot_results( bot=self._bot, - peer=(self._peer._to_input_peer() if self._peer else types.InputPeerEmpty()), + peer=( + self._peer._to_input_peer() + if self._peer + else types.InputPeerEmpty() + ), geo_point=None, query=self._query, offset=self._offset, @@ -47,8 +51,12 @@ class InlineResults(metaclass=NoPublicConstructor): assert isinstance(result, types.messages.BotResults) self._offset = result.next_offset for r in reversed(result.results): - assert isinstance(r, (types.BotInlineMediaResult, types.BotInlineResult)) - self._buffer.append(InlineResult._create(self._client, result, r, self._peer)) + assert isinstance( + r, (types.BotInlineMediaResult, types.BotInlineResult) + ) + self._buffer.append( + InlineResult._create(self._client, result, r, self._peer) + ) if not self._buffer: self._offset = None diff --git a/client/src/telethon/_impl/client/client/chats.py b/client/src/telethon/_impl/client/client/chats.py index 900d3b3a..9c507709 100644 --- a/client/src/telethon/_impl/client/client/chats.py +++ b/client/src/telethon/_impl/client/client/chats.py @@ -53,7 +53,9 @@ class ParticipantList(AsyncList[Participant]): seen_count = len(self._seen) for p in chanp.participants: - part = Participant._from_raw_channel(self._client, self._peer, p, chat_map) + part = Participant._from_raw_channel( + self._client, self._peer, p, chat_map + ) pid = part._peer_id() if pid not in self._seen: self._seen.add(pid) @@ -64,7 +66,9 @@ class ParticipantList(AsyncList[Participant]): self._done = len(self._seen) == seen_count else: - chatp = await self._client(functions.messages.get_full_chat(chat_id=self._peer._to_input_chat())) + chatp = await self._client( + functions.messages.get_full_chat(chat_id=self._peer._to_input_chat()) + ) assert isinstance(chatp, types.messages.ChatFull) assert isinstance(chatp.full_chat, types.ChatFull) @@ -83,14 +87,17 @@ class ParticipantList(AsyncList[Participant]): ) elif isinstance(participants, types.ChatParticipants): self._buffer.extend( - Participant._from_raw_chat(self._client, self._peer, p, chat_map) for p in participants.participants + Participant._from_raw_chat(self._client, self._peer, p, chat_map) + for p in participants.participants ) self._total = len(self._buffer) self._done = True -def get_participants(self: Client, chat: Group | Channel | GroupRef | ChannelRef, /) -> AsyncList[Participant]: +def get_participants( + self: Client, chat: Group | Channel | GroupRef | ChannelRef, / +) -> AsyncList[Participant]: return ParticipantList(self, chat._ref) @@ -130,7 +137,9 @@ class RecentActionList(AsyncList[RecentAction]): self._offset = min(e.id for e in self._buffer) -def get_admin_log(self: Client, chat: Group | Channel | GroupRef | ChannelRef, /) -> AsyncList[RecentAction]: +def get_admin_log( + self: Client, chat: Group | Channel | GroupRef | ChannelRef, / +) -> AsyncList[RecentAction]: return RecentActionList(self, chat._ref) @@ -165,7 +174,11 @@ class ProfilePhotoList(AsyncList[File]): else: raise RuntimeError("unexpected case") - self._buffer.extend(filter(None, (File._try_from_raw_photo(self._client, p) for p in photos))) + self._buffer.extend( + filter( + None, (File._try_from_raw_photo(self._client, p) for p in photos) + ) + ) def get_profile_photos(self: Client, peer: Peer | PeerRef, /) -> AsyncList[File]: @@ -243,7 +256,11 @@ async def set_chat_default_restrictions( *, until: Optional[datetime.datetime] = None, ) -> None: - banned_rights = ChatRestriction._set_to_raw(set(restrictions), int(until.timestamp()) if until else 0x7FFFFFFF) - await self( - functions.messages.edit_chat_default_banned_rights(peer=chat._ref._to_input_peer(), banned_rights=banned_rights) + banned_rights = ChatRestriction._set_to_raw( + set(restrictions), int(until.timestamp()) if until else 0x7FFFFFFF + ) + await self( + functions.messages.edit_chat_default_banned_rights( + peer=chat._ref._to_input_peer(), banned_rights=banned_rights + ) ) diff --git a/client/src/telethon/_impl/client/client/client.py b/client/src/telethon/_impl/client/client/client.py index 90a3f835..fd9cd1f9 100644 --- a/client/src/telethon/_impl/client/client/client.py +++ b/client/src/telethon/_impl/client/client/client.py @@ -237,7 +237,9 @@ class Client: lang_code=lang_code or "en", catch_up=catch_up or False, datacenter=datacenter, - flood_sleep_threshold=(60 if flood_sleep_threshold is None else flood_sleep_threshold), + flood_sleep_threshold=( + 60 if flood_sleep_threshold is None else flood_sleep_threshold + ), update_queue_limit=update_queue_limit, base_logger=base_logger, connector=connector or (lambda ip, port: asyncio.open_connection(ip, port)), @@ -248,8 +250,8 @@ class Client: self._message_box = MessageBox(base_logger=base_logger) self._chat_hashes = ChatHashCache(None) self._last_update_limit_warn: Optional[float] = None - self._updates: asyncio.Queue[tuple[abcs.Update, dict[int, Peer]]] = asyncio.Queue( - maxsize=self._config.update_queue_limit or 0 + self._updates: asyncio.Queue[tuple[abcs.Update, dict[int, Peer]]] = ( + asyncio.Queue(maxsize=self._config.update_queue_limit or 0) ) self._dispatcher: Optional[asyncio.Task[None]] = None self._handlers: dict[ @@ -409,7 +411,9 @@ class Client: """ await delete_dialog(self, dialog) - async def delete_messages(self, chat: Peer | PeerRef, /, message_ids: list[int], *, revoke: bool = True) -> int: + async def delete_messages( + self, chat: Peer | PeerRef, /, message_ids: list[int], *, revoke: bool = True + ) -> int: """ Delete messages. @@ -649,7 +653,9 @@ class Client: """ return await forward_messages(self, target, message_ids, source) - def get_admin_log(self, chat: Group | Channel | GroupRef | ChannelRef, /) -> AsyncList[RecentAction]: + def get_admin_log( + self, chat: Group | Channel | GroupRef | ChannelRef, / + ) -> AsyncList[RecentAction]: """ Get the recent actions from the administrator's log. @@ -751,7 +757,9 @@ class Client: """ return get_file_bytes(self, media) - def get_handler_filter(self, handler: Callable[[Event], Awaitable[Any]], /) -> Optional[FilterType]: + def get_handler_filter( + self, handler: Callable[[Event], Awaitable[Any]], / + ) -> Optional[FilterType]: """ Get the filter associated to the given event handler. @@ -853,9 +861,13 @@ class Client: async for message in reversed(client.get_messages(chat)): print(message.sender.name, ':', message.markdown_text) """ - return get_messages(self, chat, limit, offset_id=offset_id, offset_date=offset_date) + return get_messages( + self, chat, limit, offset_id=offset_id, offset_date=offset_date + ) - def get_messages_with_ids(self, chat: Peer | PeerRef, /, message_ids: list[int]) -> AsyncList[Message]: + def get_messages_with_ids( + self, chat: Peer | PeerRef, /, message_ids: list[int] + ) -> AsyncList[Message]: """ Get the full message objects from the corresponding message identifiers. @@ -879,7 +891,9 @@ class Client: """ return get_messages_with_ids(self, chat, message_ids) - def get_participants(self, chat: Group | Channel | GroupRef | ChannelRef, /) -> AsyncList[Participant]: + def get_participants( + self, chat: Group | Channel | GroupRef | ChannelRef, / + ) -> AsyncList[Participant]: """ Get the participants in a group or channel, along with their permissions. @@ -967,7 +981,9 @@ class Client: """ return await inline_query(self, bot, query, peer=peer) - async def interactive_login(self, phone_or_token: Optional[str] = None, *, password: Optional[str] = None) -> User: + async def interactive_login( + self, phone_or_token: Optional[str] = None, *, password: Optional[str] = None + ) -> User: """ Begin an interactive login if needed. If the account was already logged-in, this method simply returns :term:`yourself`. @@ -1019,7 +1035,9 @@ class Client: def on( self, event_cls: Type[Event], /, filter: Optional[FilterType] = None - ) -> Callable[[Callable[[Event], Awaitable[Any]]], Callable[[Event], Awaitable[Any]]]: + ) -> Callable[ + [Callable[[Event], Awaitable[Any]]], Callable[[Event], Awaitable[Any]] + ]: """ Register the decorated function to be invoked when the provided event type occurs. @@ -1107,7 +1125,9 @@ class Client: """ return prepare_album(self) - async def read_message(self, chat: Peer | PeerRef, /, message_id: int | Literal["all"]) -> None: + async def read_message( + self, chat: Peer | PeerRef, /, message_id: int | Literal["all"] + ) -> None: """ Mark messages as read. @@ -1138,7 +1158,9 @@ class Client: """ await read_message(self, chat, message_id) - def remove_event_handler(self, handler: Callable[[Event], Awaitable[Any]], /) -> None: + def remove_event_handler( + self, handler: Callable[[Event], Awaitable[Any]], / + ) -> None: """ Remove the handler as a function to be called when events occur. This is simply the opposite of :meth:`add_event_handler`. @@ -1302,7 +1324,9 @@ class Client: async for message in client.search_all_messages(query='hello'): print(message.text) """ - return search_all_messages(self, limit, query=query, offset_id=offset_id, offset_date=offset_date) + return search_all_messages( + self, limit, query=query, offset_id=offset_id, offset_date=offset_date + ) def search_messages( self, @@ -1348,7 +1372,9 @@ class Client: async for message in client.search_messages(chat, query='hello'): print(message.text) """ - return search_messages(self, chat, limit, query=query, offset_id=offset_id, offset_date=offset_date) + return search_messages( + self, chat, limit, query=query, offset_id=offset_id, offset_date=offset_date + ) async def send_audio( self, @@ -1949,7 +1975,9 @@ class Client: :meth:`telethon.types.Participant.set_restrictions` """ - await set_participant_restrictions(self, chat, participant, restrictions, until=until) + await set_participant_restrictions( + self, chat, participant, restrictions, until=until + ) async def sign_in(self, token: LoginToken, code: str) -> User | PasswordToken: """ @@ -1996,7 +2024,9 @@ class Client: """ await sign_out(self) - async def unpin_message(self, chat: Peer | PeerRef, /, message_id: int | Literal["all"]) -> None: + async def unpin_message( + self, chat: Peer | PeerRef, /, message_id: int | Literal["all"] + ) -> None: """ Unpin one or all messages from the top. diff --git a/client/src/telethon/_impl/client/client/dialogs.py b/client/src/telethon/_impl/client/client/dialogs.py index 1dba8752..2b3cf9f3 100644 --- a/client/src/telethon/_impl/client/client/dialogs.py +++ b/client/src/telethon/_impl/client/client/dialogs.py @@ -51,7 +51,9 @@ class DialogList(AsyncList[Dialog]): chat_map = build_chat_map(self._client, result.users, result.chats) msg_map = build_msg_map(self._client, result.messages, chat_map) - self._buffer.extend(Dialog._from_raw(self._client, d, chat_map, msg_map) for d in result.dialogs) + self._buffer.extend( + Dialog._from_raw(self._client, d, chat_map, msg_map) for d in result.dialogs + ) def get_dialogs(self: Client) -> AsyncList[Dialog]: @@ -128,7 +130,9 @@ async def edit_draft( reply_to: Optional[int] = None, ) -> Draft: peer = peer._ref - message, entities = parse_message(text=text, markdown=markdown, html=html, allow_empty=False) + message, entities = parse_message( + text=text, markdown=markdown, html=html, allow_empty=False + ) result = await self( functions.messages.save_draft( diff --git a/client/src/telethon/_impl/client/client/files.py b/client/src/telethon/_impl/client/client/files.py index a52c9c23..a0ec66ed 100644 --- a/client/src/telethon/_impl/client/client/files.py +++ b/client/src/telethon/_impl/client/client/files.py @@ -189,11 +189,15 @@ async def send_file( reply_to: Optional[int] = None, keyboard: Optional[KeyboardType] = None, ) -> Message: - message, entities = parse_message(text=caption, markdown=caption_markdown, html=caption_html, allow_empty=True) + message, entities = parse_message( + text=caption, markdown=caption_markdown, html=caption_html, allow_empty=True + ) # Re-send existing file. if isinstance(file, File): - return await do_send_file(self, chat, file._input_media, message, entities, reply_to, keyboard) + return await do_send_file( + self, chat, file._input_media, message, entities, reply_to, keyboard + ) # URLs are handled early as they can't use any other attributes either. input_media: abcs.InputMedia @@ -208,10 +212,16 @@ async def send_file( else: as_photo = False if as_photo: - input_media = types.InputMediaPhotoExternal(spoiler=False, url=file, ttl_seconds=None) + input_media = types.InputMediaPhotoExternal( + spoiler=False, url=file, ttl_seconds=None + ) else: - input_media = types.InputMediaDocumentExternal(spoiler=False, url=file, ttl_seconds=None) - return await do_send_file(self, chat, input_media, message, entities, reply_to, keyboard) + input_media = types.InputMediaDocumentExternal( + spoiler=False, url=file, ttl_seconds=None + ) + return await do_send_file( + self, chat, input_media, message, entities, reply_to, keyboard + ) input_file, name = await upload(self, file, size, name) @@ -275,7 +285,9 @@ async def send_file( ttl_seconds=None, ) - return await do_send_file(self, chat, input_media, message, entities, reply_to, keyboard) + return await do_send_file( + self, chat, input_media, message, entities, reply_to, keyboard + ) async def do_send_file( @@ -297,7 +309,11 @@ async def do_send_file( noforwards=False, update_stickersets_order=False, peer=chat._ref._to_input_peer(), - reply_to=(types.InputReplyToMessage(reply_to_msg_id=reply_to, top_msg_id=None) if reply_to else None), + reply_to=( + types.InputReplyToMessage(reply_to_msg_id=reply_to, top_msg_id=None) + if reply_to + else None + ), media=input_media, message=message, random_id=random_id, @@ -379,7 +395,11 @@ async def do_upload( ) ) else: - await client(functions.upload.save_file_part(file_id=file_id, file_part=part, bytes=to_store)) + await client( + functions.upload.save_file_part( + file_id=file_id, file_part=part, bytes=to_store + ) + ) hash_md5.update(to_store) buffer.clear() @@ -438,7 +458,9 @@ def get_file_bytes(self: Client, media: File, /) -> AsyncList[bytes]: return FileBytesList(self, media) -async def download(self: Client, media: File, /, file: str | Path | OutFileLike) -> None: +async def download( + self: Client, media: File, /, file: str | Path | OutFileLike +) -> None: fd = OutWrapper(file) try: async for chunk in get_file_bytes(self, media): diff --git a/client/src/telethon/_impl/client/client/messages.py b/client/src/telethon/_impl/client/client/messages.py index cc960a66..23f038a9 100644 --- a/client/src/telethon/_impl/client/client/messages.py +++ b/client/src/telethon/_impl/client/client/messages.py @@ -46,7 +46,9 @@ async def send_message( update_stickersets_order=False, peer=chat._ref._to_input_peer(), reply_to=( - types.InputReplyToMessage(reply_to_msg_id=text.replied_message_id, top_msg_id=None) + types.InputReplyToMessage( + reply_to_msg_id=text.replied_message_id, top_msg_id=None + ) if text.replied_message_id else None ), @@ -58,7 +60,9 @@ async def send_message( send_as=None, ) else: - message, entities = parse_message(text=text, markdown=markdown, html=html, allow_empty=False) + message, entities = parse_message( + text=text, markdown=markdown, html=html, allow_empty=False + ) request = functions.messages.send_message( no_webpage=not link_preview, silent=False, @@ -67,7 +71,11 @@ async def send_message( noforwards=False, update_stickersets_order=False, peer=chat._ref._to_input_peer(), - reply_to=(types.InputReplyToMessage(reply_to_msg_id=reply_to, top_msg_id=None) if reply_to else None), + reply_to=( + types.InputReplyToMessage(reply_to_msg_id=reply_to, top_msg_id=None) + if reply_to + else None + ), message=message, random_id=random_id, reply_markup=keyboard._raw if keyboard else None, @@ -83,7 +91,11 @@ async def send_message( {}, out=result.out, id=result.id, - from_id=(types.PeerUser(user_id=self._session.user.id) if self._session.user else None), + from_id=( + types.PeerUser(user_id=self._session.user.id) + if self._session.user + else None + ), peer_id=chat._ref._to_peer(), reply_to=( types.MessageReplyHeader( @@ -118,7 +130,9 @@ async def edit_message( link_preview: bool = False, keyboard: Optional[KeyboardType] = None, ) -> Message: - message, entities = parse_message(text=text, markdown=markdown, html=html, allow_empty=False) + message, entities = parse_message( + text=text, markdown=markdown, html=html, allow_empty=False + ) return self._build_message_map( await self( functions.messages.edit_message( @@ -146,9 +160,15 @@ async def delete_messages( ) -> int: peer = chat._ref if isinstance(peer, ChannelRef): - affected = await self(functions.channels.delete_messages(channel=peer._to_input_channel(), id=message_ids)) + affected = await self( + functions.channels.delete_messages( + channel=peer._to_input_channel(), id=message_ids + ) + ) else: - affected = await self(functions.messages.delete_messages(revoke=revoke, id=message_ids)) + affected = await self( + functions.messages.delete_messages(revoke=revoke, id=message_ids) + ) assert isinstance(affected, types.messages.AffectedMessages) return affected.pts_count @@ -185,7 +205,9 @@ class MessageList(AsyncList[Message]): super().__init__() self._reversed = False - def _extend_buffer(self, client: Client, messages: abcs.messages.Messages) -> dict[int, Peer]: + def _extend_buffer( + self, client: Client, messages: abcs.messages.Messages + ) -> dict[int, Peer]: if isinstance(messages, types.messages.MessagesNotModified): self._total = messages.count return {} @@ -193,7 +215,9 @@ class MessageList(AsyncList[Message]): if isinstance(messages, types.messages.Messages): self._total = len(messages.messages) self._done = True - elif isinstance(messages, (types.messages.MessagesSlice, types.messages.ChannelMessages)): + elif isinstance( + messages, (types.messages.MessagesSlice, types.messages.ChannelMessages) + ): self._total = messages.count else: raise RuntimeError("unexpected case") @@ -201,7 +225,9 @@ class MessageList(AsyncList[Message]): chat_map = build_chat_map(client, messages.users, messages.chats) self._buffer.extend( Message._from_raw(client, m, chat_map) - for m in (reversed(messages.messages) if self._reversed else messages.messages) + for m in ( + reversed(messages.messages) if self._reversed else messages.messages + ) ) return chat_map @@ -209,7 +235,11 @@ class MessageList(AsyncList[Message]): self, ) -> types.Message | types.MessageService | types.MessageEmpty: return next( - (m._raw for m in reversed(self._buffer) if not isinstance(m._raw, types.MessageEmpty)), + ( + m._raw + for m in reversed(self._buffer) + if not isinstance(m._raw, types.MessageEmpty) + ), types.MessageEmpty(id=0, peer_id=None), ) @@ -305,10 +335,14 @@ class CherryPickedList(MessageList): if isinstance(self._peer, ChannelRef): result = await self._client( - functions.channels.get_messages(channel=self._peer._to_input_channel(), id=self._ids[:100]) + functions.channels.get_messages( + channel=self._peer._to_input_channel(), id=self._ids[:100] + ) ) else: - result = await self._client(functions.messages.get_messages(id=self._ids[:100])) + result = await self._client( + functions.messages.get_messages(id=self._ids[:100]) + ) self._extend_buffer(self._client, result) self._ids = self._ids[100:] @@ -458,7 +492,9 @@ def search_all_messages( ) -async def pin_message(self: Client, chat: Peer | PeerRef, /, message_id: int) -> Message: +async def pin_message( + self: Client, chat: Peer | PeerRef, /, message_id: int +) -> Message: return self._build_message_map( await self( functions.messages.update_pinned_message( @@ -473,7 +509,9 @@ async def pin_message(self: Client, chat: Peer | PeerRef, /, message_id: int) -> ).get_single() -async def unpin_message(self: Client, chat: Peer | PeerRef, /, message_id: int | Literal["all"]) -> None: +async def unpin_message( + self: Client, chat: Peer | PeerRef, /, message_id: int | Literal["all"] +) -> None: if message_id == "all": await self( functions.messages.unpin_all_messages( @@ -493,15 +531,25 @@ async def unpin_message(self: Client, chat: Peer | PeerRef, /, message_id: int | ) -async def read_message(self: Client, chat: Peer | PeerRef, /, message_id: int | Literal["all"]) -> None: +async def read_message( + self: Client, chat: Peer | PeerRef, /, message_id: int | Literal["all"] +) -> None: if message_id == "all": message_id = 0 peer = chat._ref if isinstance(peer, ChannelRef): - await self(functions.channels.read_history(channel=peer._to_input_channel(), max_id=message_id)) + await self( + functions.channels.read_history( + channel=peer._to_input_channel(), max_id=message_id + ) + ) else: - await self(functions.messages.read_history(peer=peer._ref._to_input_peer(), max_id=message_id)) + await self( + functions.messages.read_history( + peer=peer._ref._to_input_peer(), max_id=message_id + ) + ) class MessageMap: @@ -536,7 +584,9 @@ class MessageMap: def _empty(self, id: int = 0) -> Message: return Message._from_raw( self._client, - types.MessageEmpty(id=id, peer_id=self._peer._to_peer() if self._peer else None), + types.MessageEmpty( + id=id, peer_id=self._peer._to_peer() if self._peer else None + ), {}, ) diff --git a/client/src/telethon/_impl/client/client/net.py b/client/src/telethon/_impl/client/client/net.py index c3452737..23634cdb 100644 --- a/client/src/telethon/_impl/client/client/net.py +++ b/client/src/telethon/_impl/client/client/net.py @@ -82,9 +82,16 @@ async def connect_sender( # Only the ID of the input DC may be known. # Find the corresponding address and authentication key if needed. addr = dc.ipv4_addr or next( - d.ipv4_addr for d in itertools.chain(known_dcs, KNOWN_DCS) if d.id == dc.id and d.ipv4_addr + d.ipv4_addr + for d in itertools.chain(known_dcs, KNOWN_DCS) + if d.id == dc.id and d.ipv4_addr + ) + auth = ( + None + if force_auth_gen + else dc.auth + or (next((d.auth for d in known_dcs if d.id == dc.id and d.auth), None)) ) - auth = None if force_auth_gen else dc.auth or (next((d.auth for d in known_dcs if d.id == dc.id and d.auth), None)) sender = await do_connect_sender( Full(), @@ -115,7 +122,9 @@ async def connect_sender( ) except BadStatus as e: if e.status == 404 and auth: - dc = DataCenter(id=dc.id, ipv4_addr=dc.ipv4_addr, ipv6_addr=dc.ipv6_addr, auth=None) + dc = DataCenter( + id=dc.id, ipv4_addr=dc.ipv4_addr, ipv6_addr=dc.ipv6_addr, auth=None + ) config.base_logger.warning( "datacenter could not find stored auth; will retry generating a new one: %s", dc, @@ -165,8 +174,12 @@ async def connect(self: Client) -> None: if session := await self._storage.load(): self._session = session - datacenter = self._config.datacenter or DataCenter(id=self._session.user.dc if self._session.user else DEFAULT_DC) - self._sender, self._session.dcs = await connect_sender(self._config, self._session.dcs, datacenter) + datacenter = self._config.datacenter or DataCenter( + id=self._session.user.dc if self._session.user else DEFAULT_DC + ) + self._sender, self._session.dcs = await connect_sender( + self._config, self._session.dcs, datacenter + ) if self._message_box.is_empty() and self._session.user: try: @@ -180,7 +193,9 @@ async def connect(self: Client) -> None: me = await self.get_me() assert me is not None self._chat_hashes.set_self_user(me.id, me.bot) - self._session.user = SessionUser(id=me.id, dc=self._sender.dc_id, bot=me.bot, username=me.username) + self._session.user = SessionUser( + id=me.id, dc=self._sender.dc_id, bot=me.bot, username=me.username + ) self._dispatcher = asyncio.create_task(dispatcher(self)) @@ -199,14 +214,18 @@ async def disconnect(self: Client) -> None: except asyncio.CancelledError: pass except Exception: - self._config.base_logger.exception("unhandled exception when cancelling dispatcher; this is a bug") + self._config.base_logger.exception( + "unhandled exception when cancelling dispatcher; this is a bug" + ) finally: self._dispatcher = None try: await sender.disconnect() except Exception: - self._config.base_logger.exception("unhandled exception during disconnect; this is a bug") + self._config.base_logger.exception( + "unhandled exception during disconnect; this is a bug" + ) try: if self._session.user: diff --git a/client/src/telethon/_impl/client/client/updates.py b/client/src/telethon/_impl/client/client/updates.py index 8852c028..2ef036a4 100644 --- a/client/src/telethon/_impl/client/client/updates.py +++ b/client/src/telethon/_impl/client/client/updates.py @@ -40,7 +40,9 @@ def add_event_handler( self._handlers.setdefault(event_cls, []).append((handler, filter)) -def remove_event_handler(self: Client, handler: Callable[[Event], Awaitable[Any]], /) -> None: +def remove_event_handler( + self: Client, handler: Callable[[Event], Awaitable[Any]], / +) -> None: for event_cls, handlers in tuple(self._handlers.items()): for i in reversed(range(len(handlers))): if handlers[i][0] == handler: @@ -49,7 +51,9 @@ def remove_event_handler(self: Client, handler: Callable[[Event], Awaitable[Any] del self._handlers[event_cls] -def get_handler_filter(self: Client, handler: Callable[[Event], Awaitable[Any]], /) -> Optional[FilterType]: +def get_handler_filter( + self: Client, handler: Callable[[Event], Awaitable[Any]], / +) -> Optional[FilterType]: for handlers in self._handlers.values(): for h, f in handlers: if h == handler: @@ -80,7 +84,9 @@ def process_socket_updates(client: Client, all_updates: list[abcs.Updates]) -> N return try: - result, users, chats = client._message_box.process_updates(updates, client._chat_hashes) + result, users, chats = client._message_box.process_updates( + updates, client._chat_hashes + ) except Gap: return @@ -101,7 +107,8 @@ def extend_update_queue( except asyncio.QueueFull: now = asyncio.get_running_loop().time() if client._last_update_limit_warn is None or ( - now - client._last_update_limit_warn > UPDATE_LIMIT_EXCEEDED_LOG_COOLDOWN + now - client._last_update_limit_warn + > UPDATE_LIMIT_EXCEEDED_LOG_COOLDOWN ): client._config.base_logger.warning( "updates are being dropped because limit=%d has been reached", diff --git a/client/src/telethon/_impl/client/client/users.py b/client/src/telethon/_impl/client/client/users.py index 63fe3f3b..345f05d7 100644 --- a/client/src/telethon/_impl/client/client/users.py +++ b/client/src/telethon/_impl/client/client/users.py @@ -53,11 +53,15 @@ def resolved_peer_to_chat(client: Client, resolved: abcs.contacts.ResolvedPeer) async def resolve_phone(self: Client, phone: str, /) -> Peer: - return resolved_peer_to_chat(self, await self(functions.contacts.resolve_phone(phone=phone))) + return resolved_peer_to_chat( + self, await self(functions.contacts.resolve_phone(phone=phone)) + ) async def resolve_username(self: Client, username: str, /) -> Peer: - return resolved_peer_to_chat(self, await self(functions.contacts.resolve_username(username=username))) + return resolved_peer_to_chat( + self, await self(functions.contacts.resolve_username(username=username)) + ) async def resolve_peers(self: Client, peers: Sequence[Peer | PeerRef], /) -> list[Peer]: @@ -95,4 +99,8 @@ async def resolve_peers(self: Client, peers: Sequence[Peer | PeerRef], /) -> lis chats.extend(ret_chats.chats) chat_map = build_chat_map(self, users, chats) - return [chat_map.get(ref.identifier) or expand_peer(self, ref._to_peer(), broadcast=None) for ref in refs] + return [ + chat_map.get(ref.identifier) + or expand_peer(self, ref._to_peer(), broadcast=None) + for ref in refs + ] diff --git a/client/src/telethon/_impl/client/errors.py b/client/src/telethon/_impl/client/errors.py index 43040121..746c4bbd 100644 --- a/client/src/telethon/_impl/client/errors.py +++ b/client/src/telethon/_impl/client/errors.py @@ -34,13 +34,17 @@ def from_name(name: str, *, _cache: dict[str, Type[RpcError]] = {}) -> Type[RpcE return _cache[name] -def adapt_rpc(error: RpcError, *, _cache: dict[tuple[int, str], Type[RpcError]] = {}) -> RpcError: +def adapt_rpc( + error: RpcError, *, _cache: dict[tuple[int, str], Type[RpcError]] = {} +) -> RpcError: code = canonicalize_code(error.code) name = canonicalize_name(error.name) tup = code, name if tup not in _cache: _cache[tup] = type(pretty_name(name), (from_code(code), from_name(name)), {}) - return _cache[tup](code=error.code, name=error.name, value=error.value, caused_by=error._caused_by) + return _cache[tup]( + code=error.code, name=error.name, value=error.value, caused_by=error._caused_by + ) class ErrorFactory: @@ -51,7 +55,9 @@ class ErrorFactory: return from_code(int(m[1])) else: adapted = adapt_user_name(name) - if pretty_name(canonicalize_name(adapted)) != name or re.match(r"[A-Z]{2}", name): + if pretty_name(canonicalize_name(adapted)) != name or re.match( + r"[A-Z]{2}", name + ): raise AttributeError(f"error subclass names must be CamelCase: {name}") return from_name(adapted) diff --git a/client/src/telethon/_impl/client/events/event.py b/client/src/telethon/_impl/client/events/event.py index e44c906d..8fcc2dbf 100644 --- a/client/src/telethon/_impl/client/events/event.py +++ b/client/src/telethon/_impl/client/events/event.py @@ -24,7 +24,9 @@ class Event(abc.ABC, metaclass=NoPublicConstructor): @classmethod @abc.abstractmethod - def _try_from_update(cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]) -> Optional[Self]: + def _try_from_update( + cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer] + ) -> Optional[Self]: pass @@ -50,7 +52,9 @@ class Raw(Event): self._chat_map = chat_map @classmethod - def _try_from_update(cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]) -> Optional[Self]: + def _try_from_update( + cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer] + ) -> Optional[Self]: return cls._create(client, update, chat_map) diff --git a/client/src/telethon/_impl/client/events/filters/combinators.py b/client/src/telethon/_impl/client/events/filters/combinators.py index 2e41fded..5bce9950 100644 --- a/client/src/telethon/_impl/client/events/filters/combinators.py +++ b/client/src/telethon/_impl/client/events/filters/combinators.py @@ -65,7 +65,9 @@ class Any(Combinable): __slots__ = ("_filters",) - def __init__(self, filter1: FilterType, filter2: FilterType, *filters: FilterType) -> None: + def __init__( + self, filter1: FilterType, filter2: FilterType, *filters: FilterType + ) -> None: self._filters = (filter1, filter2, *filters) @property @@ -109,7 +111,9 @@ class All(Combinable): __slots__ = ("_filters",) - def __init__(self, filter1: FilterType, filter2: FilterType, *filters: FilterType) -> None: + def __init__( + self, filter1: FilterType, filter2: FilterType, *filters: FilterType + ) -> None: self._filters = (filter1, filter2, *filters) @property diff --git a/client/src/telethon/_impl/client/events/messages.py b/client/src/telethon/_impl/client/events/messages.py index 335db355..65cf42de 100644 --- a/client/src/telethon/_impl/client/events/messages.py +++ b/client/src/telethon/_impl/client/events/messages.py @@ -24,11 +24,15 @@ class NewMessage(Event, Message): """ @classmethod - def _try_from_update(cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]) -> Optional[Self]: + def _try_from_update( + cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer] + ) -> Optional[Self]: if isinstance(update, (types.UpdateNewMessage, types.UpdateNewChannelMessage)): if isinstance(update.message, types.Message): return cls._from_raw(client, update.message, chat_map) - elif isinstance(update, (types.UpdateShortMessage, types.UpdateShortChatMessage)): + elif isinstance( + update, (types.UpdateShortMessage, types.UpdateShortChatMessage) + ): raise RuntimeError("should have been handled by adaptor") return None @@ -42,8 +46,12 @@ class MessageEdited(Event, Message): """ @classmethod - def _try_from_update(cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]) -> Optional[Self]: - if isinstance(update, (types.UpdateEditMessage, types.UpdateEditChannelMessage)): + def _try_from_update( + cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer] + ) -> Optional[Self]: + if isinstance( + update, (types.UpdateEditMessage, types.UpdateEditChannelMessage) + ): return cls._from_raw(client, update.message, chat_map) else: return None @@ -66,7 +74,9 @@ class MessageDeleted(Event): self._channel_id = channel_id @classmethod - def _try_from_update(cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]) -> Optional[Self]: + def _try_from_update( + cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer] + ) -> Optional[Self]: if isinstance(update, types.UpdateDeleteMessages): return cls._create(update.messages, None) elif isinstance(update, types.UpdateDeleteChannelMessages): @@ -112,7 +122,9 @@ class MessageRead(Event): self._chat_map = chat_map @classmethod - def _try_from_update(cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]) -> Optional[Self]: + def _try_from_update( + cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer] + ) -> Optional[Self]: if isinstance( update, ( @@ -127,7 +139,9 @@ class MessageRead(Event): return None def _peer(self) -> abcs.Peer: - if isinstance(self._raw, (types.UpdateReadHistoryInbox, types.UpdateReadHistoryOutbox)): + if isinstance( + self._raw, (types.UpdateReadHistoryInbox, types.UpdateReadHistoryOutbox) + ): return self._raw.peer else: return types.PeerChannel(channel_id=self._raw.channel_id) @@ -140,7 +154,9 @@ class MessageRead(Event): peer = self._peer() pid = peer_id(peer) if pid not in self._chat_map: - self._chat_map[pid] = expand_peer(self._client, peer, broadcast=getattr(self._raw, "post", None)) + self._chat_map[pid] = expand_peer( + self._client, peer, broadcast=getattr(self._raw, "post", None) + ) return self._chat_map[pid] @property diff --git a/client/src/telethon/_impl/client/events/queries.py b/client/src/telethon/_impl/client/events/queries.py index bc9d6a9c..276c35c6 100644 --- a/client/src/telethon/_impl/client/events/queries.py +++ b/client/src/telethon/_impl/client/events/queries.py @@ -31,7 +31,9 @@ class ButtonCallback(Event): self._chat_map = chat_map @classmethod - def _try_from_update(cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]) -> Optional[Self]: + def _try_from_update( + cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer] + ) -> Optional[Self]: if isinstance(update, types.UpdateBotCallbackQuery) and update.data is not None: return cls._create(client, update, chat_map) else: @@ -81,7 +83,11 @@ class ButtonCallback(Event): chat = self._chat_map.get(pid) or PeerRef._empty_from_peer(self._raw.peer) lst = CherryPickedList(self._client, chat._ref, []) - lst._ids.append(types.InputMessageCallbackQuery(id=self._raw.msg_id, query_id=self._raw.query_id)) + lst._ids.append( + types.InputMessageCallbackQuery( + id=self._raw.msg_id, query_id=self._raw.query_id + ) + ) message = (await lst)[0] @@ -99,7 +105,9 @@ class InlineQuery(Event): self._raw = update @classmethod - def _try_from_update(cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]) -> Optional[Self]: + def _try_from_update( + cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer] + ) -> Optional[Self]: if isinstance(update, types.UpdateBotInlineQuery): return cls._create(update) else: diff --git a/client/src/telethon/_impl/client/parsers/html.py b/client/src/telethon/_impl/client/parsers/html.py index 05239414..cafa3d3d 100644 --- a/client/src/telethon/_impl/client/parsers/html.py +++ b/client/src/telethon/_impl/client/parsers/html.py @@ -133,7 +133,9 @@ def parse(html: str) -> tuple[str, list[MessageEntity]]: return del_surrogate(parser.text), parser.entities -ENTITY_TO_FORMATTER: dict[Type[MessageEntity], tuple[str, str] | Callable[[Any, str], tuple[str, str]]] = { +ENTITY_TO_FORMATTER: dict[ + Type[MessageEntity], tuple[str, str] | Callable[[Any, str], tuple[str, str]] +] = { MessageEntityBold: ("", ""), MessageEntityItalic: ("", ""), MessageEntityCode: ("", ""), @@ -194,7 +196,12 @@ def unparse(text: str, entities: Iterable[MessageEntity]) -> str: while within_surrogate(text, at): at += 1 - text = text[:at] + what + escape(text[at:next_escape_bound]) + text[next_escape_bound:] + text = ( + text[:at] + + what + + escape(text[at:next_escape_bound]) + + text[next_escape_bound:] + ) next_escape_bound = at text = escape(text[:next_escape_bound]) + text[next_escape_bound:] diff --git a/client/src/telethon/_impl/client/parsers/markdown.py b/client/src/telethon/_impl/client/parsers/markdown.py index 0531f026..1704d4fb 100644 --- a/client/src/telethon/_impl/client/parsers/markdown.py +++ b/client/src/telethon/_impl/client/parsers/markdown.py @@ -82,7 +82,9 @@ def parse(message: str) -> tuple[str, list[MessageEntity]]: else: for entity in reversed(entities): if isinstance(entity, ty): - setattr(entity, "length", len(message) - getattr(entity, "offset", 0)) + setattr( + entity, "length", len(message) - getattr(entity, "offset", 0) + ) break parsed = MARKDOWN.parse(add_surrogate(message.strip())) @@ -101,15 +103,25 @@ def parse(message: str) -> tuple[str, list[MessageEntity]]: if token.type in ("blockquote_close", "blockquote_open"): push(MessageEntityBlockquote) elif token.type == "code_block": - entities.append(MessageEntityPre(offset=len(message), length=len(token.content), language="")) + entities.append( + MessageEntityPre( + offset=len(message), length=len(token.content), language="" + ) + ) message += token.content elif token.type == "code_inline": - entities.append(MessageEntityCode(offset=len(message), length=len(token.content))) + entities.append( + MessageEntityCode(offset=len(message), length=len(token.content)) + ) message += token.content elif token.type in ("em_close", "em_open"): push(MessageEntityItalic) elif token.type == "fence": - entities.append(MessageEntityPre(offset=len(message), length=len(token.content), language=token.info)) + entities.append( + MessageEntityPre( + offset=len(message), length=len(token.content), language=token.info + ) + ) message += token.content[:-1] # remove a single trailing newline elif token.type == "hardbreak": message += "\n" @@ -118,7 +130,9 @@ def parse(message: str) -> tuple[str, list[MessageEntity]]: elif token.type == "hr": message += "\u2015\n\n" elif token.type in ("link_close", "link_open"): - if token.markup != "autolink": # telegram already picks up on these automatically + if ( + token.markup != "autolink" + ): # telegram already picks up on these automatically push(MessageEntityTextUrl, url=token.attrs.get("href")) elif token.type in ("s_close", "s_open"): push(MessageEntityStrike) diff --git a/client/src/telethon/_impl/client/parsers/strings.py b/client/src/telethon/_impl/client/parsers/strings.py index 1fa9a454..0fd69059 100644 --- a/client/src/telethon/_impl/client/parsers/strings.py +++ b/client/src/telethon/_impl/client/parsers/strings.py @@ -6,7 +6,11 @@ def add_surrogate(text: str) -> str: return "".join( # SMP -> Surrogate Pairs (Telegram offsets are calculated with these). # See https://en.wikipedia.org/wiki/Plane_(Unicode)#Overview for more. - ("".join(chr(y) for y in struct.unpack(" list[Message]: + async def send( + self, peer: Peer | PeerRef, *, reply_to: Optional[int] = None + ) -> list[Message]: """ Send the album. @@ -205,7 +225,11 @@ class AlbumBuilder(metaclass=NoPublicConstructor): update_stickersets_order=False, peer=peer._ref._to_input_peer(), reply_to=( - types.InputReplyToMessage(reply_to_msg_id=reply_to, top_msg_id=None) if reply_to else None + types.InputReplyToMessage( + reply_to_msg_id=reply_to, top_msg_id=None + ) + if reply_to + else None ), multi_media=self._medias, schedule_date=None, diff --git a/client/src/telethon/_impl/client/types/buttons/button.py b/client/src/telethon/_impl/client/types/buttons/button.py index 5632d293..e28b4582 100644 --- a/client/src/telethon/_impl/client/types/buttons/button.py +++ b/client/src/telethon/_impl/client/types/buttons/button.py @@ -50,7 +50,9 @@ class Button(abc.ABC): def __init__(self, text: str) -> None: if self.__class__ == Button: - raise TypeError(f"Can't instantiate abstract class {self.__class__.__name__}") + raise TypeError( + f"Can't instantiate abstract class {self.__class__.__name__}" + ) self._raw: RawButtonType = types.KeyboardButton(text=text) self._msg: Optional[weakref.ReferenceType[Message]] = None diff --git a/client/src/telethon/_impl/client/types/buttons/inline_button.py b/client/src/telethon/_impl/client/types/buttons/inline_button.py index 1d4cfd1e..803269f1 100644 --- a/client/src/telethon/_impl/client/types/buttons/inline_button.py +++ b/client/src/telethon/_impl/client/types/buttons/inline_button.py @@ -29,6 +29,8 @@ class InlineButton(Button, abc.ABC): def __init__(self, text: str) -> None: if self.__class__ == InlineButton: - raise TypeError(f"Can't instantiate abstract class {self.__class__.__name__}") + raise TypeError( + f"Can't instantiate abstract class {self.__class__.__name__}" + ) else: super().__init__(text) diff --git a/client/src/telethon/_impl/client/types/buttons/switch_inline.py b/client/src/telethon/_impl/client/types/buttons/switch_inline.py index 0e0842a9..977b35c8 100644 --- a/client/src/telethon/_impl/client/types/buttons/switch_inline.py +++ b/client/src/telethon/_impl/client/types/buttons/switch_inline.py @@ -14,7 +14,9 @@ class SwitchInline(InlineButton): def __init__(self, text: str, query: Optional[str] = None) -> None: super().__init__(text) - self._raw = types.KeyboardButtonSwitchInline(same_peer=False, text=text, query=query or "", peer_types=None) + self._raw = types.KeyboardButtonSwitchInline( + same_peer=False, text=text, query=query or "", peer_types=None + ) @property def query(self) -> str: diff --git a/client/src/telethon/_impl/client/types/chat_restriction.py b/client/src/telethon/_impl/client/types/chat_restriction.py index c8f4e8fa..3c330f91 100644 --- a/client/src/telethon/_impl/client/types/chat_restriction.py +++ b/client/src/telethon/_impl/client/types/chat_restriction.py @@ -111,7 +111,9 @@ class ChatRestriction(Enum): return set(filter(None, iter(restrictions))) @classmethod - def _set_to_raw(cls, restrictions: set[ChatRestriction], until_date: int) -> types.ChatBannedRights: + def _set_to_raw( + cls, restrictions: set[ChatRestriction], until_date: int + ) -> types.ChatBannedRights: return types.ChatBannedRights( view_messages=cls.VIEW_MESSAGES in restrictions, send_messages=cls.SEND_MESSAGES in restrictions, diff --git a/client/src/telethon/_impl/client/types/dialog.py b/client/src/telethon/_impl/client/types/dialog.py index 52192a29..673fbf78 100644 --- a/client/src/telethon/_impl/client/types/dialog.py +++ b/client/src/telethon/_impl/client/types/dialog.py @@ -90,6 +90,9 @@ class Dialog(metaclass=NoPublicConstructor): if isinstance(self._raw, types.Dialog): return self._raw.unread_count elif isinstance(self._raw, types.DialogPeerFolder): - return self._raw.unread_unmuted_messages_count + self._raw.unread_muted_messages_count + return ( + self._raw.unread_unmuted_messages_count + + self._raw.unread_muted_messages_count + ) else: raise RuntimeError("unexpected case") diff --git a/client/src/telethon/_impl/client/types/draft.py b/client/src/telethon/_impl/client/types/draft.py index 3086497b..9368044d 100644 --- a/client/src/telethon/_impl/client/types/draft.py +++ b/client/src/telethon/_impl/client/types/draft.py @@ -37,7 +37,9 @@ class Draft(metaclass=NoPublicConstructor): self._chat_map = chat_map @classmethod - def _from_raw_update(cls, client: Client, draft: types.UpdateDraftMessage, chat_map: dict[int, Peer]) -> Self: + def _from_raw_update( + cls, client: Client, draft: types.UpdateDraftMessage, chat_map: dict[int, Peer] + ) -> Self: return cls._create(client, draft.peer, draft.top_msg_id, draft.draft, chat_map) @classmethod @@ -58,7 +60,9 @@ class Draft(metaclass=NoPublicConstructor): This is also the chat where the message will be sent to by :meth:`send`. """ - return self._chat_map.get(peer_id(self._peer)) or expand_peer(self._client, self._peer, broadcast=None) + return self._chat_map.get(peer_id(self._peer)) or expand_peer( + self._client, self._peer, broadcast=None + ) @property def link_preview(self) -> bool: @@ -87,7 +91,9 @@ class Draft(metaclass=NoPublicConstructor): The :attr:`~Message.text_html` of the message that will be sent. """ if text := getattr(self._raw, "message", None): - return generate_html_message(text, getattr(self._raw, "entities", None) or []) + return generate_html_message( + text, getattr(self._raw, "entities", None) or [] + ) else: return None @@ -97,7 +103,9 @@ class Draft(metaclass=NoPublicConstructor): The :attr:`~Message.text_markdown` of the message that will be sent. """ if text := getattr(self._raw, "message", None): - return generate_markdown_message(text, getattr(self._raw, "entities", None) or []) + return generate_markdown_message( + text, getattr(self._raw, "entities", None) or [] + ) else: return None @@ -107,7 +115,11 @@ class Draft(metaclass=NoPublicConstructor): The date when the draft was last updated. """ date = getattr(self._raw, "date", None) - return datetime.datetime.fromtimestamp(date, tz=datetime.timezone.utc) if date is not None else None + return ( + datetime.datetime.fromtimestamp(date, tz=datetime.timezone.utc) + if date is not None + else None + ) async def edit( self, @@ -180,7 +192,11 @@ class Draft(metaclass=NoPublicConstructor): noforwards=False, update_stickersets_order=False, peer=self._peer_ref()._to_input_peer(), - reply_to=(types.InputReplyToMessage(reply_to_msg_id=reply_to, top_msg_id=None) if reply_to else None), + reply_to=( + types.InputReplyToMessage(reply_to_msg_id=reply_to, top_msg_id=None) + if reply_to + else None + ), message=message, random_id=random_id, reply_markup=None, @@ -195,7 +211,11 @@ class Draft(metaclass=NoPublicConstructor): {}, out=result.out, id=result.id, - from_id=(types.PeerUser(user_id=self._client._session.user.id) if self._client._session.user else None), + from_id=( + types.PeerUser(user_id=self._client._session.user.id) + if self._client._session.user + else None + ), peer_id=self._peer_ref()._to_peer(), reply_to=( types.MessageReplyHeader( @@ -215,7 +235,9 @@ class Draft(metaclass=NoPublicConstructor): ttl_period=result.ttl_period, ) else: - return self._client._build_message_map(result, self._peer_ref()).with_random_id(random_id) + return self._client._build_message_map( + result, self._peer_ref() + ).with_random_id(random_id) async def delete(self) -> None: """ diff --git a/client/src/telethon/_impl/client/types/file.py b/client/src/telethon/_impl/client/types/file.py index 76c9ce83..1e5e018f 100644 --- a/client/src/telethon/_impl/client/types/file.py +++ b/client/src/telethon/_impl/client/types/file.py @@ -29,7 +29,11 @@ def photo_size_byte_count(size: abcs.PhotoSize) -> int: elif isinstance(size, types.PhotoSizeProgressive): return max(size.sizes) elif isinstance(size, types.PhotoStrippedSize): - return len(stripped_size_header) + (len(size.bytes) - 3) + len(stripped_size_footer) + return ( + len(stripped_size_header) + + (len(size.bytes) - 3) + + len(stripped_size_footer) + ) else: raise RuntimeError("unexpected case") @@ -176,7 +180,9 @@ class File(metaclass=NoPublicConstructor): self._client = client @classmethod - def _try_from_raw_message_media(cls, client: Client, raw: abcs.MessageMedia) -> Optional[Self]: + def _try_from_raw_message_media( + cls, client: Client, raw: abcs.MessageMedia + ) -> Optional[Self]: if isinstance(raw, types.MessageMediaDocument): if raw.document: return cls._try_from_raw_document( @@ -198,9 +204,13 @@ class File(metaclass=NoPublicConstructor): elif isinstance(raw, types.MessageMediaWebPage): if isinstance(raw.webpage, types.WebPage): if raw.webpage.document: - return cls._try_from_raw_document(client, raw.webpage.document, orig_raw=raw) + return cls._try_from_raw_document( + client, raw.webpage.document, orig_raw=raw + ) if raw.webpage.photo: - return cls._try_from_raw_photo(client, raw.webpage.photo, orig_raw=raw) + return cls._try_from_raw_photo( + client, raw.webpage.photo, orig_raw=raw + ) return None @@ -219,13 +229,21 @@ class File(metaclass=NoPublicConstructor): attributes=raw.attributes, size=raw.size, name=next( - (a.file_name for a in raw.attributes if isinstance(a, types.DocumentAttributeFilename)), + ( + a.file_name + for a in raw.attributes + if isinstance(a, types.DocumentAttributeFilename) + ), "", ), mime=raw.mime_type, photo=False, muted=next( - (a.nosound for a in raw.attributes if isinstance(a, types.DocumentAttributeVideo)), + ( + a.nosound + for a in raw.attributes + if isinstance(a, types.DocumentAttributeVideo) + ), False, ), input_media=types.InputMediaDocument( @@ -343,7 +361,9 @@ class File(metaclass=NoPublicConstructor): return dim.w for attr in self._attributes: - if isinstance(attr, (types.DocumentAttributeImageSize, types.DocumentAttributeVideo)): + if isinstance( + attr, (types.DocumentAttributeImageSize, types.DocumentAttributeVideo) + ): return attr.w return None @@ -357,7 +377,9 @@ class File(metaclass=NoPublicConstructor): return dim.h for attr in self._attributes: - if isinstance(attr, (types.DocumentAttributeImageSize, types.DocumentAttributeVideo)): + if isinstance( + attr, (types.DocumentAttributeImageSize, types.DocumentAttributeVideo) + ): return attr.h return None @@ -388,7 +410,9 @@ class File(metaclass=NoPublicConstructor): id=self._input_media.id.id, access_hash=self._input_media.id.access_hash, file_reference=self._input_media.id.file_reference, - thumb_size=(self._thumb.type if isinstance(self._thumb, thumb_types) else ""), + thumb_size=( + self._thumb.type if isinstance(self._thumb, thumb_types) else "" + ), ) elif isinstance(self._input_media, types.InputMediaPhoto): assert isinstance(self._input_media.id, types.InputPhoto) diff --git a/client/src/telethon/_impl/client/types/keyboard.py b/client/src/telethon/_impl/client/types/keyboard.py index b9fc1b7a..d560463c 100644 --- a/client/src/telethon/_impl/client/types/keyboard.py +++ b/client/src/telethon/_impl/client/types/keyboard.py @@ -12,11 +12,16 @@ def _build_keyboard_rows( ) -> list[abcs.KeyboardButtonRow]: # list[button] -> list[list[button]] # This does allow for "invalid" inputs (mixing lists and non-lists), but that's acceptable. - buttons_lists_iter = [button if isinstance(button, list) else [button] for button in (btns or [])] + buttons_lists_iter = [ + button if isinstance(button, list) else [button] for button in (btns or []) + ] # Remove empty rows (also making it easy to check if all-empty). buttons_lists = [bs for bs in buttons_lists_iter if bs] - return [types.KeyboardButtonRow(buttons=[btn._raw for btn in btns]) for btns in buttons_lists] + return [ + types.KeyboardButtonRow(buttons=[btn._raw for btn in btns]) + for btns in buttons_lists + ] class Keyboard: @@ -44,7 +49,9 @@ class Keyboard: class InlineKeyboard: __slots__ = ("_raw",) - def __init__(self, buttons: list[AnyInlineButton] | list[list[AnyInlineButton]]) -> None: + def __init__( + self, buttons: list[AnyInlineButton] | list[list[AnyInlineButton]] + ) -> None: self._raw = types.ReplyInlineMarkup(rows=_build_keyboard_rows(buttons)) diff --git a/client/src/telethon/_impl/client/types/message.py b/client/src/telethon/_impl/client/types/message.py index 08101ec9..4d37263d 100644 --- a/client/src/telethon/_impl/client/types/message.py +++ b/client/src/telethon/_impl/client/types/message.py @@ -34,7 +34,11 @@ def generate_random_id() -> int: def adapt_date(date: Optional[int]) -> Optional[datetime.datetime]: - return datetime.datetime.fromtimestamp(date, tz=datetime.timezone.utc) if date is not None else None + return ( + datetime.datetime.fromtimestamp(date, tz=datetime.timezone.utc) + if date is not None + else None + ) class Message(metaclass=NoPublicConstructor): @@ -55,14 +59,20 @@ class Message(metaclass=NoPublicConstructor): print('Found empty message with ID', message.id) """ - def __init__(self, client: Client, message: abcs.Message, chat_map: dict[int, Peer]) -> None: - assert isinstance(message, (types.Message, types.MessageService, types.MessageEmpty)) + def __init__( + self, client: Client, message: abcs.Message, chat_map: dict[int, Peer] + ) -> None: + assert isinstance( + message, (types.Message, types.MessageService, types.MessageEmpty) + ) self._client = client self._raw = message self._chat_map = chat_map @classmethod - def _from_raw(cls, client: Client, message: abcs.Message, chat_map: dict[int, Peer]) -> Self: + def _from_raw( + cls, client: Client, message: abcs.Message, chat_map: dict[int, Peer] + ) -> Self: return cls._create(client, message, chat_map) @classmethod @@ -148,7 +158,9 @@ class Message(metaclass=NoPublicConstructor): See :ref:`formatting` to learn the HTML elements used. """ if text := getattr(self._raw, "message", None): - return generate_html_message(text, getattr(self._raw, "entities", None) or []) + return generate_html_message( + text, getattr(self._raw, "entities", None) or [] + ) else: return None @@ -160,7 +172,9 @@ class Message(metaclass=NoPublicConstructor): See :ref:`formatting` to learn the formatting characters used. """ if text := getattr(self._raw, "message", None): - return generate_markdown_message(text, getattr(self._raw, "entities", None) or []) + return generate_markdown_message( + text, getattr(self._raw, "entities", None) or [] + ) else: return None @@ -179,7 +193,9 @@ class Message(metaclass=NoPublicConstructor): peer = self._raw.peer_id or types.PeerUser(user_id=0) pid = peer_id(peer) if pid not in self._chat_map: - self._chat_map[pid] = expand_peer(self._client, peer, broadcast=getattr(self._raw, "post", None)) + self._chat_map[pid] = expand_peer( + self._client, peer, broadcast=getattr(self._raw, "post", None) + ) return self._chat_map[pid] @property @@ -223,7 +239,14 @@ class Message(metaclass=NoPublicConstructor): This can also be used as a way to check that the message media is an audio. """ audio = self._file() - return audio if audio and any(isinstance(a, types.DocumentAttributeAudio) for a in audio._attributes) else None + return ( + audio + if audio + and any( + isinstance(a, types.DocumentAttributeAudio) for a in audio._attributes + ) + else None + ) @property def video(self) -> Optional[File]: @@ -233,7 +256,14 @@ class Message(metaclass=NoPublicConstructor): This can also be used as a way to check that the message media is a video. """ audio = self._file() - return audio if audio and any(isinstance(a, types.DocumentAttributeVideo) for a in audio._attributes) else None + return ( + audio + if audio + and any( + isinstance(a, types.DocumentAttributeVideo) for a in audio._attributes + ) + else None + ) @property def file(self) -> Optional[File]: @@ -447,7 +477,10 @@ class Message(metaclass=NoPublicConstructor): return None return [ - [create_button(self, button) for button in cast(types.KeyboardButtonRow, row).buttons] + [ + create_button(self, button) + for button in cast(types.KeyboardButtonRow, row).buttons + ] for row in markup.rows ] @@ -473,8 +506,13 @@ class Message(metaclass=NoPublicConstructor): return not isinstance(self._raw, types.MessageEmpty) -def build_msg_map(client: Client, messages: Sequence[abcs.Message], chat_map: dict[int, Peer]) -> dict[int, Message]: - return {msg.id: msg for msg in (Message._from_raw(client, m, chat_map) for m in messages)} +def build_msg_map( + client: Client, messages: Sequence[abcs.Message], chat_map: dict[int, Peer] +) -> dict[int, Message]: + return { + msg.id: msg + for msg in (Message._from_raw(client, m, chat_map) for m in messages) + } def parse_message( diff --git a/client/src/telethon/_impl/client/types/meta.py b/client/src/telethon/_impl/client/types/meta.py index b812de6c..b7de4f15 100644 --- a/client/src/telethon/_impl/client/types/meta.py +++ b/client/src/telethon/_impl/client/types/meta.py @@ -16,16 +16,23 @@ class Final(abc.ABCMeta): cls_namespace: dict[str, object], ) -> "Final": # Allow subclassing while within telethon._impl (or other package names). - allowed_base = Final.__module__[: Final.__module__.find(".", Final.__module__.find(".") + 1)] + allowed_base = Final.__module__[ + : Final.__module__.find(".", Final.__module__.find(".") + 1) + ] for base in bases: if isinstance(base, Final) and not base.__module__.startswith(allowed_base): - raise TypeError(f"{base.__module__}.{base.__qualname__} does not support" " subclassing") + raise TypeError( + f"{base.__module__}.{base.__qualname__} does not support" + " subclassing" + ) return super().__new__(cls, name, bases, cls_namespace) class NoPublicConstructor(Final): def __call__(cls, *args: Any, **kwds: Any) -> Any: - raise TypeError(f"{cls.__module__}.{cls.__qualname__} has no public constructor") + raise TypeError( + f"{cls.__module__}.{cls.__qualname__} has no public constructor" + ) @property def _create(cls: Type[T]) -> Type[T]: diff --git a/client/src/telethon/_impl/client/types/participant.py b/client/src/telethon/_impl/client/types/participant.py index 62689a5d..ce7937b5 100644 --- a/client/src/telethon/_impl/client/types/participant.py +++ b/client/src/telethon/_impl/client/types/participant.py @@ -157,16 +157,22 @@ class Participant(metaclass=NoPublicConstructor): """ :data:`True` if the participant is the creator of the chat. """ - return isinstance(self._raw, (types.ChannelParticipantCreator, types.ChatParticipantCreator)) + return isinstance( + self._raw, (types.ChannelParticipantCreator, types.ChatParticipantCreator) + ) @property def admin_rights(self) -> Optional[set[AdminRight]]: """ The set of administrator rights this participant has been granted, if they are an administrator. """ - if isinstance(self._raw, (types.ChannelParticipantCreator, types.ChannelParticipantAdmin)): + if isinstance( + self._raw, (types.ChannelParticipantCreator, types.ChannelParticipantAdmin) + ): return AdminRight._from_raw(self._raw.admin_rights) - elif isinstance(self._raw, (types.ChatParticipantCreator, types.ChatParticipantAdmin)): + elif isinstance( + self._raw, (types.ChatParticipantCreator, types.ChatParticipantAdmin) + ): return AdminRight._chat_rights() else: return None @@ -188,9 +194,13 @@ class Participant(metaclass=NoPublicConstructor): participant = self.user or self.banned or self.left assert participant if isinstance(participant, User): - await self._client.set_participant_admin_rights(self._chat, participant, rights) + await self._client.set_participant_admin_rights( + self._chat, participant, rights + ) else: - raise TypeError(f"participant of type {participant.__class__.__name__} cannot be made admin") + raise TypeError( + f"participant of type {participant.__class__.__name__} cannot be made admin" + ) async def set_restrictions( self, @@ -203,4 +213,6 @@ class Participant(metaclass=NoPublicConstructor): """ participant = self.user or self.banned or self.left assert participant - await self._client.set_participant_restrictions(self._chat, participant, restrictions, until=until) + await self._client.set_participant_restrictions( + self._chat, participant, restrictions, until=until + ) diff --git a/client/src/telethon/_impl/client/types/peer/__init__.py b/client/src/telethon/_impl/client/types/peer/__init__.py index 582a8304..91c5d6ba 100644 --- a/client/src/telethon/_impl/client/types/peer/__init__.py +++ b/client/src/telethon/_impl/client/types/peer/__init__.py @@ -15,7 +15,9 @@ if TYPE_CHECKING: from ...client.client import Client -def build_chat_map(client: Client, users: Sequence[abcs.User], chats: Sequence[abcs.Chat]) -> dict[int, Peer]: +def build_chat_map( + client: Client, users: Sequence[abcs.User], chats: Sequence[abcs.Chat] +) -> dict[int, Peer]: users_iter = (User._from_raw(u) for u in users) chats_iter = ( ( @@ -43,7 +45,9 @@ def build_chat_map(client: Client, users: Sequence[abcs.User], chats: Sequence[a for x in v: print(x, file=sys.stderr) - raise RuntimeError(f"chat identifier collision: {k}; please report this") + raise RuntimeError( + f"chat identifier collision: {k}; please report this" + ) return result @@ -77,7 +81,11 @@ def expand_peer(client: Client, peer: abcs.Peer, *, broadcast: Optional[bool]) - until_date=None, ) - return Channel._from_raw(channel) if broadcast else Group._from_raw(client, channel) + return ( + Channel._from_raw(channel) + if broadcast + else Group._from_raw(client, channel) + ) else: raise RuntimeError("unexpected case") diff --git a/client/src/telethon/_impl/client/types/peer/group.py b/client/src/telethon/_impl/client/types/peer/group.py index 6fd8b04d..5dce7ef3 100644 --- a/client/src/telethon/_impl/client/types/peer/group.py +++ b/client/src/telethon/_impl/client/types/peer/group.py @@ -24,7 +24,13 @@ class Group(Peer, metaclass=NoPublicConstructor): def __init__( self, client: Client, - chat: (types.ChatEmpty | types.Chat | types.ChatForbidden | types.Channel | types.ChannelForbidden), + chat: ( + types.ChatEmpty + | types.Chat + | types.ChatForbidden + | types.Channel + | types.ChannelForbidden + ), ) -> None: self._client = client self._raw = chat @@ -90,4 +96,6 @@ class Group(Peer, metaclass=NoPublicConstructor): """ Alias for :meth:`telethon.Client.set_chat_default_restrictions`. """ - await self._client.set_chat_default_restrictions(self, restrictions, until=until) + await self._client.set_chat_default_restrictions( + self, restrictions, until=until + ) diff --git a/client/src/telethon/_impl/crypto/aes.py b/client/src/telethon/_impl/crypto/aes.py index 41355f81..bc248801 100644 --- a/client/src/telethon/_impl/crypto/aes.py +++ b/client/src/telethon/_impl/crypto/aes.py @@ -1,10 +1,16 @@ try: import cryptg # type: ignore [import-untyped] - def ige_encrypt(plaintext: bytes | bytearray | memoryview, key: bytes, iv: bytes) -> bytes: # noqa: F811 - return cryptg.encrypt_ige(bytes(plaintext) if not isinstance(plaintext, bytes) else plaintext, key, iv) + def ige_encrypt( + plaintext: bytes | bytearray | memoryview, key: bytes, iv: bytes + ) -> bytes: # noqa: F811 + return cryptg.encrypt_ige( + bytes(plaintext) if not isinstance(plaintext, bytes) else plaintext, key, iv + ) - def ige_decrypt(ciphertext: bytes | bytearray | memoryview, key: bytes, iv: bytes) -> bytes: # noqa: F811 + def ige_decrypt( + ciphertext: bytes | bytearray | memoryview, key: bytes, iv: bytes + ) -> bytes: # noqa: F811 return cryptg.decrypt_ige( bytes(ciphertext) if not isinstance(ciphertext, bytes) else ciphertext, key, @@ -14,7 +20,9 @@ try: except ImportError: import pyaes # type: ignore [import-untyped] - def ige_encrypt(plaintext: bytes | bytearray | memoryview, key: bytes, iv: bytes) -> bytes: + def ige_encrypt( + plaintext: bytes | bytearray | memoryview, key: bytes, iv: bytes + ) -> bytes: assert len(plaintext) % 16 == 0 assert len(iv) == 32 @@ -27,7 +35,10 @@ except ImportError: for block_offset in range(0, len(plaintext), 16): plaintext_block = plaintext[block_offset : block_offset + 16] ciphertext_block = bytes( - a ^ b for a, b in zip(aes.encrypt([a ^ b for a, b in zip(plaintext_block, iv1)]), iv2) + a ^ b + for a, b in zip( + aes.encrypt([a ^ b for a, b in zip(plaintext_block, iv1)]), iv2 + ) ) iv1 = ciphertext_block iv2 = plaintext_block @@ -36,7 +47,9 @@ except ImportError: return bytes(ciphertext) - def ige_decrypt(ciphertext: bytes | bytearray | memoryview, key: bytes, iv: bytes) -> bytes: + def ige_decrypt( + ciphertext: bytes | bytearray | memoryview, key: bytes, iv: bytes + ) -> bytes: assert len(ciphertext) % 16 == 0 assert len(iv) == 32 @@ -49,7 +62,10 @@ except ImportError: for block_offset in range(0, len(ciphertext), 16): ciphertext_block = ciphertext[block_offset : block_offset + 16] plaintext_block = bytes( - a ^ b for a, b in zip(aes.decrypt([a ^ b for a, b in zip(ciphertext_block, iv2)]), iv1) + a ^ b + for a, b in zip( + aes.decrypt([a ^ b for a, b in zip(ciphertext_block, iv2)]), iv1 + ) ) iv1 = ciphertext_block iv2 = plaintext_block diff --git a/client/src/telethon/_impl/crypto/auth_key.py b/client/src/telethon/_impl/crypto/auth_key.py index 11a4ab84..f2c2cdc7 100644 --- a/client/src/telethon/_impl/crypto/auth_key.py +++ b/client/src/telethon/_impl/crypto/auth_key.py @@ -20,4 +20,8 @@ class AuthKey: return self.data def calc_new_nonce_hash(self, new_nonce: int, number: int) -> int: - return int.from_bytes(sha1(new_nonce.to_bytes(32) + number.to_bytes(1) + self.aux_hash).digest()[4:]) + return int.from_bytes( + sha1(new_nonce.to_bytes(32) + number.to_bytes(1) + self.aux_hash).digest()[ + 4: + ] + ) diff --git a/client/src/telethon/_impl/crypto/crypto.py b/client/src/telethon/_impl/crypto/crypto.py index d08308b7..17fd788e 100644 --- a/client/src/telethon/_impl/crypto/crypto.py +++ b/client/src/telethon/_impl/crypto/crypto.py @@ -19,7 +19,9 @@ class CalcKey(NamedTuple): # https://core.telegram.org/mtproto/description#defining-aes-key-and-initialization-vector -def calc_key(auth_key: AuthKey, msg_key: bytes | bytearray | memoryview, side: Side) -> CalcKey: +def calc_key( + auth_key: AuthKey, msg_key: bytes | bytearray | memoryview, side: Side +) -> CalcKey: x = int(side) # sha256_a = SHA256 (msg_key + substr (auth_key, x, 36)) @@ -41,8 +43,12 @@ def determine_padding_v2_length(length: int) -> int: return 16 + (16 - (length % 16)) -def _do_encrypt_data_v2(plaintext: bytes, auth_key: AuthKey, random_padding: bytes) -> bytes: - padded_plaintext = plaintext + random_padding[: determine_padding_v2_length(len(plaintext))] +def _do_encrypt_data_v2( + plaintext: bytes, auth_key: AuthKey, random_padding: bytes +) -> bytes: + padded_plaintext = ( + plaintext + random_padding[: determine_padding_v2_length(len(plaintext))] + ) side = Side.CLIENT x = int(side) @@ -64,7 +70,9 @@ def encrypt_data_v2(plaintext: bytes, auth_key: AuthKey) -> bytes: return _do_encrypt_data_v2(plaintext, auth_key, random_padding) -def decrypt_data_v2(ciphertext: bytes | bytearray | memoryview, auth_key: AuthKey) -> bytes: +def decrypt_data_v2( + ciphertext: bytes | bytearray | memoryview, auth_key: AuthKey +) -> bytes: side = Side.SERVER x = int(side) diff --git a/client/src/telethon/_impl/crypto/rsa.py b/client/src/telethon/_impl/crypto/rsa.py index c125bc60..72338887 100644 --- a/client/src/telethon/_impl/crypto/rsa.py +++ b/client/src/telethon/_impl/crypto/rsa.py @@ -34,13 +34,17 @@ def encrypt_hashed(data: bytes, key: PublicKey, random_bytes: bytes) -> bytes: temp_key = random_bytes[192 + 32 * attempt : 192 + 32 * attempt + 32] # data_with_hash := data_pad_reversed + SHA256(temp_key + data_with_padding); -- after this assignment, data_with_hash is exactly 224 bytes long. - data_with_hash = data_pad_reversed + sha256(temp_key + data_with_padding).digest() + data_with_hash = ( + data_pad_reversed + sha256(temp_key + data_with_padding).digest() + ) # aes_encrypted := AES256_IGE(data_with_hash, temp_key, 0); -- AES256-IGE encryption with zero IV. aes_encrypted = ige_encrypt(data_with_hash, temp_key, bytes(32)) # temp_key_xor := temp_key XOR SHA256(aes_encrypted); -- adjusted key, 32 bytes - temp_key_xor = bytes(a ^ b for a, b in zip(temp_key, sha256(aes_encrypted).digest())) + temp_key_xor = bytes( + a ^ b for a, b in zip(temp_key, sha256(aes_encrypted).digest()) + ) # key_aes_encrypted := temp_key_xor + aes_encrypted; -- exactly 256 bytes (2048 bits) long key_aes_encrypted = temp_key_xor + aes_encrypted @@ -83,4 +87,6 @@ j4WcDuXc2CTHgH8gFTNhp/Y8/SpDOhvn9QIDAQAB ) -RSA_KEYS = {compute_fingerprint(key): key for key in (PRODUCTION_RSA_KEY, TESTMODE_RSA_KEY)} +RSA_KEYS = { + compute_fingerprint(key): key for key in (PRODUCTION_RSA_KEY, TESTMODE_RSA_KEY) +} diff --git a/client/src/telethon/_impl/crypto/two_factor_auth.py b/client/src/telethon/_impl/crypto/two_factor_auth.py index 4a05d1aa..384ca5a9 100644 --- a/client/src/telethon/_impl/crypto/two_factor_auth.py +++ b/client/src/telethon/_impl/crypto/two_factor_auth.py @@ -20,7 +20,9 @@ def h(*data: bytes | bytearray | memoryview) -> bytes: # SH(data, salt) := H(salt | data | salt) -def sh(data: bytes | bytearray | memoryview, salt: bytes | bytearray | memoryview) -> bytes: +def sh( + data: bytes | bytearray | memoryview, salt: bytes | bytearray | memoryview +) -> bytes: return h(salt, data, salt) diff --git a/client/src/telethon/_impl/mtproto/authentication.py b/client/src/telethon/_impl/mtproto/authentication.py index a5902f04..c38b26dc 100644 --- a/client/src/telethon/_impl/mtproto/authentication.py +++ b/client/src/telethon/_impl/mtproto/authentication.py @@ -108,9 +108,13 @@ def _do_step2(data: Step1, response: bytes, random_bytes: bytes) -> tuple[bytes, ) try: - fingerprint = next(fp for fp in res_pq.server_public_key_fingerprints if fp in RSA_KEYS) + fingerprint = next( + fp for fp in res_pq.server_public_key_fingerprints if fp in RSA_KEYS + ) except StopIteration: - raise ValueError(f"unknown fingerprints: {res_pq.server_public_key_fingerprints}") + raise ValueError( + f"unknown fingerprints: {res_pq.server_public_key_fingerprints}" + ) key = RSA_KEYS[fingerprint] ciphertext = encrypt_hashed(pq_inner_data, key, random_bytes) @@ -129,7 +133,9 @@ def step2(data: Step1, response: bytes) -> tuple[bytes, Step2]: return _do_step2(data, response, os.urandom(288)) -def _do_step3(data: Step2, response: bytes, random_bytes: bytes, now: int) -> tuple[bytes, Step3]: +def _do_step3( + data: Step2, response: bytes, random_bytes: bytes, now: int +) -> tuple[bytes, Step3]: assert len(random_bytes) == 272 nonce = data.nonce @@ -152,7 +158,9 @@ def _do_step3(data: Step2, response: bytes, random_bytes: bytes, now: int) -> tu check_server_nonce(server_dh_params.server_nonce, server_nonce) if len(server_dh_params.encrypted_answer) % 16 != 0: - raise ValueError(f"encrypted response not padded with size: {len(server_dh_params.encrypted_answer)}") + raise ValueError( + f"encrypted response not padded with size: {len(server_dh_params.encrypted_answer)}" + ) key, iv = generate_key_data_from_nonce(server_nonce, new_nonce) assert isinstance(server_dh_params.encrypted_answer, bytes) @@ -164,7 +172,9 @@ def _do_step3(data: Step2, response: bytes, random_bytes: bytes, now: int) -> tu server_dh_inner = AbcServerDhInnerData._read_from(plain_text_reader) assert isinstance(server_dh_inner, ServerDhInnerData) - expected_answer_hash = sha1(plain_text_answer[20 : 20 + plain_text_reader._pos]).digest() + expected_answer_hash = sha1( + plain_text_answer[20 : 20 + plain_text_reader._pos] + ).digest() if got_answer_hash != expected_answer_hash: raise ValueError("invalid answer hash") @@ -203,11 +213,15 @@ def _do_step3(data: Step2, response: bytes, random_bytes: bytes, now: int) -> tu ) client_dh_inner_hashed = sha1(client_dh_inner).digest() + client_dh_inner - client_dh_inner_hashed += random_bytes[: (16 - (len(client_dh_inner_hashed) % 16)) % 16] + client_dh_inner_hashed += random_bytes[ + : (16 - (len(client_dh_inner_hashed) % 16)) % 16 + ] client_dh_encrypted = encrypt_ige(client_dh_inner_hashed, key, iv) - return set_client_dh_params(nonce=nonce, server_nonce=server_nonce, encrypted_data=client_dh_encrypted), Step3( + return set_client_dh_params( + nonce=nonce, server_nonce=server_nonce, encrypted_data=client_dh_encrypted + ), Step3( nonce=nonce, server_nonce=server_nonce, new_nonce=new_nonce, @@ -263,7 +277,10 @@ def create_key(data: Step3, response: bytes) -> CreatedKey: first_salt = struct.unpack( " None: self._auth_key = auth_key self._time_offset: int = time_offset or 0 - self._salts: list[FutureSalt] = [FutureSalt(valid_since=0, valid_until=0x7FFFFFFF, salt=first_salt or 0)] + self._salts: list[FutureSalt] = [ + FutureSalt(valid_since=0, valid_until=0x7FFFFFFF, salt=first_salt or 0) + ] self._start_salt_time: Optional[tuple[int, float]] = None self._compression_threshold = compression_threshold self._deserialization: list[Deserialization] = [] @@ -201,7 +203,9 @@ class Encrypted(Mtp): if self._msg_count == 1: del self._buffer[:CONTAINER_HEADER_LEN] - self._buffer[:HEADER_LEN] = struct.pack(" None: new_session = NewSessionCreated.from_bytes(message.body) self._salts.clear() - self._salts.append(FutureSalt(valid_since=0, valid_until=0x7FFFFFFF, salt=new_session.server_salt)) + self._salts.append( + FutureSalt( + valid_since=0, valid_until=0x7FFFFFFF, salt=new_session.server_salt + ) + ) def _handle_container(self, message: Message) -> None: container = MsgContainer.from_bytes(message.body) @@ -362,7 +376,11 @@ class Encrypted(Mtp): self._deserialization.append(Update(message.body)) def _try_request_salts(self) -> None: - if len(self._salts) == 1 and self._salt_request_msg_id is None and self._get_current_salt() != 0: + if ( + len(self._salts) == 1 + and self._salt_request_msg_id is None + and self._get_current_salt() != 0 + ): # If salts are requested in a container leading to bad_msg, # the bad_msg_id will refer to the container, not the salts request. # @@ -370,7 +388,9 @@ class Encrypted(Mtp): # This would break, because we couldn't identify the response. # # So salts are only requested once we have a valid salt to reduce the chances of this happening. - self._salt_request_msg_id = self._serialize_msg(bytes(get_future_salts(num=NUM_FUTURE_SALTS)), True) + self._salt_request_msg_id = self._serialize_msg( + bytes(get_future_salts(num=NUM_FUTURE_SALTS)), True + ) def push(self, request: bytes) -> Optional[MsgId]: if self._start_salt_time and len(self._salts) >= 2: @@ -415,7 +435,9 @@ class Encrypted(Mtp): return MsgId(self._last_msg_id), encrypt_data_v2(result, self._auth_key) - def deserialize(self, payload: bytes | bytearray | memoryview) -> list[Deserialization]: + def deserialize( + self, payload: bytes | bytearray | memoryview + ) -> list[Deserialization]: check_message_buffer(payload) plaintext = decrypt_data_v2(payload, self._auth_key) diff --git a/client/src/telethon/_impl/mtproto/mtp/plain.py b/client/src/telethon/_impl/mtproto/mtp/plain.py index 43b93518..b1fe03d2 100644 --- a/client/src/telethon/_impl/mtproto/mtp/plain.py +++ b/client/src/telethon/_impl/mtproto/mtp/plain.py @@ -31,7 +31,9 @@ class Plain(Mtp): self._buffer.clear() return MsgId(0), result - def deserialize(self, payload: bytes | bytearray | memoryview) -> list[Deserialization]: + def deserialize( + self, payload: bytes | bytearray | memoryview + ) -> list[Deserialization]: check_message_buffer(payload) auth_key_id, msg_id, length = struct.unpack_from("= 0, got: {length}") if 20 + length > len(payload): - raise ValueError(f"message too short, expected: {20 + length}, got {len(payload)}") + raise ValueError( + f"message too short, expected: {20 + length}, got {len(payload)}" + ) return [RpcResult(MsgId(0), bytes(payload[20 : 20 + length]))] diff --git a/client/src/telethon/_impl/mtproto/mtp/types.py b/client/src/telethon/_impl/mtproto/mtp/types.py index 58f176cd..3798d859 100644 --- a/client/src/telethon/_impl/mtproto/mtp/types.py +++ b/client/src/telethon/_impl/mtproto/mtp/types.py @@ -116,7 +116,11 @@ class RpcError(ValueError): def __eq__(self, other: object) -> bool: if not isinstance(other, self.__class__): return NotImplemented - return self._code == other._code and self._name == other._name and self._value == other._value + return ( + self._code == other._code + and self._name == other._name + and self._value == other._value + ) # https://core.telegram.org/mtproto/service_messages_about_messages @@ -152,7 +156,9 @@ class BadMessage(ValueError): self.msg_id = msg_id self._code = code self._caused_by = caused_by - self.severity = logging.WARNING if self._code in NON_FATAL_MSG_IDS else logging.ERROR + self.severity = ( + logging.WARNING if self._code in NON_FATAL_MSG_IDS else logging.ERROR + ) @property def code(self) -> int: @@ -195,7 +201,9 @@ class Mtp(ABC): """ @abstractmethod - def deserialize(self, payload: bytes | bytearray | memoryview) -> list[Deserialization]: + def deserialize( + self, payload: bytes | bytearray | memoryview + ) -> list[Deserialization]: """ Deserialize incoming buffer payload. """ diff --git a/client/src/telethon/_impl/mtproto/transport/intermediate.py b/client/src/telethon/_impl/mtproto/transport/intermediate.py index 7d4c0053..3a9aebbb 100644 --- a/client/src/telethon/_impl/mtproto/transport/intermediate.py +++ b/client/src/telethon/_impl/mtproto/transport/intermediate.py @@ -42,7 +42,10 @@ class Intermediate(Transport): raise MissingBytes(expected=length, got=len(input)) if length <= 4: - if length >= 4 and (status := struct.unpack("= 4 + and (status := struct.unpack(" 0, got: {length}") diff --git a/client/src/telethon/_impl/mtproto/utils.py b/client/src/telethon/_impl/mtproto/utils.py index 68df596c..7b56fd17 100644 --- a/client/src/telethon/_impl/mtproto/utils.py +++ b/client/src/telethon/_impl/mtproto/utils.py @@ -12,7 +12,9 @@ MESSAGE_SIZE_OVERHEAD = 8 + 4 + 4 # msg_id, seq_no, bytes def check_message_buffer(message: bytes | bytearray | memoryview) -> None: if len(message) < 20: - raise ValueError(f"server payload is too small to be a valid message: {message.hex()}") + raise ValueError( + f"server payload is too small to be a valid message: {message.hex()}" + ) # https://core.telegram.org/mtproto/description#content-related-message diff --git a/client/src/telethon/_impl/mtsender/sender.py b/client/src/telethon/_impl/mtsender/sender.py index 86d04fa1..84533451 100644 --- a/client/src/telethon/_impl/mtsender/sender.py +++ b/client/src/telethon/_impl/mtsender/sender.py @@ -313,7 +313,13 @@ class Sender: def _on_ping_timeout(self) -> None: ping_id = generate_random_id() - self._enqueue_body(bytes(ping_delay_disconnect(ping_id=ping_id, disconnect_delay=NO_PING_DISCONNECT))) + self._enqueue_body( + bytes( + ping_delay_disconnect( + ping_id=ping_id, disconnect_delay=NO_PING_DISCONNECT + ) + ) + ) self._next_ping = asyncio.get_running_loop().time() + PING_DELAY def _process_mtp_buffer(self, updates: list[Updates]) -> None: @@ -329,7 +335,9 @@ class Sender: else: self._process_bad_message(result) - def _process_update(self, updates: list[Updates], update: bytes | bytearray | memoryview) -> None: + def _process_update( + self, updates: list[Updates], update: bytes | bytearray | memoryview + ) -> None: try: updates.append(Updates.from_bytes(update)) except ValueError: @@ -433,7 +441,9 @@ class Sender: req.state.msg_id == msg_id or req.state.container_msg_id == msg_id ): raise RuntimeError("got response for unsent request") - elif isinstance(req.state, Sent) and (req.state.msg_id == msg_id or req.state.container_msg_id == msg_id): + elif isinstance(req.state, Sent) and ( + req.state.msg_id == msg_id or req.state.container_msg_id == msg_id + ): yield self._requests.pop(i) @property diff --git a/client/src/telethon/_impl/session/chat/hash_cache.py b/client/src/telethon/_impl/session/chat/hash_cache.py index 2c5f7fe8..60117029 100644 --- a/client/src/telethon/_impl/session/chat/hash_cache.py +++ b/client/src/telethon/_impl/session/chat/hash_cache.py @@ -65,7 +65,9 @@ class ChatHashCache: return self._has_peer(peer.peer) elif isinstance(peer, types.NotifyForumTopic): return self._has_peer(peer.peer) - elif isinstance(peer, (types.NotifyUsers, types.NotifyChats, types.NotifyBroadcasts)): + elif isinstance( + peer, (types.NotifyUsers, types.NotifyChats, types.NotifyBroadcasts) + ): return True else: raise RuntimeError("unexpected case") @@ -118,7 +120,9 @@ class ChatHashCache: elif isinstance(participant, types.ChannelParticipantAdmin): return ( self._has(participant.user_id) - and (participant.inviter_id is None or self._has(participant.inviter_id)) + and ( + participant.inviter_id is None or self._has(participant.inviter_id) + ) and self._has(participant.promoted_by) ) elif isinstance(participant, types.ChannelParticipantBanned): diff --git a/client/src/telethon/_impl/session/chat/peer_ref.py b/client/src/telethon/_impl/session/chat/peer_ref.py index c1bd0070..9752e1de 100644 --- a/client/src/telethon/_impl/session/chat/peer_ref.py +++ b/client/src/telethon/_impl/session/chat/peer_ref.py @@ -43,8 +43,12 @@ class PeerRef(abc.ABC): __slots__ = ("identifier", "authorization") - def __init__(self, identifier: PeerIdentifier, authorization: PeerAuth = None) -> None: - assert identifier >= 0, "PeerRef identifiers must be positive; see the documentation for Peers" + def __init__( + self, identifier: PeerIdentifier, authorization: PeerAuth = None + ) -> None: + assert ( + identifier >= 0 + ), "PeerRef identifiers must be positive; see the documentation for Peers" self.identifier = identifier self.authorization = authorization @@ -79,7 +83,9 @@ class PeerRef(abc.ABC): authorization: Optional[int] = None else: try: - (authorization,) = struct.unpack("!q", base64.urlsafe_b64decode(auth.encode("ascii") + b"=")) + (authorization,) = struct.unpack( + "!q", base64.urlsafe_b64decode(auth.encode("ascii") + b"=") + ) except Exception: raise ValueError(f"invalid PeerRef string: {string!r}") @@ -131,14 +137,21 @@ class PeerRef(abc.ABC): if self.authorization is None: auth = "0" else: - auth = base64.urlsafe_b64encode(struct.pack("!q", self.authorization)).decode("ascii").rstrip("=") + auth = ( + base64.urlsafe_b64encode(struct.pack("!q", self.authorization)) + .decode("ascii") + .rstrip("=") + ) return f"{self.identifier}.{auth}" def __eq__(self, other: object) -> bool: if not isinstance(other, self.__class__): return NotImplemented - return self.identifier == other.identifier and self.authorization == other.authorization + return ( + self.identifier == other.identifier + and self.authorization == other.authorization + ) @property def _ref(self) -> UserRef | GroupRef | ChannelRef: @@ -169,12 +182,16 @@ class UserRef(PeerRef): def _to_input_peer(self) -> abcs.InputPeer: if self.identifier == SELF_USER_SENTINEL_ID: return types.InputPeerSelf() - return types.InputPeerUser(user_id=self.identifier, access_hash=self.authorization or 0) + return types.InputPeerUser( + user_id=self.identifier, access_hash=self.authorization or 0 + ) def _to_input_user(self) -> abcs.InputUser: if self.identifier == SELF_USER_SENTINEL_ID: return types.InputUserSelf() - return types.InputUser(user_id=self.identifier, access_hash=self.authorization or 0) + return types.InputUser( + user_id=self.identifier, access_hash=self.authorization or 0 + ) def __str__(self) -> str: return f"{USER_PREFIX}{self._encode_str()}" @@ -235,10 +252,14 @@ class ChannelRef(PeerRef): return types.PeerChannel(channel_id=self.identifier) def _to_input_peer(self) -> abcs.InputPeer: - return types.InputPeerChannel(channel_id=self.identifier, access_hash=self.authorization or 0) + return types.InputPeerChannel( + channel_id=self.identifier, access_hash=self.authorization or 0 + ) def _to_input_channel(self) -> types.InputChannel: - return types.InputChannel(channel_id=self.identifier, access_hash=self.authorization or 0) + return types.InputChannel( + channel_id=self.identifier, access_hash=self.authorization or 0 + ) def __str__(self) -> str: return f"{CHANNEL_PREFIX}{self._encode_str()}" diff --git a/client/src/telethon/_impl/session/message_box/adaptor.py b/client/src/telethon/_impl/session/message_box/adaptor.py index a2560757..2eafb5bd 100644 --- a/client/src/telethon/_impl/session/message_box/adaptor.py +++ b/client/src/telethon/_impl/session/message_box/adaptor.py @@ -27,7 +27,9 @@ def update_short(short: types.UpdateShort) -> types.UpdatesCombined: ) -def update_short_message(short: types.UpdateShortMessage, self_id: int) -> types.UpdatesCombined: +def update_short_message( + short: types.UpdateShortMessage, self_id: int +) -> types.UpdatesCombined: return update_short( types.UpdateShort( update=types.UpdateNewMessage( @@ -44,7 +46,9 @@ def update_short_message(short: types.UpdateShortMessage, self_id: int) -> types noforwards=False, reactions=None, id=short.id, - from_id=types.PeerUser(user_id=self_id if short.out else short.user_id), + from_id=types.PeerUser( + user_id=self_id if short.out else short.user_id + ), peer_id=types.PeerChat( chat_id=short.user_id, ), diff --git a/client/src/telethon/_impl/session/message_box/defs.py b/client/src/telethon/_impl/session/message_box/defs.py index dcfe4b71..3033c921 100644 --- a/client/src/telethon/_impl/session/message_box/defs.py +++ b/client/src/telethon/_impl/session/message_box/defs.py @@ -38,7 +38,9 @@ class PtsInfo: self.entry = entry def __repr__(self) -> str: - return f"PtsInfo(pts={self.pts}, pts_count={self.pts_count}, entry={self.entry})" + return ( + f"PtsInfo(pts={self.pts}, pts_count={self.pts_count}, entry={self.entry})" + ) class State: @@ -68,7 +70,9 @@ class PossibleGap: self.updates = updates def __repr__(self) -> str: - return f"PossibleGap(deadline={self.deadline}, update_count={len(self.updates)})" + return ( + f"PossibleGap(deadline={self.deadline}, update_count={len(self.updates)})" + ) class PrematureEndReason(Enum): diff --git a/client/src/telethon/_impl/session/message_box/messagebox.py b/client/src/telethon/_impl/session/message_box/messagebox.py index 68fb9c37..a93a01f9 100644 --- a/client/src/telethon/_impl/session/message_box/messagebox.py +++ b/client/src/telethon/_impl/session/message_box/messagebox.py @@ -91,9 +91,13 @@ class MessageBox: self.map[ENTRY_ACCOUNT] = State(pts=state.pts, deadline=deadline) if state.qts != NO_SEQ: self.map[ENTRY_SECRET] = State(pts=state.qts, deadline=deadline) - self.map.update((s.id, State(pts=s.pts, deadline=deadline)) for s in state.channels) + self.map.update( + (s.id, State(pts=s.pts, deadline=deadline)) for s in state.channels + ) - self.date = datetime.datetime.fromtimestamp(state.date, tz=datetime.timezone.utc) + self.date = datetime.datetime.fromtimestamp( + state.date, tz=datetime.timezone.utc + ) self.seq = state.seq self.possible_gaps.clear() self.getting_diff_for.clear() @@ -132,18 +136,28 @@ class MessageBox: default_deadline = next_updates_deadline() if self.possible_gaps: - deadline = min(default_deadline, *(gap.deadline for gap in self.possible_gaps.values())) + deadline = min( + default_deadline, *(gap.deadline for gap in self.possible_gaps.values()) + ) elif self.next_deadline in self.map: deadline = min(default_deadline, self.map[self.next_deadline].deadline) else: deadline = default_deadline if now >= deadline: - self.getting_diff_for.update(entry for entry, gap in self.possible_gaps.items() if now >= gap.deadline) - self.getting_diff_for.update(entry for entry, state in self.map.items() if now >= state.deadline) + self.getting_diff_for.update( + entry + for entry, gap in self.possible_gaps.items() + if now >= gap.deadline + ) + self.getting_diff_for.update( + entry for entry, state in self.map.items() if now >= state.deadline + ) if __debug__: - self._trace("deadlines met, now getting diff for: %r", self.getting_diff_for) + self._trace( + "deadlines met, now getting diff for: %r", self.getting_diff_for + ) for entry in self.getting_diff_for: self.possible_gaps.pop(entry, None) @@ -157,12 +171,19 @@ class MessageBox: entry: Entry = ENTRY_ACCOUNT # for pyright to know it's not unbound for entry in entries: if entry not in self.map: - raise RuntimeError("Called reset_deadline on an entry for which we do not have state") + raise RuntimeError( + "Called reset_deadline on an entry for which we do not have state" + ) self.map[entry].deadline = deadline if self.next_deadline in entries: - self.next_deadline = min(self.map.items(), key=lambda entry_state: entry_state[1].deadline)[0] - elif self.next_deadline in self.map and deadline < self.map[self.next_deadline].deadline: + self.next_deadline = min( + self.map.items(), key=lambda entry_state: entry_state[1].deadline + )[0] + elif ( + self.next_deadline in self.map + and deadline < self.map[self.next_deadline].deadline + ): self.next_deadline = entry def reset_channel_deadline(self, channel_id: int, timeout: Optional[float]) -> None: @@ -179,7 +200,9 @@ class MessageBox: assert isinstance(state, types.updates.State) self.map[ENTRY_ACCOUNT] = State(state.pts, deadline) self.map[ENTRY_SECRET] = State(state.qts, deadline) - self.date = datetime.datetime.fromtimestamp(state.date, tz=datetime.timezone.utc) + self.date = datetime.datetime.fromtimestamp( + state.date, tz=datetime.timezone.utc + ) self.seq = state.seq def try_set_channel_state(self, id: int, pts: int) -> None: @@ -192,11 +215,15 @@ class MessageBox: def try_begin_get_diff(self, entry: Entry, reason: str) -> None: if entry not in self.map: if entry in self.possible_gaps: - raise RuntimeError("Should not have a possible_gap for an entry not in the state map") + raise RuntimeError( + "Should not have a possible_gap for an entry not in the state map" + ) return if __debug__: - self._trace("marking entry=%r as needing difference because: %s", entry, reason) + self._trace( + "marking entry=%r as needing difference because: %s", entry, reason + ) self.getting_diff_for.add(entry) self.possible_gaps.pop(entry, None) @@ -204,10 +231,14 @@ class MessageBox: try: self.getting_diff_for.remove(entry) except KeyError: - raise RuntimeError("Called end_get_diff on an entry which was not getting diff for") + raise RuntimeError( + "Called end_get_diff on an entry which was not getting diff for" + ) self.reset_deadlines({entry}, next_updates_deadline()) - assert entry not in self.possible_gaps, "gaps shouldn't be created while getting difference" + assert ( + entry not in self.possible_gaps + ), "gaps shouldn't be created while getting difference" def ensure_known_peer_hashes( self, @@ -215,7 +246,10 @@ class MessageBox: chat_hashes: ChatHashCache, ) -> None: if not chat_hashes.extend_from_updates(updates): - can_recover = not isinstance(updates, types.UpdateShort) or pts_info_from_update(updates.update) is not None + can_recover = ( + not isinstance(updates, types.UpdateShort) + or pts_info_from_update(updates.update) is not None + ) if can_recover: self.try_begin_get_diff(ENTRY_ACCOUNT, "missing hash") raise Gap @@ -241,7 +275,9 @@ class MessageBox: if combined.seq_start != NO_SEQ: if self.seq + 1 > combined.seq_start: if __debug__: - self._trace("skipping updates as they should have already been handled") + self._trace( + "skipping updates as they should have already been handled" + ) return result, combined.users, combined.chats elif self.seq + 1 < combined.seq_start: self.try_begin_get_diff(ENTRY_ACCOUNT, "detected gap") @@ -269,13 +305,17 @@ class MessageBox: if __debug__: self._trace("updating seq as local pts was updated too") if combined.date != NO_DATE: - self.date = datetime.datetime.fromtimestamp(combined.date, tz=datetime.timezone.utc) + self.date = datetime.datetime.fromtimestamp( + combined.date, tz=datetime.timezone.utc + ) if combined.seq != NO_SEQ: self.seq = combined.seq if self.possible_gaps: if __debug__: - self._trace("trying to re-apply count=%r possible gaps", len(self.possible_gaps)) + self._trace( + "trying to re-apply count=%r possible gaps", len(self.possible_gaps) + ) for key in list(self.possible_gaps.keys()): self.possible_gaps[key].updates.sort(key=update_sort_key) @@ -292,7 +332,9 @@ class MessageBox: applied, ) - self.possible_gaps = {entry: gap for entry, gap in self.possible_gaps.items() if gap.updates} + self.possible_gaps = { + entry: gap for entry, gap in self.possible_gaps.items() if gap.updates + } return result, combined.users, combined.chats @@ -342,7 +384,8 @@ class MessageBox: ) if pts.entry not in self.possible_gaps: self.possible_gaps[pts.entry] = PossibleGap( - deadline=asyncio.get_running_loop().time() + POSSIBLE_GAP_TIMEOUT, + deadline=asyncio.get_running_loop().time() + + POSSIBLE_GAP_TIMEOUT, updates=[], ) @@ -370,14 +413,20 @@ class MessageBox: for entry in (ENTRY_ACCOUNT, ENTRY_SECRET): if entry in self.getting_diff_for: if entry not in self.map: - raise RuntimeError("Should not try to get difference for an entry without known state") + raise RuntimeError( + "Should not try to get difference for an entry without known state" + ) gd = functions.updates.get_difference( pts=self.map[ENTRY_ACCOUNT].pts, pts_limit=None, pts_total_limit=None, date=int(self.date.timestamp()), - qts=(self.map[ENTRY_SECRET].pts if ENTRY_SECRET in self.map else NO_SEQ), + qts=( + self.map[ENTRY_SECRET].pts + if ENTRY_SECRET in self.map + else NO_SEQ + ), qts_limit=None, ) if __debug__: @@ -398,7 +447,9 @@ class MessageBox: result: tuple[list[abcs.Update], Sequence[abcs.User], Sequence[abcs.Chat]] if isinstance(diff, types.updates.DifferenceEmpty): finish = True - self.date = datetime.datetime.fromtimestamp(diff.date, tz=datetime.timezone.utc) + self.date = datetime.datetime.fromtimestamp( + diff.date, tz=datetime.timezone.utc + ) self.seq = diff.seq result = [], [], [] elif isinstance(diff, types.updates.Difference): @@ -451,7 +502,9 @@ class MessageBox: assert isinstance(state, types.updates.State) self.map[ENTRY_ACCOUNT].pts = state.pts self.map[ENTRY_SECRET].pts = state.qts - self.date = datetime.datetime.fromtimestamp(state.date, tz=datetime.timezone.utc) + self.date = datetime.datetime.fromtimestamp( + state.date, tz=datetime.timezone.utc + ) self.seq = state.seq updates, users, chats = self.process_updates( @@ -507,13 +560,19 @@ class MessageBox: channel=channel, filter=types.ChannelMessagesFilterEmpty(), pts=state.pts, - limit=(BOT_CHANNEL_DIFF_LIMIT if chat_hashes.is_self_bot else USER_CHANNEL_DIFF_LIMIT), + limit=( + BOT_CHANNEL_DIFF_LIMIT + if chat_hashes.is_self_bot + else USER_CHANNEL_DIFF_LIMIT + ), ) if __debug__: self._trace("requesting channel difference: %s", gd) return gd else: - raise RuntimeError("should not try to get difference for an entry without known state") + raise RuntimeError( + "should not try to get difference for an entry without known state" + ) else: self.end_get_diff(entry) self.map.pop(entry, None) @@ -579,7 +638,9 @@ class MessageBox: else: raise RuntimeError("unexpected case") - def end_channel_difference(self, channel_id: int, reason: PrematureEndReason) -> None: + def end_channel_difference( + self, channel_id: int, reason: PrematureEndReason + ) -> None: entry: Entry = channel_id if __debug__: self._trace("ending channel=%r difference: %s", entry, reason) diff --git a/client/src/telethon/_impl/session/storage/sqlite.py b/client/src/telethon/_impl/session/storage/sqlite.py index 335ee035..2988e336 100644 --- a/client/src/telethon/_impl/session/storage/sqlite.py +++ b/client/src/telethon/_impl/session/storage/sqlite.py @@ -38,7 +38,9 @@ class SqliteSession(Storage): if version == 7: session = self._load_v7(c) else: - raise ValueError("only migration from sqlite session format 7 supported") + raise ValueError( + "only migration from sqlite session format 7 supported" + ) self._reset(c) self._get_or_init_version(c) @@ -103,7 +105,11 @@ class SqliteSession(Storage): DataCenter(id=id, ipv4_addr=ipv4_addr, ipv6_addr=ipv6_addr, auth=auth) for (id, ipv4_addr, ipv6_addr, auth) in datacenter ], - user=(User(id=user[0], dc=user[1], bot=bool(user[2]), username=user[3]) if user else None), + user=( + User(id=user[0], dc=user[1], bot=bool(user[2]), username=user[3]) + if user + else None + ), state=( UpdateState( pts=state[0], @@ -160,7 +166,9 @@ class SqliteSession(Storage): @staticmethod def _get_or_init_version(c: sqlite3.Cursor) -> int: - c.execute("select name from sqlite_master where type='table' and name='version'") + c.execute( + "select name from sqlite_master where type='table' and name='version'" + ) if c.fetchone(): c.execute("select version from version") tup = c.fetchone() diff --git a/client/src/telethon/_impl/tl/core/reader.py b/client/src/telethon/_impl/tl/core/reader.py index 36c8622d..3bb32a4c 100644 --- a/client/src/telethon/_impl/tl/core/reader.py +++ b/client/src/telethon/_impl/tl/core/reader.py @@ -23,7 +23,9 @@ def _bootstrap_get_ty(constructor_id: int) -> Optional[Type["Serializable"]]: from ..mtproto.layer import TYPE_MAPPING as MTPROTO_TYPES if API_TYPES.keys() & MTPROTO_TYPES.keys(): - raise RuntimeError("generated api and mtproto schemas cannot have colliding constructor identifiers") + raise RuntimeError( + "generated api and mtproto schemas cannot have colliding constructor identifiers" + ) ALL_TYPES = API_TYPES | MTPROTO_TYPES # Signatures don't fully match, but this is a private method @@ -37,7 +39,9 @@ class Reader: __slots__ = ("_view", "_pos", "_len") def __init__(self, buffer: "Buffer") -> None: - self._view = memoryview(buffer) if not isinstance(buffer, memoryview) else buffer + self._view = ( + memoryview(buffer) if not isinstance(buffer, memoryview) else buffer + ) self._pos = 0 self._len = len(self._view) diff --git a/client/src/telethon/_impl/tl/core/request.py b/client/src/telethon/_impl/tl/core/request.py index 83acc8d5..b053bdd4 100644 --- a/client/src/telethon/_impl/tl/core/request.py +++ b/client/src/telethon/_impl/tl/core/request.py @@ -14,7 +14,9 @@ def _bootstrap_get_deserializer( from ..mtproto.layer import RESPONSE_MAPPING as MTPROTO_DESER if API_DESER.keys() & MTPROTO_DESER.keys(): - raise RuntimeError("generated api and mtproto schemas cannot have colliding constructor identifiers") + raise RuntimeError( + "generated api and mtproto schemas cannot have colliding constructor identifiers" + ) ALL_DESER = API_DESER | MTPROTO_DESER Request._get_deserializer = ALL_DESER.get # type: ignore [assignment] diff --git a/client/src/telethon/_impl/tl/core/serializable.py b/client/src/telethon/_impl/tl/core/serializable.py index 9d753e4d..8af4e56c 100644 --- a/client/src/telethon/_impl/tl/core/serializable.py +++ b/client/src/telethon/_impl/tl/core/serializable.py @@ -49,7 +49,9 @@ class Serializable(abc.ABC): def __eq__(self, other: object) -> bool: if not isinstance(other, self.__class__): return NotImplemented - return all(getattr(self, attr) == getattr(other, attr) for attr in self.__slots__) + return all( + getattr(self, attr) == getattr(other, attr) for attr in self.__slots__ + ) def serialize_bytes_to(buffer: bytearray, data: bytes | bytearray | memoryview) -> None: diff --git a/client/tests/auth_key_test.py b/client/tests/auth_key_test.py index 2a6095dd..978b9554 100644 --- a/client/tests/auth_key_test.py +++ b/client/tests/auth_key_test.py @@ -26,16 +26,25 @@ def test_auth_key_id() -> None: def test_calc_new_nonce_hash1() -> None: auth_key = get_auth_key() new_nonce = get_new_nonce() - assert auth_key.calc_new_nonce_hash(new_nonce, 1) == 258944117842285651226187582903746985063 + assert ( + auth_key.calc_new_nonce_hash(new_nonce, 1) + == 258944117842285651226187582903746985063 + ) def test_calc_new_nonce_hash2() -> None: auth_key = get_auth_key() new_nonce = get_new_nonce() - assert auth_key.calc_new_nonce_hash(new_nonce, 2) == 324588944215647649895949797213421233055 + assert ( + auth_key.calc_new_nonce_hash(new_nonce, 2) + == 324588944215647649895949797213421233055 + ) def test_calc_new_nonce_hash3() -> None: auth_key = get_auth_key() new_nonce = get_new_nonce() - assert auth_key.calc_new_nonce_hash(new_nonce, 3) == 100989356540453064705070297823778556733 + assert ( + auth_key.calc_new_nonce_hash(new_nonce, 3) + == 100989356540453064705070297823778556733 + ) diff --git a/client/tests/crypto_test.py b/client/tests/crypto_test.py index faf92d2b..e94fa11d 100644 --- a/client/tests/crypto_test.py +++ b/client/tests/crypto_test.py @@ -66,8 +66,14 @@ def test_key_from_nonce() -> None: new_nonce = int.from_bytes(bytes(range(32))) key, iv = generate_key_data_from_nonce(server_nonce, new_nonce) - assert key == b'\x07X\xf1S;a]$\xf6\xe8\xa9Jo\xcb\xee\nU\xea\xab"\x17\xd7)\\\xa9!=\x1a-}\x16\xa6' - assert iv == b"Z\x84\x10\x8e\x98\x05el\xe8d\x07\x0e\x16nb\x18\xf6x>\x85\x11G\x1aZ\xb7\x80,\xf2\x00\x01\x02\x03" + assert ( + key + == b'\x07X\xf1S;a]$\xf6\xe8\xa9Jo\xcb\xee\nU\xea\xab"\x17\xd7)\\\xa9!=\x1a-}\x16\xa6' + ) + assert ( + iv + == b"Z\x84\x10\x8e\x98\x05el\xe8d\x07\x0e\x16nb\x18\xf6x>\x85\x11G\x1aZ\xb7\x80,\xf2\x00\x01\x02\x03" + ) def test_verify_ige_encryption() -> None: @@ -82,7 +88,5 @@ def test_verify_ige_decryption() -> None: ciphertext = get_test_aes_key_or_iv() key = get_test_aes_key_or_iv() iv = get_test_aes_key_or_iv() - expected = ( - b"\xe5wz\xfa\xcd{,\x16\xf7\xac@\xca\xe6\x1e\xf6\x03\xfe\xe6\t\x8f\xb8\xa8\x86\n\xb9\xeeg,\xd7\xe5\xba\xcc" - ) + expected = b"\xe5wz\xfa\xcd{,\x16\xf7\xac@\xca\xe6\x1e\xf6\x03\xfe\xe6\t\x8f\xb8\xa8\x86\n\xb9\xeeg,\xd7\xe5\xba\xcc" assert decrypt_ige(ciphertext, key, iv) == expected diff --git a/client/tests/parsers_test.py b/client/tests/parsers_test.py index 373ed105..4541306a 100644 --- a/client/tests/parsers_test.py +++ b/client/tests/parsers_test.py @@ -32,7 +32,10 @@ def test_parse_all_entities_markdown() -> None: markdown = "Some **bold** (__strong__), *italics* (_cursive_), inline `code`, a\n```rust\npre\n```\nblock, a [link](https://example.com), and [mentions](tg://user?id=12345678)" text, entities = parse_markdown_message(markdown) - assert text == "Some bold (strong), italics (cursive), inline code, a\npre\nblock, a link, and mentions" + assert ( + text + == "Some bold (strong), italics (cursive), inline code, a\npre\nblock, a link, and mentions" + ) assert entities == [ types.MessageEntityBold(offset=5, length=4), types.MessageEntityBold(offset=11, length=6), @@ -89,7 +92,10 @@ def test_parse_emoji_html() -> None: def test_parse_all_entities_html() -> None: html = 'Some bold (strong), italics (cursive), inline code, a
pre
block, a link,
spoilers
and mentions' text, entities = parse_html_message(html) - assert text == "Some bold (strong), italics (cursive), inline code, a pre block, a link, spoilers and mentions" + assert ( + text + == "Some bold (strong), italics (cursive), inline code, a pre block, a link, spoilers and mentions" + ) assert entities == [ types.MessageEntityBold(offset=5, length=4), types.MessageEntityBold(offset=11, length=6), diff --git a/generator/pyproject.toml b/generator/pyproject.toml index 2228b95c..9d32dd77 100644 --- a/generator/pyproject.toml +++ b/generator/pyproject.toml @@ -37,9 +37,6 @@ build-backend = "setuptools.build_meta" [tool.setuptools.dynamic] version = {attr = "telethon_generator.version.__version__"} -[tool.ruff] -line-length = 120 - [tool.ruff.lint] select = ["F", "E", "W", "I"] ignore = [ diff --git a/generator/src/telethon_generator/_impl/codegen/generator.py b/generator/src/telethon_generator/_impl/codegen/generator.py index c922d90c..d3fbe128 100644 --- a/generator/src/telethon_generator/_impl/codegen/generator.py +++ b/generator/src/telethon_generator/_impl/codegen/generator.py @@ -17,7 +17,9 @@ from .serde.deserialization import ( from .serde.serialization import generate_function, generate_write -def generate_init(writer: SourceWriter, namespaces: set[str], classes: set[str]) -> None: +def generate_init( + writer: SourceWriter, namespaces: set[str], classes: set[str] +) -> None: sorted_cls = list(sorted(classes)) sorted_ns = list(sorted(namespaces)) @@ -91,7 +93,9 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None: writer = fs.open(type_path) if type_path not in fs: - writer.write("# pyright: reportUnusedImport=false, reportConstantRedefinition=false") + writer.write( + "# pyright: reportUnusedImport=false, reportConstantRedefinition=false" + ) writer.write("import struct") writer.write("from typing import Optional, Self, Sequence") writer.write("from .. import abcs") @@ -102,7 +106,9 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None: generated_type_names.add(f"{ns}{to_class_name(typedef.name)}") # class Type(BaseType) - writer.write(f"class {to_class_name(typedef.name)}({inner_type_fmt(typedef.ty)}):") + writer.write( + f"class {to_class_name(typedef.name)}({inner_type_fmt(typedef.ty)}):" + ) # __slots__ = ('params', ...) slots = " ".join(f"'{p.name}'," for p in property_params) @@ -115,7 +121,9 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None: # def __init__() if property_params: - params = "".join(f", {p.name}: {param_type_fmt(p.ty)}" for p in property_params) + params = "".join( + f", {p.name}: {param_type_fmt(p.ty)}" for p in property_params + ) writer.write(f" def __init__(_s, *{params}) -> None:") for p in property_params: writer.write(f" _s.{p.name} = {p.name}") @@ -143,7 +151,9 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None: raise ValueError("nested function-namespaces are not supported") elif len(functiondef.namespace) == 1: function_namespaces.add(functiondef.namespace[0]) - function_path = (Path("functions") / functiondef.namespace[0]).with_suffix(".py") + function_path = (Path("functions") / functiondef.namespace[0]).with_suffix( + ".py" + ) else: function_def_names.add(to_method_name(functiondef.name)) function_path = Path("functions/_nons.py") @@ -163,14 +173,18 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None: params = "".join(f", {p.name}: {param_type_fmt(p.ty)}" for p in required_params) star = "*" if params else "" return_ty = param_type_fmt(NormalParameter(ty=functiondef.ty, flag=None)) - writer.write(f"def {to_method_name(functiondef.name)}({star}{params}) -> Request[{return_ty}]:") + writer.write( + f"def {to_method_name(functiondef.name)}({star}{params}) -> Request[{return_ty}]:" + ) writer.indent(2) generate_function(writer, functiondef) writer.dedent(2) generate_init(fs.open(Path("abcs/__init__.py")), abc_namespaces, abc_class_names) generate_init(fs.open(Path("types/__init__.py")), type_namespaces, type_class_names) - generate_init(fs.open(Path("functions/__init__.py")), function_namespaces, function_def_names) + generate_init( + fs.open(Path("functions/__init__.py")), function_namespaces, function_def_names + ) writer = fs.open(Path("layer.py")) writer.write("# pyright: reportUnusedImport=false") @@ -180,12 +194,16 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None: ) writer.write("from typing import cast, Type") writer.write(f"LAYER = {tl.layer!r}") - writer.write("TYPE_MAPPING = {t.constructor_id(): t for t in cast(tuple[Type[Serializable]], (") + writer.write( + "TYPE_MAPPING = {t.constructor_id(): t for t in cast(tuple[Type[Serializable]], (" + ) for name in sorted(generated_type_names): writer.write(f" types.{name},") writer.write("))}") writer.write("RESPONSE_MAPPING = {") for functiondef in tl.functiondefs: - writer.write(f" {hex(functiondef.id)}: {function_deserializer_fmt(functiondef)},") + writer.write( + f" {hex(functiondef.id)}: {function_deserializer_fmt(functiondef)}," + ) writer.write("}") writer.write("__all__ = ['LAYER', 'TYPE_MAPPING', 'RESPONSE_MAPPING']") diff --git a/generator/src/telethon_generator/_impl/codegen/serde/common.py b/generator/src/telethon_generator/_impl/codegen/serde/common.py index affb263d..0269af53 100644 --- a/generator/src/telethon_generator/_impl/codegen/serde/common.py +++ b/generator/src/telethon_generator/_impl/codegen/serde/common.py @@ -50,7 +50,9 @@ _TRIVIAL_STRUCT_MAP = {"int": "i", "long": "q", "double": "d", "Bool": "I"} def trivial_struct_fmt(ty: BaseParameter) -> str: try: - return _TRIVIAL_STRUCT_MAP[ty.ty.name] if isinstance(ty, NormalParameter) else "I" + return ( + _TRIVIAL_STRUCT_MAP[ty.ty.name] if isinstance(ty, NormalParameter) else "I" + ) except KeyError: raise ValueError("input param was not trivial") diff --git a/generator/src/telethon_generator/_impl/codegen/serde/deserialization.py b/generator/src/telethon_generator/_impl/codegen/serde/deserialization.py index 1d132626..a5aa2d8b 100644 --- a/generator/src/telethon_generator/_impl/codegen/serde/deserialization.py +++ b/generator/src/telethon_generator/_impl/codegen/serde/deserialization.py @@ -38,7 +38,9 @@ def reader_read_fmt(ty: Type, constructor_id: int) -> tuple[str, Optional[str]]: return f"reader.read_serializable({inner_type_fmt(ty)})", "type-abstract" -def generate_normal_param_read(writer: SourceWriter, name: str, param: NormalParameter, constructor_id: int) -> None: +def generate_normal_param_read( + writer: SourceWriter, name: str, param: NormalParameter, constructor_id: int +) -> None: flag_check = f"_{param.flag.name} & {1 << param.flag.index}" if param.flag else None if param.ty.name == "true": if not flag_check: @@ -53,7 +55,9 @@ def generate_normal_param_read(writer: SourceWriter, name: str, param: NormalPar if param.ty.generic_arg: if param.ty.name not in ("Vector", "vector"): - raise ValueError("generic_arg deserialization for non-vectors is not supported") + raise ValueError( + "generic_arg deserialization for non-vectors is not supported" + ) if param.ty.bare: writer.write("__len = reader.read_fmt(' str: def function_deserializer_fmt(defn: Definition) -> str: if defn.ty.generic_arg: if defn.ty.name != ("Vector"): - raise ValueError("generic_arg return for non-boxed-vectors is not supported") + raise ValueError( + "generic_arg return for non-boxed-vectors is not supported" + ) elif defn.ty.generic_ref: raise ValueError("return for generic refs inside vector is not supported") elif is_trivial(NormalParameter(ty=defn.ty.generic_arg, flag=None)): @@ -126,9 +138,13 @@ def function_deserializer_fmt(defn: Definition) -> str: elif defn.ty.generic_arg.name == "long": return "deserialize_i64_list" else: - raise ValueError(f"return for trivial arg {defn.ty.generic_arg} is not supported") + raise ValueError( + f"return for trivial arg {defn.ty.generic_arg} is not supported" + ) elif defn.ty.generic_arg.bare: - raise ValueError("return for non-boxed serializables inside a vector is not supported") + raise ValueError( + "return for non-boxed serializables inside a vector is not supported" + ) else: return f"list_deserializer({inner_type_fmt(defn.ty.generic_arg)})" elif defn.ty.generic_ref: diff --git a/generator/src/telethon_generator/_impl/codegen/serde/serialization.py b/generator/src/telethon_generator/_impl/codegen/serde/serialization.py index 1ccf3b11..792893c9 100644 --- a/generator/src/telethon_generator/_impl/codegen/serde/serialization.py +++ b/generator/src/telethon_generator/_impl/codegen/serde/serialization.py @@ -15,11 +15,15 @@ def param_value_expr(param: Parameter) -> str: return f"{pre}{mid}{suf}" -def generate_buffer_append(writer: SourceWriter, buffer: str, name: str, ty: Type) -> None: +def generate_buffer_append( + writer: SourceWriter, buffer: str, name: str, ty: Type +) -> None: if is_trivial(NormalParameter(ty=ty, flag=None)): fmt = trivial_struct_fmt(NormalParameter(ty=ty, flag=None)) if ty.name == "Bool": - writer.write(f"{buffer} += struct.pack(f'<{fmt}', (0x997275b5 if {name} else 0xbc799737))") + writer.write( + f"{buffer} += struct.pack(f'<{fmt}', (0x997275b5 if {name} else 0xbc799737))" + ) else: writer.write(f"{buffer} += struct.pack(f'<{fmt}', {name})") elif ty.generic_ref or ty.name == "Object": @@ -54,7 +58,9 @@ def generate_normal_param_write( if param.ty.generic_arg: if param.ty.name not in ("Vector", "vector"): - raise ValueError("generic_arg deserialization for non-vectors is not supported") + raise ValueError( + "generic_arg deserialization for non-vectors is not supported" + ) if param.ty.bare: writer.write(f"{buffer} += struct.pack(' None: else f"(0 if self.{p.name} is None else {1 << p.ty.flag.index})" ) for p in defn.params - if isinstance(p.ty, NormalParameter) and p.ty.flag and p.ty.flag.name == param.name + if isinstance(p.ty, NormalParameter) + and p.ty.flag + and p.ty.flag.name == param.name ) writer.write(f"_{param.name} = {flags or 0}") @@ -113,7 +123,9 @@ def generate_write(writer: SourceWriter, defn: Definition) -> None: for param in iter: if not isinstance(param.ty, NormalParameter): raise RuntimeError("FlagsParameter should be considered trivial") - generate_normal_param_write(writer, tmp_names, "buffer", f"self.{param.name}", param.ty) + generate_normal_param_write( + writer, tmp_names, "buffer", f"self.{param.name}", param.ty + ) def generate_function(writer: SourceWriter, defn: Definition) -> None: @@ -136,7 +148,9 @@ def generate_function(writer: SourceWriter, defn: Definition) -> None: else f"(0 if {p.name} is None else {1 << p.ty.flag.index})" ) for p in defn.params - if isinstance(p.ty, NormalParameter) and p.ty.flag and p.ty.flag.name == param.name + if isinstance(p.ty, NormalParameter) + and p.ty.flag + and p.ty.flag.name == param.name ) writer.write(f"{param.name} = {flags or 0}") @@ -147,5 +161,7 @@ def generate_function(writer: SourceWriter, defn: Definition) -> None: for param in iter: if not isinstance(param.ty, NormalParameter): raise RuntimeError("FlagsParameter should be considered trivial") - generate_normal_param_write(writer, tmp_names, "_buffer", param.name, param.ty) + generate_normal_param_write( + writer, tmp_names, "_buffer", param.name, param.ty + ) writer.write("return Request(b'' + _buffer)") diff --git a/generator/src/telethon_generator/_impl/tl_parser/loader.py b/generator/src/telethon_generator/_impl/tl_parser/loader.py index 80f4e3a3..70185de7 100644 --- a/generator/src/telethon_generator/_impl/tl_parser/loader.py +++ b/generator/src/telethon_generator/_impl/tl_parser/loader.py @@ -35,4 +35,6 @@ def load_tl_file(path: str | Path) -> ParsedTl: else: functiondefs.append(definition) - return ParsedTl(layer=layer, typedefs=list(typedefs), functiondefs=list(functiondefs)) + return ParsedTl( + layer=layer, typedefs=list(typedefs), functiondefs=list(functiondefs) + ) diff --git a/generator/tests/generator_test.py b/generator/tests/generator_test.py index 0fda1f7b..949a3179 100644 --- a/generator/tests/generator_test.py +++ b/generator/tests/generator_test.py @@ -14,7 +14,9 @@ def gen_py_code( functiondefs: Optional[list[Definition]] = None, ) -> str: fs = FakeFs() - generate(fs, ParsedTl(layer=0, typedefs=typedefs or [], functiondefs=functiondefs or [])) + generate( + fs, ParsedTl(layer=0, typedefs=typedefs or [], functiondefs=functiondefs or []) + ) generated = bytearray() for path, data in fs._files.items(): if path.stem not in ("__init__", "layer"): @@ -25,7 +27,9 @@ def gen_py_code( def test_generic_functions_use_bytes_parameters() -> None: - definitions = get_definitions("invokeWithLayer#da9b0d0d {X:Type} layer:int query:!X = X;") + definitions = get_definitions( + "invokeWithLayer#da9b0d0d {X:Type} layer:int query:!X = X;" + ) result = gen_py_code(functiondefs=definitions) assert "invoke_with_layer" in result assert "query: _bytes" in result diff --git a/generator/tests/parameter_test.py b/generator/tests/parameter_test.py index 5d7d1b37..6aeacae2 100644 --- a/generator/tests/parameter_test.py +++ b/generator/tests/parameter_test.py @@ -55,14 +55,18 @@ def test_valid_param() -> None: assert Parameter.from_str("foo:!bar") == Parameter( name="foo", ty=NormalParameter( - ty=Type(namespace=[], name="bar", bare=True, generic_ref=True, generic_arg=None), + ty=Type( + namespace=[], name="bar", bare=True, generic_ref=True, generic_arg=None + ), flag=None, ), ) assert Parameter.from_str("foo:bar.1?baz") == Parameter( name="foo", ty=NormalParameter( - ty=Type(namespace=[], name="baz", bare=True, generic_ref=False, generic_arg=None), + ty=Type( + namespace=[], name="baz", bare=True, generic_ref=False, generic_arg=None + ), flag=Flag( name="bar", index=1, diff --git a/generator/tests/ty_test.py b/generator/tests/ty_test.py index ac1f5879..8a1f4443 100644 --- a/generator/tests/ty_test.py +++ b/generator/tests/ty_test.py @@ -12,7 +12,9 @@ def test_empty_simple() -> None: def test_simple() -> None: - assert Type.from_str("foo") == Type(namespace=[], name="foo", bare=True, generic_ref=False, generic_arg=None) + assert Type.from_str("foo") == Type( + namespace=[], name="foo", bare=True, generic_ref=False, generic_arg=None + ) @mark.parametrize("ty", [".", "..", ".foo", "foo.", "foo..foo", ".foo."])