diff --git a/client/doc/roles/tl.py b/client/doc/roles/tl.py index f48c0a18..ac6cd515 100644 --- a/client/doc/roles/tl.py +++ b/client/doc/roles/tl.py @@ -15,9 +15,7 @@ def make_link_node(rawtext, app, name, options): base += "/" set_classes(options) - node = nodes.reference( - rawtext, utils.unescape(name), refuri="{}?q={}".format(base, name), **options - ) + node = nodes.reference(rawtext, utils.unescape(name), refuri="{}?q={}".format(base, name), **options) return node diff --git a/client/pyproject.toml b/client/pyproject.toml index 613ddf22..917e54bf 100644 --- a/client/pyproject.toml +++ b/client/pyproject.toml @@ -28,10 +28,8 @@ dynamic = ["version"] [project.optional-dependencies] cryptg = ["cryptg~=0.4"] dev = [ - "isort~=5.12", - "black~=23.3.0", - "mypy~=1.3", - "ruff~=0.0.292", + "mypy~=1.11.2", + "ruff~=0.6.8", "pytest~=7.3", "pytest-asyncio~=0.21", ] @@ -55,6 +53,10 @@ backend-path = ["build_backend"] version = {attr = "telethon.version.__version__"} [tool.ruff] +line-length = 120 + +[tool.ruff.lint] +select = ["F", "E", "W", "I"] ignore = [ "E501", # formatter takes care of lines that are too long besides documentation ] diff --git a/client/src/telethon/_impl/client/client/auth.py b/client/src/telethon/_impl/client/client/auth.py index 7220a4ec..2dfd05ba 100644 --- a/client/src/telethon/_impl/client/client/auth.py +++ b/client/src/telethon/_impl/client/client/auth.py @@ -31,9 +31,7 @@ async def complete_login(client: Client, auth: abcs.auth.Authorization) -> User: assert isinstance(auth, types.auth.Authorization) assert isinstance(auth.user, types.User) user = User._from_raw(auth.user) - client._session.user = SessionUser( - id=user.id, dc=client._sender.dc_id, bot=user.bot, username=user.username - ) + client._session.user = SessionUser(id=user.id, dc=client._sender.dc_id, bot=user.bot, username=user.username) client._chat_hashes.set_self_user(user.id, user.bot) @@ -56,9 +54,7 @@ async def complete_login(client: Client, auth: abcs.auth.Authorization) -> User: async def handle_migrate(client: Client, dc_id: Optional[int]) -> None: assert dc_id is not None - sender, client._session.dcs = await connect_sender( - client._config, client._session.dcs, DataCenter(id=dc_id) - ) + sender, client._session.dcs = await connect_sender(client._config, client._session.dcs, DataCenter(id=dc_id)) async with client._sender_lock: client._sender = sender @@ -177,9 +173,7 @@ async def interactive_login( user = await self.check_password(user_or_token, password) else: while True: - print( - "Please enter your password (prompt is hidden; type and press enter)" - ) + print("Please enter your password (prompt is hidden; type and press enter)") password = getpass.getpass(": ") try: user = await self.check_password(user_or_token, password) @@ -208,13 +202,9 @@ async def get_password_information(client: Client) -> PasswordToken: return PasswordToken._new(result) -async def check_password( - self: Client, token: PasswordToken, password: str | bytes -) -> User: +async def check_password(self: Client, token: PasswordToken, password: str | bytes) -> User: algo = token._password.current_algo - if not isinstance( - algo, types.PasswordKdfAlgoSha256Sha256Pbkdf2HmacshA512Iter100000Sha256ModPow - ): + if not isinstance(algo, types.PasswordKdfAlgoSha256Sha256Pbkdf2HmacshA512Iter100000Sha256ModPow): raise RuntimeError("unrecognised 2FA algorithm") if not two_factor_auth.check_p_and_g(algo.p, algo.g): diff --git a/client/src/telethon/_impl/client/client/bots.py b/client/src/telethon/_impl/client/client/bots.py index 10aa79d8..5464e836 100644 --- a/client/src/telethon/_impl/client/client/bots.py +++ b/client/src/telethon/_impl/client/client/bots.py @@ -38,11 +38,7 @@ class InlineResults(metaclass=NoPublicConstructor): result = await self._client( functions.messages.get_inline_bot_results( bot=self._bot, - peer=( - self._peer._to_input_peer() - if self._peer - else types.InputPeerEmpty() - ), + peer=(self._peer._to_input_peer() if self._peer else types.InputPeerEmpty()), geo_point=None, query=self._query, offset=self._offset, @@ -51,12 +47,8 @@ class InlineResults(metaclass=NoPublicConstructor): assert isinstance(result, types.messages.BotResults) self._offset = result.next_offset for r in reversed(result.results): - assert isinstance( - r, (types.BotInlineMediaResult, types.BotInlineResult) - ) - self._buffer.append( - InlineResult._create(self._client, result, r, self._peer) - ) + assert isinstance(r, (types.BotInlineMediaResult, types.BotInlineResult)) + self._buffer.append(InlineResult._create(self._client, result, r, self._peer)) if not self._buffer: self._offset = None diff --git a/client/src/telethon/_impl/client/client/chats.py b/client/src/telethon/_impl/client/client/chats.py index 9c507709..900d3b3a 100644 --- a/client/src/telethon/_impl/client/client/chats.py +++ b/client/src/telethon/_impl/client/client/chats.py @@ -53,9 +53,7 @@ class ParticipantList(AsyncList[Participant]): seen_count = len(self._seen) for p in chanp.participants: - part = Participant._from_raw_channel( - self._client, self._peer, p, chat_map - ) + part = Participant._from_raw_channel(self._client, self._peer, p, chat_map) pid = part._peer_id() if pid not in self._seen: self._seen.add(pid) @@ -66,9 +64,7 @@ class ParticipantList(AsyncList[Participant]): self._done = len(self._seen) == seen_count else: - chatp = await self._client( - functions.messages.get_full_chat(chat_id=self._peer._to_input_chat()) - ) + chatp = await self._client(functions.messages.get_full_chat(chat_id=self._peer._to_input_chat())) assert isinstance(chatp, types.messages.ChatFull) assert isinstance(chatp.full_chat, types.ChatFull) @@ -87,17 +83,14 @@ class ParticipantList(AsyncList[Participant]): ) elif isinstance(participants, types.ChatParticipants): self._buffer.extend( - Participant._from_raw_chat(self._client, self._peer, p, chat_map) - for p in participants.participants + Participant._from_raw_chat(self._client, self._peer, p, chat_map) for p in participants.participants ) self._total = len(self._buffer) self._done = True -def get_participants( - self: Client, chat: Group | Channel | GroupRef | ChannelRef, / -) -> AsyncList[Participant]: +def get_participants(self: Client, chat: Group | Channel | GroupRef | ChannelRef, /) -> AsyncList[Participant]: return ParticipantList(self, chat._ref) @@ -137,9 +130,7 @@ class RecentActionList(AsyncList[RecentAction]): self._offset = min(e.id for e in self._buffer) -def get_admin_log( - self: Client, chat: Group | Channel | GroupRef | ChannelRef, / -) -> AsyncList[RecentAction]: +def get_admin_log(self: Client, chat: Group | Channel | GroupRef | ChannelRef, /) -> AsyncList[RecentAction]: return RecentActionList(self, chat._ref) @@ -174,11 +165,7 @@ class ProfilePhotoList(AsyncList[File]): else: raise RuntimeError("unexpected case") - self._buffer.extend( - filter( - None, (File._try_from_raw_photo(self._client, p) for p in photos) - ) - ) + self._buffer.extend(filter(None, (File._try_from_raw_photo(self._client, p) for p in photos))) def get_profile_photos(self: Client, peer: Peer | PeerRef, /) -> AsyncList[File]: @@ -256,11 +243,7 @@ async def set_chat_default_restrictions( *, until: Optional[datetime.datetime] = None, ) -> None: - banned_rights = ChatRestriction._set_to_raw( - set(restrictions), int(until.timestamp()) if until else 0x7FFFFFFF - ) + banned_rights = ChatRestriction._set_to_raw(set(restrictions), int(until.timestamp()) if until else 0x7FFFFFFF) await self( - functions.messages.edit_chat_default_banned_rights( - peer=chat._ref._to_input_peer(), banned_rights=banned_rights - ) + functions.messages.edit_chat_default_banned_rights(peer=chat._ref._to_input_peer(), banned_rights=banned_rights) ) diff --git a/client/src/telethon/_impl/client/client/client.py b/client/src/telethon/_impl/client/client/client.py index fd9cd1f9..90a3f835 100644 --- a/client/src/telethon/_impl/client/client/client.py +++ b/client/src/telethon/_impl/client/client/client.py @@ -237,9 +237,7 @@ class Client: lang_code=lang_code or "en", catch_up=catch_up or False, datacenter=datacenter, - flood_sleep_threshold=( - 60 if flood_sleep_threshold is None else flood_sleep_threshold - ), + flood_sleep_threshold=(60 if flood_sleep_threshold is None else flood_sleep_threshold), update_queue_limit=update_queue_limit, base_logger=base_logger, connector=connector or (lambda ip, port: asyncio.open_connection(ip, port)), @@ -250,8 +248,8 @@ class Client: self._message_box = MessageBox(base_logger=base_logger) self._chat_hashes = ChatHashCache(None) self._last_update_limit_warn: Optional[float] = None - self._updates: asyncio.Queue[tuple[abcs.Update, dict[int, Peer]]] = ( - asyncio.Queue(maxsize=self._config.update_queue_limit or 0) + self._updates: asyncio.Queue[tuple[abcs.Update, dict[int, Peer]]] = asyncio.Queue( + maxsize=self._config.update_queue_limit or 0 ) self._dispatcher: Optional[asyncio.Task[None]] = None self._handlers: dict[ @@ -411,9 +409,7 @@ class Client: """ await delete_dialog(self, dialog) - async def delete_messages( - self, chat: Peer | PeerRef, /, message_ids: list[int], *, revoke: bool = True - ) -> int: + async def delete_messages(self, chat: Peer | PeerRef, /, message_ids: list[int], *, revoke: bool = True) -> int: """ Delete messages. @@ -653,9 +649,7 @@ class Client: """ return await forward_messages(self, target, message_ids, source) - def get_admin_log( - self, chat: Group | Channel | GroupRef | ChannelRef, / - ) -> AsyncList[RecentAction]: + def get_admin_log(self, chat: Group | Channel | GroupRef | ChannelRef, /) -> AsyncList[RecentAction]: """ Get the recent actions from the administrator's log. @@ -757,9 +751,7 @@ class Client: """ return get_file_bytes(self, media) - def get_handler_filter( - self, handler: Callable[[Event], Awaitable[Any]], / - ) -> Optional[FilterType]: + def get_handler_filter(self, handler: Callable[[Event], Awaitable[Any]], /) -> Optional[FilterType]: """ Get the filter associated to the given event handler. @@ -861,13 +853,9 @@ class Client: async for message in reversed(client.get_messages(chat)): print(message.sender.name, ':', message.markdown_text) """ - return get_messages( - self, chat, limit, offset_id=offset_id, offset_date=offset_date - ) + return get_messages(self, chat, limit, offset_id=offset_id, offset_date=offset_date) - def get_messages_with_ids( - self, chat: Peer | PeerRef, /, message_ids: list[int] - ) -> AsyncList[Message]: + def get_messages_with_ids(self, chat: Peer | PeerRef, /, message_ids: list[int]) -> AsyncList[Message]: """ Get the full message objects from the corresponding message identifiers. @@ -891,9 +879,7 @@ class Client: """ return get_messages_with_ids(self, chat, message_ids) - def get_participants( - self, chat: Group | Channel | GroupRef | ChannelRef, / - ) -> AsyncList[Participant]: + def get_participants(self, chat: Group | Channel | GroupRef | ChannelRef, /) -> AsyncList[Participant]: """ Get the participants in a group or channel, along with their permissions. @@ -981,9 +967,7 @@ class Client: """ return await inline_query(self, bot, query, peer=peer) - async def interactive_login( - self, phone_or_token: Optional[str] = None, *, password: Optional[str] = None - ) -> User: + async def interactive_login(self, phone_or_token: Optional[str] = None, *, password: Optional[str] = None) -> User: """ Begin an interactive login if needed. If the account was already logged-in, this method simply returns :term:`yourself`. @@ -1035,9 +1019,7 @@ class Client: def on( self, event_cls: Type[Event], /, filter: Optional[FilterType] = None - ) -> Callable[ - [Callable[[Event], Awaitable[Any]]], Callable[[Event], Awaitable[Any]] - ]: + ) -> Callable[[Callable[[Event], Awaitable[Any]]], Callable[[Event], Awaitable[Any]]]: """ Register the decorated function to be invoked when the provided event type occurs. @@ -1125,9 +1107,7 @@ class Client: """ return prepare_album(self) - async def read_message( - self, chat: Peer | PeerRef, /, message_id: int | Literal["all"] - ) -> None: + async def read_message(self, chat: Peer | PeerRef, /, message_id: int | Literal["all"]) -> None: """ Mark messages as read. @@ -1158,9 +1138,7 @@ class Client: """ await read_message(self, chat, message_id) - def remove_event_handler( - self, handler: Callable[[Event], Awaitable[Any]], / - ) -> None: + def remove_event_handler(self, handler: Callable[[Event], Awaitable[Any]], /) -> None: """ Remove the handler as a function to be called when events occur. This is simply the opposite of :meth:`add_event_handler`. @@ -1324,9 +1302,7 @@ class Client: async for message in client.search_all_messages(query='hello'): print(message.text) """ - return search_all_messages( - self, limit, query=query, offset_id=offset_id, offset_date=offset_date - ) + return search_all_messages(self, limit, query=query, offset_id=offset_id, offset_date=offset_date) def search_messages( self, @@ -1372,9 +1348,7 @@ class Client: async for message in client.search_messages(chat, query='hello'): print(message.text) """ - return search_messages( - self, chat, limit, query=query, offset_id=offset_id, offset_date=offset_date - ) + return search_messages(self, chat, limit, query=query, offset_id=offset_id, offset_date=offset_date) async def send_audio( self, @@ -1975,9 +1949,7 @@ class Client: :meth:`telethon.types.Participant.set_restrictions` """ - await set_participant_restrictions( - self, chat, participant, restrictions, until=until - ) + await set_participant_restrictions(self, chat, participant, restrictions, until=until) async def sign_in(self, token: LoginToken, code: str) -> User | PasswordToken: """ @@ -2024,9 +1996,7 @@ class Client: """ await sign_out(self) - async def unpin_message( - self, chat: Peer | PeerRef, /, message_id: int | Literal["all"] - ) -> None: + async def unpin_message(self, chat: Peer | PeerRef, /, message_id: int | Literal["all"]) -> None: """ Unpin one or all messages from the top. diff --git a/client/src/telethon/_impl/client/client/dialogs.py b/client/src/telethon/_impl/client/client/dialogs.py index 2b3cf9f3..1dba8752 100644 --- a/client/src/telethon/_impl/client/client/dialogs.py +++ b/client/src/telethon/_impl/client/client/dialogs.py @@ -51,9 +51,7 @@ class DialogList(AsyncList[Dialog]): chat_map = build_chat_map(self._client, result.users, result.chats) msg_map = build_msg_map(self._client, result.messages, chat_map) - self._buffer.extend( - Dialog._from_raw(self._client, d, chat_map, msg_map) for d in result.dialogs - ) + self._buffer.extend(Dialog._from_raw(self._client, d, chat_map, msg_map) for d in result.dialogs) def get_dialogs(self: Client) -> AsyncList[Dialog]: @@ -130,9 +128,7 @@ async def edit_draft( reply_to: Optional[int] = None, ) -> Draft: peer = peer._ref - message, entities = parse_message( - text=text, markdown=markdown, html=html, allow_empty=False - ) + message, entities = parse_message(text=text, markdown=markdown, html=html, allow_empty=False) result = await self( functions.messages.save_draft( diff --git a/client/src/telethon/_impl/client/client/files.py b/client/src/telethon/_impl/client/client/files.py index a0ec66ed..a52c9c23 100644 --- a/client/src/telethon/_impl/client/client/files.py +++ b/client/src/telethon/_impl/client/client/files.py @@ -189,15 +189,11 @@ async def send_file( reply_to: Optional[int] = None, keyboard: Optional[KeyboardType] = None, ) -> Message: - message, entities = parse_message( - text=caption, markdown=caption_markdown, html=caption_html, allow_empty=True - ) + message, entities = parse_message(text=caption, markdown=caption_markdown, html=caption_html, allow_empty=True) # Re-send existing file. if isinstance(file, File): - return await do_send_file( - self, chat, file._input_media, message, entities, reply_to, keyboard - ) + return await do_send_file(self, chat, file._input_media, message, entities, reply_to, keyboard) # URLs are handled early as they can't use any other attributes either. input_media: abcs.InputMedia @@ -212,16 +208,10 @@ async def send_file( else: as_photo = False if as_photo: - input_media = types.InputMediaPhotoExternal( - spoiler=False, url=file, ttl_seconds=None - ) + input_media = types.InputMediaPhotoExternal(spoiler=False, url=file, ttl_seconds=None) else: - input_media = types.InputMediaDocumentExternal( - spoiler=False, url=file, ttl_seconds=None - ) - return await do_send_file( - self, chat, input_media, message, entities, reply_to, keyboard - ) + input_media = types.InputMediaDocumentExternal(spoiler=False, url=file, ttl_seconds=None) + return await do_send_file(self, chat, input_media, message, entities, reply_to, keyboard) input_file, name = await upload(self, file, size, name) @@ -285,9 +275,7 @@ async def send_file( ttl_seconds=None, ) - return await do_send_file( - self, chat, input_media, message, entities, reply_to, keyboard - ) + return await do_send_file(self, chat, input_media, message, entities, reply_to, keyboard) async def do_send_file( @@ -309,11 +297,7 @@ async def do_send_file( noforwards=False, update_stickersets_order=False, peer=chat._ref._to_input_peer(), - reply_to=( - types.InputReplyToMessage(reply_to_msg_id=reply_to, top_msg_id=None) - if reply_to - else None - ), + reply_to=(types.InputReplyToMessage(reply_to_msg_id=reply_to, top_msg_id=None) if reply_to else None), media=input_media, message=message, random_id=random_id, @@ -395,11 +379,7 @@ async def do_upload( ) ) else: - await client( - functions.upload.save_file_part( - file_id=file_id, file_part=part, bytes=to_store - ) - ) + await client(functions.upload.save_file_part(file_id=file_id, file_part=part, bytes=to_store)) hash_md5.update(to_store) buffer.clear() @@ -458,9 +438,7 @@ def get_file_bytes(self: Client, media: File, /) -> AsyncList[bytes]: return FileBytesList(self, media) -async def download( - self: Client, media: File, /, file: str | Path | OutFileLike -) -> None: +async def download(self: Client, media: File, /, file: str | Path | OutFileLike) -> None: fd = OutWrapper(file) try: async for chunk in get_file_bytes(self, media): diff --git a/client/src/telethon/_impl/client/client/messages.py b/client/src/telethon/_impl/client/client/messages.py index 23f038a9..cc960a66 100644 --- a/client/src/telethon/_impl/client/client/messages.py +++ b/client/src/telethon/_impl/client/client/messages.py @@ -46,9 +46,7 @@ async def send_message( update_stickersets_order=False, peer=chat._ref._to_input_peer(), reply_to=( - types.InputReplyToMessage( - reply_to_msg_id=text.replied_message_id, top_msg_id=None - ) + types.InputReplyToMessage(reply_to_msg_id=text.replied_message_id, top_msg_id=None) if text.replied_message_id else None ), @@ -60,9 +58,7 @@ async def send_message( send_as=None, ) else: - message, entities = parse_message( - text=text, markdown=markdown, html=html, allow_empty=False - ) + message, entities = parse_message(text=text, markdown=markdown, html=html, allow_empty=False) request = functions.messages.send_message( no_webpage=not link_preview, silent=False, @@ -71,11 +67,7 @@ async def send_message( noforwards=False, update_stickersets_order=False, peer=chat._ref._to_input_peer(), - reply_to=( - types.InputReplyToMessage(reply_to_msg_id=reply_to, top_msg_id=None) - if reply_to - else None - ), + reply_to=(types.InputReplyToMessage(reply_to_msg_id=reply_to, top_msg_id=None) if reply_to else None), message=message, random_id=random_id, reply_markup=keyboard._raw if keyboard else None, @@ -91,11 +83,7 @@ async def send_message( {}, out=result.out, id=result.id, - from_id=( - types.PeerUser(user_id=self._session.user.id) - if self._session.user - else None - ), + from_id=(types.PeerUser(user_id=self._session.user.id) if self._session.user else None), peer_id=chat._ref._to_peer(), reply_to=( types.MessageReplyHeader( @@ -130,9 +118,7 @@ async def edit_message( link_preview: bool = False, keyboard: Optional[KeyboardType] = None, ) -> Message: - message, entities = parse_message( - text=text, markdown=markdown, html=html, allow_empty=False - ) + message, entities = parse_message(text=text, markdown=markdown, html=html, allow_empty=False) return self._build_message_map( await self( functions.messages.edit_message( @@ -160,15 +146,9 @@ async def delete_messages( ) -> int: peer = chat._ref if isinstance(peer, ChannelRef): - affected = await self( - functions.channels.delete_messages( - channel=peer._to_input_channel(), id=message_ids - ) - ) + affected = await self(functions.channels.delete_messages(channel=peer._to_input_channel(), id=message_ids)) else: - affected = await self( - functions.messages.delete_messages(revoke=revoke, id=message_ids) - ) + affected = await self(functions.messages.delete_messages(revoke=revoke, id=message_ids)) assert isinstance(affected, types.messages.AffectedMessages) return affected.pts_count @@ -205,9 +185,7 @@ class MessageList(AsyncList[Message]): super().__init__() self._reversed = False - def _extend_buffer( - self, client: Client, messages: abcs.messages.Messages - ) -> dict[int, Peer]: + def _extend_buffer(self, client: Client, messages: abcs.messages.Messages) -> dict[int, Peer]: if isinstance(messages, types.messages.MessagesNotModified): self._total = messages.count return {} @@ -215,9 +193,7 @@ class MessageList(AsyncList[Message]): if isinstance(messages, types.messages.Messages): self._total = len(messages.messages) self._done = True - elif isinstance( - messages, (types.messages.MessagesSlice, types.messages.ChannelMessages) - ): + elif isinstance(messages, (types.messages.MessagesSlice, types.messages.ChannelMessages)): self._total = messages.count else: raise RuntimeError("unexpected case") @@ -225,9 +201,7 @@ class MessageList(AsyncList[Message]): chat_map = build_chat_map(client, messages.users, messages.chats) self._buffer.extend( Message._from_raw(client, m, chat_map) - for m in ( - reversed(messages.messages) if self._reversed else messages.messages - ) + for m in (reversed(messages.messages) if self._reversed else messages.messages) ) return chat_map @@ -235,11 +209,7 @@ class MessageList(AsyncList[Message]): self, ) -> types.Message | types.MessageService | types.MessageEmpty: return next( - ( - m._raw - for m in reversed(self._buffer) - if not isinstance(m._raw, types.MessageEmpty) - ), + (m._raw for m in reversed(self._buffer) if not isinstance(m._raw, types.MessageEmpty)), types.MessageEmpty(id=0, peer_id=None), ) @@ -335,14 +305,10 @@ class CherryPickedList(MessageList): if isinstance(self._peer, ChannelRef): result = await self._client( - functions.channels.get_messages( - channel=self._peer._to_input_channel(), id=self._ids[:100] - ) + functions.channels.get_messages(channel=self._peer._to_input_channel(), id=self._ids[:100]) ) else: - result = await self._client( - functions.messages.get_messages(id=self._ids[:100]) - ) + result = await self._client(functions.messages.get_messages(id=self._ids[:100])) self._extend_buffer(self._client, result) self._ids = self._ids[100:] @@ -492,9 +458,7 @@ def search_all_messages( ) -async def pin_message( - self: Client, chat: Peer | PeerRef, /, message_id: int -) -> Message: +async def pin_message(self: Client, chat: Peer | PeerRef, /, message_id: int) -> Message: return self._build_message_map( await self( functions.messages.update_pinned_message( @@ -509,9 +473,7 @@ async def pin_message( ).get_single() -async def unpin_message( - self: Client, chat: Peer | PeerRef, /, message_id: int | Literal["all"] -) -> None: +async def unpin_message(self: Client, chat: Peer | PeerRef, /, message_id: int | Literal["all"]) -> None: if message_id == "all": await self( functions.messages.unpin_all_messages( @@ -531,25 +493,15 @@ async def unpin_message( ) -async def read_message( - self: Client, chat: Peer | PeerRef, /, message_id: int | Literal["all"] -) -> None: +async def read_message(self: Client, chat: Peer | PeerRef, /, message_id: int | Literal["all"]) -> None: if message_id == "all": message_id = 0 peer = chat._ref if isinstance(peer, ChannelRef): - await self( - functions.channels.read_history( - channel=peer._to_input_channel(), max_id=message_id - ) - ) + await self(functions.channels.read_history(channel=peer._to_input_channel(), max_id=message_id)) else: - await self( - functions.messages.read_history( - peer=peer._ref._to_input_peer(), max_id=message_id - ) - ) + await self(functions.messages.read_history(peer=peer._ref._to_input_peer(), max_id=message_id)) class MessageMap: @@ -584,9 +536,7 @@ class MessageMap: def _empty(self, id: int = 0) -> Message: return Message._from_raw( self._client, - types.MessageEmpty( - id=id, peer_id=self._peer._to_peer() if self._peer else None - ), + types.MessageEmpty(id=id, peer_id=self._peer._to_peer() if self._peer else None), {}, ) diff --git a/client/src/telethon/_impl/client/client/net.py b/client/src/telethon/_impl/client/client/net.py index 23634cdb..c3452737 100644 --- a/client/src/telethon/_impl/client/client/net.py +++ b/client/src/telethon/_impl/client/client/net.py @@ -82,16 +82,9 @@ async def connect_sender( # Only the ID of the input DC may be known. # Find the corresponding address and authentication key if needed. addr = dc.ipv4_addr or next( - d.ipv4_addr - for d in itertools.chain(known_dcs, KNOWN_DCS) - if d.id == dc.id and d.ipv4_addr - ) - auth = ( - None - if force_auth_gen - else dc.auth - or (next((d.auth for d in known_dcs if d.id == dc.id and d.auth), None)) + d.ipv4_addr for d in itertools.chain(known_dcs, KNOWN_DCS) if d.id == dc.id and d.ipv4_addr ) + auth = None if force_auth_gen else dc.auth or (next((d.auth for d in known_dcs if d.id == dc.id and d.auth), None)) sender = await do_connect_sender( Full(), @@ -122,9 +115,7 @@ async def connect_sender( ) except BadStatus as e: if e.status == 404 and auth: - dc = DataCenter( - id=dc.id, ipv4_addr=dc.ipv4_addr, ipv6_addr=dc.ipv6_addr, auth=None - ) + dc = DataCenter(id=dc.id, ipv4_addr=dc.ipv4_addr, ipv6_addr=dc.ipv6_addr, auth=None) config.base_logger.warning( "datacenter could not find stored auth; will retry generating a new one: %s", dc, @@ -174,12 +165,8 @@ async def connect(self: Client) -> None: if session := await self._storage.load(): self._session = session - datacenter = self._config.datacenter or DataCenter( - id=self._session.user.dc if self._session.user else DEFAULT_DC - ) - self._sender, self._session.dcs = await connect_sender( - self._config, self._session.dcs, datacenter - ) + datacenter = self._config.datacenter or DataCenter(id=self._session.user.dc if self._session.user else DEFAULT_DC) + self._sender, self._session.dcs = await connect_sender(self._config, self._session.dcs, datacenter) if self._message_box.is_empty() and self._session.user: try: @@ -193,9 +180,7 @@ async def connect(self: Client) -> None: me = await self.get_me() assert me is not None self._chat_hashes.set_self_user(me.id, me.bot) - self._session.user = SessionUser( - id=me.id, dc=self._sender.dc_id, bot=me.bot, username=me.username - ) + self._session.user = SessionUser(id=me.id, dc=self._sender.dc_id, bot=me.bot, username=me.username) self._dispatcher = asyncio.create_task(dispatcher(self)) @@ -214,18 +199,14 @@ async def disconnect(self: Client) -> None: except asyncio.CancelledError: pass except Exception: - self._config.base_logger.exception( - "unhandled exception when cancelling dispatcher; this is a bug" - ) + self._config.base_logger.exception("unhandled exception when cancelling dispatcher; this is a bug") finally: self._dispatcher = None try: await sender.disconnect() except Exception: - self._config.base_logger.exception( - "unhandled exception during disconnect; this is a bug" - ) + self._config.base_logger.exception("unhandled exception during disconnect; this is a bug") try: if self._session.user: diff --git a/client/src/telethon/_impl/client/client/updates.py b/client/src/telethon/_impl/client/client/updates.py index 2ef036a4..8852c028 100644 --- a/client/src/telethon/_impl/client/client/updates.py +++ b/client/src/telethon/_impl/client/client/updates.py @@ -40,9 +40,7 @@ def add_event_handler( self._handlers.setdefault(event_cls, []).append((handler, filter)) -def remove_event_handler( - self: Client, handler: Callable[[Event], Awaitable[Any]], / -) -> None: +def remove_event_handler(self: Client, handler: Callable[[Event], Awaitable[Any]], /) -> None: for event_cls, handlers in tuple(self._handlers.items()): for i in reversed(range(len(handlers))): if handlers[i][0] == handler: @@ -51,9 +49,7 @@ def remove_event_handler( del self._handlers[event_cls] -def get_handler_filter( - self: Client, handler: Callable[[Event], Awaitable[Any]], / -) -> Optional[FilterType]: +def get_handler_filter(self: Client, handler: Callable[[Event], Awaitable[Any]], /) -> Optional[FilterType]: for handlers in self._handlers.values(): for h, f in handlers: if h == handler: @@ -84,9 +80,7 @@ def process_socket_updates(client: Client, all_updates: list[abcs.Updates]) -> N return try: - result, users, chats = client._message_box.process_updates( - updates, client._chat_hashes - ) + result, users, chats = client._message_box.process_updates(updates, client._chat_hashes) except Gap: return @@ -107,8 +101,7 @@ def extend_update_queue( except asyncio.QueueFull: now = asyncio.get_running_loop().time() if client._last_update_limit_warn is None or ( - now - client._last_update_limit_warn - > UPDATE_LIMIT_EXCEEDED_LOG_COOLDOWN + now - client._last_update_limit_warn > UPDATE_LIMIT_EXCEEDED_LOG_COOLDOWN ): client._config.base_logger.warning( "updates are being dropped because limit=%d has been reached", diff --git a/client/src/telethon/_impl/client/client/users.py b/client/src/telethon/_impl/client/client/users.py index 345f05d7..63fe3f3b 100644 --- a/client/src/telethon/_impl/client/client/users.py +++ b/client/src/telethon/_impl/client/client/users.py @@ -53,15 +53,11 @@ def resolved_peer_to_chat(client: Client, resolved: abcs.contacts.ResolvedPeer) async def resolve_phone(self: Client, phone: str, /) -> Peer: - return resolved_peer_to_chat( - self, await self(functions.contacts.resolve_phone(phone=phone)) - ) + return resolved_peer_to_chat(self, await self(functions.contacts.resolve_phone(phone=phone))) async def resolve_username(self: Client, username: str, /) -> Peer: - return resolved_peer_to_chat( - self, await self(functions.contacts.resolve_username(username=username)) - ) + return resolved_peer_to_chat(self, await self(functions.contacts.resolve_username(username=username))) async def resolve_peers(self: Client, peers: Sequence[Peer | PeerRef], /) -> list[Peer]: @@ -99,8 +95,4 @@ async def resolve_peers(self: Client, peers: Sequence[Peer | PeerRef], /) -> lis chats.extend(ret_chats.chats) chat_map = build_chat_map(self, users, chats) - return [ - chat_map.get(ref.identifier) - or expand_peer(self, ref._to_peer(), broadcast=None) - for ref in refs - ] + return [chat_map.get(ref.identifier) or expand_peer(self, ref._to_peer(), broadcast=None) for ref in refs] diff --git a/client/src/telethon/_impl/client/errors.py b/client/src/telethon/_impl/client/errors.py index 746c4bbd..43040121 100644 --- a/client/src/telethon/_impl/client/errors.py +++ b/client/src/telethon/_impl/client/errors.py @@ -34,17 +34,13 @@ def from_name(name: str, *, _cache: dict[str, Type[RpcError]] = {}) -> Type[RpcE return _cache[name] -def adapt_rpc( - error: RpcError, *, _cache: dict[tuple[int, str], Type[RpcError]] = {} -) -> RpcError: +def adapt_rpc(error: RpcError, *, _cache: dict[tuple[int, str], Type[RpcError]] = {}) -> RpcError: code = canonicalize_code(error.code) name = canonicalize_name(error.name) tup = code, name if tup not in _cache: _cache[tup] = type(pretty_name(name), (from_code(code), from_name(name)), {}) - return _cache[tup]( - code=error.code, name=error.name, value=error.value, caused_by=error._caused_by - ) + return _cache[tup](code=error.code, name=error.name, value=error.value, caused_by=error._caused_by) class ErrorFactory: @@ -55,9 +51,7 @@ class ErrorFactory: return from_code(int(m[1])) else: adapted = adapt_user_name(name) - if pretty_name(canonicalize_name(adapted)) != name or re.match( - r"[A-Z]{2}", name - ): + if pretty_name(canonicalize_name(adapted)) != name or re.match(r"[A-Z]{2}", name): raise AttributeError(f"error subclass names must be CamelCase: {name}") return from_name(adapted) diff --git a/client/src/telethon/_impl/client/events/event.py b/client/src/telethon/_impl/client/events/event.py index 8fcc2dbf..e44c906d 100644 --- a/client/src/telethon/_impl/client/events/event.py +++ b/client/src/telethon/_impl/client/events/event.py @@ -24,9 +24,7 @@ class Event(abc.ABC, metaclass=NoPublicConstructor): @classmethod @abc.abstractmethod - def _try_from_update( - cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer] - ) -> Optional[Self]: + def _try_from_update(cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]) -> Optional[Self]: pass @@ -52,9 +50,7 @@ class Raw(Event): self._chat_map = chat_map @classmethod - def _try_from_update( - cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer] - ) -> Optional[Self]: + def _try_from_update(cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]) -> Optional[Self]: return cls._create(client, update, chat_map) diff --git a/client/src/telethon/_impl/client/events/filters/combinators.py b/client/src/telethon/_impl/client/events/filters/combinators.py index 5bce9950..2e41fded 100644 --- a/client/src/telethon/_impl/client/events/filters/combinators.py +++ b/client/src/telethon/_impl/client/events/filters/combinators.py @@ -65,9 +65,7 @@ class Any(Combinable): __slots__ = ("_filters",) - def __init__( - self, filter1: FilterType, filter2: FilterType, *filters: FilterType - ) -> None: + def __init__(self, filter1: FilterType, filter2: FilterType, *filters: FilterType) -> None: self._filters = (filter1, filter2, *filters) @property @@ -111,9 +109,7 @@ class All(Combinable): __slots__ = ("_filters",) - def __init__( - self, filter1: FilterType, filter2: FilterType, *filters: FilterType - ) -> None: + def __init__(self, filter1: FilterType, filter2: FilterType, *filters: FilterType) -> None: self._filters = (filter1, filter2, *filters) @property diff --git a/client/src/telethon/_impl/client/events/messages.py b/client/src/telethon/_impl/client/events/messages.py index 65cf42de..335db355 100644 --- a/client/src/telethon/_impl/client/events/messages.py +++ b/client/src/telethon/_impl/client/events/messages.py @@ -24,15 +24,11 @@ class NewMessage(Event, Message): """ @classmethod - def _try_from_update( - cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer] - ) -> Optional[Self]: + def _try_from_update(cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]) -> Optional[Self]: if isinstance(update, (types.UpdateNewMessage, types.UpdateNewChannelMessage)): if isinstance(update.message, types.Message): return cls._from_raw(client, update.message, chat_map) - elif isinstance( - update, (types.UpdateShortMessage, types.UpdateShortChatMessage) - ): + elif isinstance(update, (types.UpdateShortMessage, types.UpdateShortChatMessage)): raise RuntimeError("should have been handled by adaptor") return None @@ -46,12 +42,8 @@ class MessageEdited(Event, Message): """ @classmethod - def _try_from_update( - cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer] - ) -> Optional[Self]: - if isinstance( - update, (types.UpdateEditMessage, types.UpdateEditChannelMessage) - ): + def _try_from_update(cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]) -> Optional[Self]: + if isinstance(update, (types.UpdateEditMessage, types.UpdateEditChannelMessage)): return cls._from_raw(client, update.message, chat_map) else: return None @@ -74,9 +66,7 @@ class MessageDeleted(Event): self._channel_id = channel_id @classmethod - def _try_from_update( - cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer] - ) -> Optional[Self]: + def _try_from_update(cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]) -> Optional[Self]: if isinstance(update, types.UpdateDeleteMessages): return cls._create(update.messages, None) elif isinstance(update, types.UpdateDeleteChannelMessages): @@ -122,9 +112,7 @@ class MessageRead(Event): self._chat_map = chat_map @classmethod - def _try_from_update( - cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer] - ) -> Optional[Self]: + def _try_from_update(cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]) -> Optional[Self]: if isinstance( update, ( @@ -139,9 +127,7 @@ class MessageRead(Event): return None def _peer(self) -> abcs.Peer: - if isinstance( - self._raw, (types.UpdateReadHistoryInbox, types.UpdateReadHistoryOutbox) - ): + if isinstance(self._raw, (types.UpdateReadHistoryInbox, types.UpdateReadHistoryOutbox)): return self._raw.peer else: return types.PeerChannel(channel_id=self._raw.channel_id) @@ -154,9 +140,7 @@ class MessageRead(Event): peer = self._peer() pid = peer_id(peer) if pid not in self._chat_map: - self._chat_map[pid] = expand_peer( - self._client, peer, broadcast=getattr(self._raw, "post", None) - ) + self._chat_map[pid] = expand_peer(self._client, peer, broadcast=getattr(self._raw, "post", None)) return self._chat_map[pid] @property diff --git a/client/src/telethon/_impl/client/events/queries.py b/client/src/telethon/_impl/client/events/queries.py index 276c35c6..bc9d6a9c 100644 --- a/client/src/telethon/_impl/client/events/queries.py +++ b/client/src/telethon/_impl/client/events/queries.py @@ -31,9 +31,7 @@ class ButtonCallback(Event): self._chat_map = chat_map @classmethod - def _try_from_update( - cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer] - ) -> Optional[Self]: + def _try_from_update(cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]) -> Optional[Self]: if isinstance(update, types.UpdateBotCallbackQuery) and update.data is not None: return cls._create(client, update, chat_map) else: @@ -83,11 +81,7 @@ class ButtonCallback(Event): chat = self._chat_map.get(pid) or PeerRef._empty_from_peer(self._raw.peer) lst = CherryPickedList(self._client, chat._ref, []) - lst._ids.append( - types.InputMessageCallbackQuery( - id=self._raw.msg_id, query_id=self._raw.query_id - ) - ) + lst._ids.append(types.InputMessageCallbackQuery(id=self._raw.msg_id, query_id=self._raw.query_id)) message = (await lst)[0] @@ -105,9 +99,7 @@ class InlineQuery(Event): self._raw = update @classmethod - def _try_from_update( - cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer] - ) -> Optional[Self]: + def _try_from_update(cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]) -> Optional[Self]: if isinstance(update, types.UpdateBotInlineQuery): return cls._create(update) else: diff --git a/client/src/telethon/_impl/client/parsers/html.py b/client/src/telethon/_impl/client/parsers/html.py index cafa3d3d..05239414 100644 --- a/client/src/telethon/_impl/client/parsers/html.py +++ b/client/src/telethon/_impl/client/parsers/html.py @@ -133,9 +133,7 @@ def parse(html: str) -> tuple[str, list[MessageEntity]]: return del_surrogate(parser.text), parser.entities -ENTITY_TO_FORMATTER: dict[ - Type[MessageEntity], tuple[str, str] | Callable[[Any, str], tuple[str, str]] -] = { +ENTITY_TO_FORMATTER: dict[Type[MessageEntity], tuple[str, str] | Callable[[Any, str], tuple[str, str]]] = { MessageEntityBold: ("", ""), MessageEntityItalic: ("", ""), MessageEntityCode: ("", ""), @@ -196,12 +194,7 @@ def unparse(text: str, entities: Iterable[MessageEntity]) -> str: while within_surrogate(text, at): at += 1 - text = ( - text[:at] - + what - + escape(text[at:next_escape_bound]) - + text[next_escape_bound:] - ) + text = text[:at] + what + escape(text[at:next_escape_bound]) + text[next_escape_bound:] next_escape_bound = at text = escape(text[:next_escape_bound]) + text[next_escape_bound:] diff --git a/client/src/telethon/_impl/client/parsers/markdown.py b/client/src/telethon/_impl/client/parsers/markdown.py index 1704d4fb..0531f026 100644 --- a/client/src/telethon/_impl/client/parsers/markdown.py +++ b/client/src/telethon/_impl/client/parsers/markdown.py @@ -82,9 +82,7 @@ def parse(message: str) -> tuple[str, list[MessageEntity]]: else: for entity in reversed(entities): if isinstance(entity, ty): - setattr( - entity, "length", len(message) - getattr(entity, "offset", 0) - ) + setattr(entity, "length", len(message) - getattr(entity, "offset", 0)) break parsed = MARKDOWN.parse(add_surrogate(message.strip())) @@ -103,25 +101,15 @@ def parse(message: str) -> tuple[str, list[MessageEntity]]: if token.type in ("blockquote_close", "blockquote_open"): push(MessageEntityBlockquote) elif token.type == "code_block": - entities.append( - MessageEntityPre( - offset=len(message), length=len(token.content), language="" - ) - ) + entities.append(MessageEntityPre(offset=len(message), length=len(token.content), language="")) message += token.content elif token.type == "code_inline": - entities.append( - MessageEntityCode(offset=len(message), length=len(token.content)) - ) + entities.append(MessageEntityCode(offset=len(message), length=len(token.content))) message += token.content elif token.type in ("em_close", "em_open"): push(MessageEntityItalic) elif token.type == "fence": - entities.append( - MessageEntityPre( - offset=len(message), length=len(token.content), language=token.info - ) - ) + entities.append(MessageEntityPre(offset=len(message), length=len(token.content), language=token.info)) message += token.content[:-1] # remove a single trailing newline elif token.type == "hardbreak": message += "\n" @@ -130,9 +118,7 @@ def parse(message: str) -> tuple[str, list[MessageEntity]]: elif token.type == "hr": message += "\u2015\n\n" elif token.type in ("link_close", "link_open"): - if ( - token.markup != "autolink" - ): # telegram already picks up on these automatically + if token.markup != "autolink": # telegram already picks up on these automatically push(MessageEntityTextUrl, url=token.attrs.get("href")) elif token.type in ("s_close", "s_open"): push(MessageEntityStrike) diff --git a/client/src/telethon/_impl/client/parsers/strings.py b/client/src/telethon/_impl/client/parsers/strings.py index 0fd69059..1fa9a454 100644 --- a/client/src/telethon/_impl/client/parsers/strings.py +++ b/client/src/telethon/_impl/client/parsers/strings.py @@ -6,11 +6,7 @@ def add_surrogate(text: str) -> str: return "".join( # SMP -> Surrogate Pairs (Telegram offsets are calculated with these). # See https://en.wikipedia.org/wiki/Plane_(Unicode)#Overview for more. - ( - "".join(chr(y) for y in struct.unpack(" list[Message]: + async def send(self, peer: Peer | PeerRef, *, reply_to: Optional[int] = None) -> list[Message]: """ Send the album. @@ -225,11 +205,7 @@ class AlbumBuilder(metaclass=NoPublicConstructor): update_stickersets_order=False, peer=peer._ref._to_input_peer(), reply_to=( - types.InputReplyToMessage( - reply_to_msg_id=reply_to, top_msg_id=None - ) - if reply_to - else None + types.InputReplyToMessage(reply_to_msg_id=reply_to, top_msg_id=None) if reply_to else None ), multi_media=self._medias, schedule_date=None, diff --git a/client/src/telethon/_impl/client/types/buttons/button.py b/client/src/telethon/_impl/client/types/buttons/button.py index e28b4582..5632d293 100644 --- a/client/src/telethon/_impl/client/types/buttons/button.py +++ b/client/src/telethon/_impl/client/types/buttons/button.py @@ -50,9 +50,7 @@ class Button(abc.ABC): def __init__(self, text: str) -> None: if self.__class__ == Button: - raise TypeError( - f"Can't instantiate abstract class {self.__class__.__name__}" - ) + raise TypeError(f"Can't instantiate abstract class {self.__class__.__name__}") self._raw: RawButtonType = types.KeyboardButton(text=text) self._msg: Optional[weakref.ReferenceType[Message]] = None diff --git a/client/src/telethon/_impl/client/types/buttons/inline_button.py b/client/src/telethon/_impl/client/types/buttons/inline_button.py index 803269f1..1d4cfd1e 100644 --- a/client/src/telethon/_impl/client/types/buttons/inline_button.py +++ b/client/src/telethon/_impl/client/types/buttons/inline_button.py @@ -29,8 +29,6 @@ class InlineButton(Button, abc.ABC): def __init__(self, text: str) -> None: if self.__class__ == InlineButton: - raise TypeError( - f"Can't instantiate abstract class {self.__class__.__name__}" - ) + raise TypeError(f"Can't instantiate abstract class {self.__class__.__name__}") else: super().__init__(text) diff --git a/client/src/telethon/_impl/client/types/buttons/switch_inline.py b/client/src/telethon/_impl/client/types/buttons/switch_inline.py index 977b35c8..0e0842a9 100644 --- a/client/src/telethon/_impl/client/types/buttons/switch_inline.py +++ b/client/src/telethon/_impl/client/types/buttons/switch_inline.py @@ -14,9 +14,7 @@ class SwitchInline(InlineButton): def __init__(self, text: str, query: Optional[str] = None) -> None: super().__init__(text) - self._raw = types.KeyboardButtonSwitchInline( - same_peer=False, text=text, query=query or "", peer_types=None - ) + self._raw = types.KeyboardButtonSwitchInline(same_peer=False, text=text, query=query or "", peer_types=None) @property def query(self) -> str: diff --git a/client/src/telethon/_impl/client/types/chat_restriction.py b/client/src/telethon/_impl/client/types/chat_restriction.py index 3c330f91..c8f4e8fa 100644 --- a/client/src/telethon/_impl/client/types/chat_restriction.py +++ b/client/src/telethon/_impl/client/types/chat_restriction.py @@ -111,9 +111,7 @@ class ChatRestriction(Enum): return set(filter(None, iter(restrictions))) @classmethod - def _set_to_raw( - cls, restrictions: set[ChatRestriction], until_date: int - ) -> types.ChatBannedRights: + def _set_to_raw(cls, restrictions: set[ChatRestriction], until_date: int) -> types.ChatBannedRights: return types.ChatBannedRights( view_messages=cls.VIEW_MESSAGES in restrictions, send_messages=cls.SEND_MESSAGES in restrictions, diff --git a/client/src/telethon/_impl/client/types/dialog.py b/client/src/telethon/_impl/client/types/dialog.py index 673fbf78..52192a29 100644 --- a/client/src/telethon/_impl/client/types/dialog.py +++ b/client/src/telethon/_impl/client/types/dialog.py @@ -90,9 +90,6 @@ class Dialog(metaclass=NoPublicConstructor): if isinstance(self._raw, types.Dialog): return self._raw.unread_count elif isinstance(self._raw, types.DialogPeerFolder): - return ( - self._raw.unread_unmuted_messages_count - + self._raw.unread_muted_messages_count - ) + return self._raw.unread_unmuted_messages_count + self._raw.unread_muted_messages_count else: raise RuntimeError("unexpected case") diff --git a/client/src/telethon/_impl/client/types/draft.py b/client/src/telethon/_impl/client/types/draft.py index 9368044d..3086497b 100644 --- a/client/src/telethon/_impl/client/types/draft.py +++ b/client/src/telethon/_impl/client/types/draft.py @@ -37,9 +37,7 @@ class Draft(metaclass=NoPublicConstructor): self._chat_map = chat_map @classmethod - def _from_raw_update( - cls, client: Client, draft: types.UpdateDraftMessage, chat_map: dict[int, Peer] - ) -> Self: + def _from_raw_update(cls, client: Client, draft: types.UpdateDraftMessage, chat_map: dict[int, Peer]) -> Self: return cls._create(client, draft.peer, draft.top_msg_id, draft.draft, chat_map) @classmethod @@ -60,9 +58,7 @@ class Draft(metaclass=NoPublicConstructor): This is also the chat where the message will be sent to by :meth:`send`. """ - return self._chat_map.get(peer_id(self._peer)) or expand_peer( - self._client, self._peer, broadcast=None - ) + return self._chat_map.get(peer_id(self._peer)) or expand_peer(self._client, self._peer, broadcast=None) @property def link_preview(self) -> bool: @@ -91,9 +87,7 @@ class Draft(metaclass=NoPublicConstructor): The :attr:`~Message.text_html` of the message that will be sent. """ if text := getattr(self._raw, "message", None): - return generate_html_message( - text, getattr(self._raw, "entities", None) or [] - ) + return generate_html_message(text, getattr(self._raw, "entities", None) or []) else: return None @@ -103,9 +97,7 @@ class Draft(metaclass=NoPublicConstructor): The :attr:`~Message.text_markdown` of the message that will be sent. """ if text := getattr(self._raw, "message", None): - return generate_markdown_message( - text, getattr(self._raw, "entities", None) or [] - ) + return generate_markdown_message(text, getattr(self._raw, "entities", None) or []) else: return None @@ -115,11 +107,7 @@ class Draft(metaclass=NoPublicConstructor): The date when the draft was last updated. """ date = getattr(self._raw, "date", None) - return ( - datetime.datetime.fromtimestamp(date, tz=datetime.timezone.utc) - if date is not None - else None - ) + return datetime.datetime.fromtimestamp(date, tz=datetime.timezone.utc) if date is not None else None async def edit( self, @@ -192,11 +180,7 @@ class Draft(metaclass=NoPublicConstructor): noforwards=False, update_stickersets_order=False, peer=self._peer_ref()._to_input_peer(), - reply_to=( - types.InputReplyToMessage(reply_to_msg_id=reply_to, top_msg_id=None) - if reply_to - else None - ), + reply_to=(types.InputReplyToMessage(reply_to_msg_id=reply_to, top_msg_id=None) if reply_to else None), message=message, random_id=random_id, reply_markup=None, @@ -211,11 +195,7 @@ class Draft(metaclass=NoPublicConstructor): {}, out=result.out, id=result.id, - from_id=( - types.PeerUser(user_id=self._client._session.user.id) - if self._client._session.user - else None - ), + from_id=(types.PeerUser(user_id=self._client._session.user.id) if self._client._session.user else None), peer_id=self._peer_ref()._to_peer(), reply_to=( types.MessageReplyHeader( @@ -235,9 +215,7 @@ class Draft(metaclass=NoPublicConstructor): ttl_period=result.ttl_period, ) else: - return self._client._build_message_map( - result, self._peer_ref() - ).with_random_id(random_id) + return self._client._build_message_map(result, self._peer_ref()).with_random_id(random_id) async def delete(self) -> None: """ diff --git a/client/src/telethon/_impl/client/types/file.py b/client/src/telethon/_impl/client/types/file.py index 1e5e018f..76c9ce83 100644 --- a/client/src/telethon/_impl/client/types/file.py +++ b/client/src/telethon/_impl/client/types/file.py @@ -29,11 +29,7 @@ def photo_size_byte_count(size: abcs.PhotoSize) -> int: elif isinstance(size, types.PhotoSizeProgressive): return max(size.sizes) elif isinstance(size, types.PhotoStrippedSize): - return ( - len(stripped_size_header) - + (len(size.bytes) - 3) - + len(stripped_size_footer) - ) + return len(stripped_size_header) + (len(size.bytes) - 3) + len(stripped_size_footer) else: raise RuntimeError("unexpected case") @@ -180,9 +176,7 @@ class File(metaclass=NoPublicConstructor): self._client = client @classmethod - def _try_from_raw_message_media( - cls, client: Client, raw: abcs.MessageMedia - ) -> Optional[Self]: + def _try_from_raw_message_media(cls, client: Client, raw: abcs.MessageMedia) -> Optional[Self]: if isinstance(raw, types.MessageMediaDocument): if raw.document: return cls._try_from_raw_document( @@ -204,13 +198,9 @@ class File(metaclass=NoPublicConstructor): elif isinstance(raw, types.MessageMediaWebPage): if isinstance(raw.webpage, types.WebPage): if raw.webpage.document: - return cls._try_from_raw_document( - client, raw.webpage.document, orig_raw=raw - ) + return cls._try_from_raw_document(client, raw.webpage.document, orig_raw=raw) if raw.webpage.photo: - return cls._try_from_raw_photo( - client, raw.webpage.photo, orig_raw=raw - ) + return cls._try_from_raw_photo(client, raw.webpage.photo, orig_raw=raw) return None @@ -229,21 +219,13 @@ class File(metaclass=NoPublicConstructor): attributes=raw.attributes, size=raw.size, name=next( - ( - a.file_name - for a in raw.attributes - if isinstance(a, types.DocumentAttributeFilename) - ), + (a.file_name for a in raw.attributes if isinstance(a, types.DocumentAttributeFilename)), "", ), mime=raw.mime_type, photo=False, muted=next( - ( - a.nosound - for a in raw.attributes - if isinstance(a, types.DocumentAttributeVideo) - ), + (a.nosound for a in raw.attributes if isinstance(a, types.DocumentAttributeVideo)), False, ), input_media=types.InputMediaDocument( @@ -361,9 +343,7 @@ class File(metaclass=NoPublicConstructor): return dim.w for attr in self._attributes: - if isinstance( - attr, (types.DocumentAttributeImageSize, types.DocumentAttributeVideo) - ): + if isinstance(attr, (types.DocumentAttributeImageSize, types.DocumentAttributeVideo)): return attr.w return None @@ -377,9 +357,7 @@ class File(metaclass=NoPublicConstructor): return dim.h for attr in self._attributes: - if isinstance( - attr, (types.DocumentAttributeImageSize, types.DocumentAttributeVideo) - ): + if isinstance(attr, (types.DocumentAttributeImageSize, types.DocumentAttributeVideo)): return attr.h return None @@ -410,9 +388,7 @@ class File(metaclass=NoPublicConstructor): id=self._input_media.id.id, access_hash=self._input_media.id.access_hash, file_reference=self._input_media.id.file_reference, - thumb_size=( - self._thumb.type if isinstance(self._thumb, thumb_types) else "" - ), + thumb_size=(self._thumb.type if isinstance(self._thumb, thumb_types) else ""), ) elif isinstance(self._input_media, types.InputMediaPhoto): assert isinstance(self._input_media.id, types.InputPhoto) diff --git a/client/src/telethon/_impl/client/types/keyboard.py b/client/src/telethon/_impl/client/types/keyboard.py index d560463c..b9fc1b7a 100644 --- a/client/src/telethon/_impl/client/types/keyboard.py +++ b/client/src/telethon/_impl/client/types/keyboard.py @@ -12,16 +12,11 @@ def _build_keyboard_rows( ) -> list[abcs.KeyboardButtonRow]: # list[button] -> list[list[button]] # This does allow for "invalid" inputs (mixing lists and non-lists), but that's acceptable. - buttons_lists_iter = [ - button if isinstance(button, list) else [button] for button in (btns or []) - ] + buttons_lists_iter = [button if isinstance(button, list) else [button] for button in (btns or [])] # Remove empty rows (also making it easy to check if all-empty). buttons_lists = [bs for bs in buttons_lists_iter if bs] - return [ - types.KeyboardButtonRow(buttons=[btn._raw for btn in btns]) - for btns in buttons_lists - ] + return [types.KeyboardButtonRow(buttons=[btn._raw for btn in btns]) for btns in buttons_lists] class Keyboard: @@ -49,9 +44,7 @@ class Keyboard: class InlineKeyboard: __slots__ = ("_raw",) - def __init__( - self, buttons: list[AnyInlineButton] | list[list[AnyInlineButton]] - ) -> None: + def __init__(self, buttons: list[AnyInlineButton] | list[list[AnyInlineButton]]) -> None: self._raw = types.ReplyInlineMarkup(rows=_build_keyboard_rows(buttons)) diff --git a/client/src/telethon/_impl/client/types/message.py b/client/src/telethon/_impl/client/types/message.py index 4d37263d..08101ec9 100644 --- a/client/src/telethon/_impl/client/types/message.py +++ b/client/src/telethon/_impl/client/types/message.py @@ -34,11 +34,7 @@ def generate_random_id() -> int: def adapt_date(date: Optional[int]) -> Optional[datetime.datetime]: - return ( - datetime.datetime.fromtimestamp(date, tz=datetime.timezone.utc) - if date is not None - else None - ) + return datetime.datetime.fromtimestamp(date, tz=datetime.timezone.utc) if date is not None else None class Message(metaclass=NoPublicConstructor): @@ -59,20 +55,14 @@ class Message(metaclass=NoPublicConstructor): print('Found empty message with ID', message.id) """ - def __init__( - self, client: Client, message: abcs.Message, chat_map: dict[int, Peer] - ) -> None: - assert isinstance( - message, (types.Message, types.MessageService, types.MessageEmpty) - ) + def __init__(self, client: Client, message: abcs.Message, chat_map: dict[int, Peer]) -> None: + assert isinstance(message, (types.Message, types.MessageService, types.MessageEmpty)) self._client = client self._raw = message self._chat_map = chat_map @classmethod - def _from_raw( - cls, client: Client, message: abcs.Message, chat_map: dict[int, Peer] - ) -> Self: + def _from_raw(cls, client: Client, message: abcs.Message, chat_map: dict[int, Peer]) -> Self: return cls._create(client, message, chat_map) @classmethod @@ -158,9 +148,7 @@ class Message(metaclass=NoPublicConstructor): See :ref:`formatting` to learn the HTML elements used. """ if text := getattr(self._raw, "message", None): - return generate_html_message( - text, getattr(self._raw, "entities", None) or [] - ) + return generate_html_message(text, getattr(self._raw, "entities", None) or []) else: return None @@ -172,9 +160,7 @@ class Message(metaclass=NoPublicConstructor): See :ref:`formatting` to learn the formatting characters used. """ if text := getattr(self._raw, "message", None): - return generate_markdown_message( - text, getattr(self._raw, "entities", None) or [] - ) + return generate_markdown_message(text, getattr(self._raw, "entities", None) or []) else: return None @@ -193,9 +179,7 @@ class Message(metaclass=NoPublicConstructor): peer = self._raw.peer_id or types.PeerUser(user_id=0) pid = peer_id(peer) if pid not in self._chat_map: - self._chat_map[pid] = expand_peer( - self._client, peer, broadcast=getattr(self._raw, "post", None) - ) + self._chat_map[pid] = expand_peer(self._client, peer, broadcast=getattr(self._raw, "post", None)) return self._chat_map[pid] @property @@ -239,14 +223,7 @@ class Message(metaclass=NoPublicConstructor): This can also be used as a way to check that the message media is an audio. """ audio = self._file() - return ( - audio - if audio - and any( - isinstance(a, types.DocumentAttributeAudio) for a in audio._attributes - ) - else None - ) + return audio if audio and any(isinstance(a, types.DocumentAttributeAudio) for a in audio._attributes) else None @property def video(self) -> Optional[File]: @@ -256,14 +233,7 @@ class Message(metaclass=NoPublicConstructor): This can also be used as a way to check that the message media is a video. """ audio = self._file() - return ( - audio - if audio - and any( - isinstance(a, types.DocumentAttributeVideo) for a in audio._attributes - ) - else None - ) + return audio if audio and any(isinstance(a, types.DocumentAttributeVideo) for a in audio._attributes) else None @property def file(self) -> Optional[File]: @@ -477,10 +447,7 @@ class Message(metaclass=NoPublicConstructor): return None return [ - [ - create_button(self, button) - for button in cast(types.KeyboardButtonRow, row).buttons - ] + [create_button(self, button) for button in cast(types.KeyboardButtonRow, row).buttons] for row in markup.rows ] @@ -506,13 +473,8 @@ class Message(metaclass=NoPublicConstructor): return not isinstance(self._raw, types.MessageEmpty) -def build_msg_map( - client: Client, messages: Sequence[abcs.Message], chat_map: dict[int, Peer] -) -> dict[int, Message]: - return { - msg.id: msg - for msg in (Message._from_raw(client, m, chat_map) for m in messages) - } +def build_msg_map(client: Client, messages: Sequence[abcs.Message], chat_map: dict[int, Peer]) -> dict[int, Message]: + return {msg.id: msg for msg in (Message._from_raw(client, m, chat_map) for m in messages)} def parse_message( diff --git a/client/src/telethon/_impl/client/types/meta.py b/client/src/telethon/_impl/client/types/meta.py index b7de4f15..b812de6c 100644 --- a/client/src/telethon/_impl/client/types/meta.py +++ b/client/src/telethon/_impl/client/types/meta.py @@ -16,23 +16,16 @@ class Final(abc.ABCMeta): cls_namespace: dict[str, object], ) -> "Final": # Allow subclassing while within telethon._impl (or other package names). - allowed_base = Final.__module__[ - : Final.__module__.find(".", Final.__module__.find(".") + 1) - ] + allowed_base = Final.__module__[: Final.__module__.find(".", Final.__module__.find(".") + 1)] for base in bases: if isinstance(base, Final) and not base.__module__.startswith(allowed_base): - raise TypeError( - f"{base.__module__}.{base.__qualname__} does not support" - " subclassing" - ) + raise TypeError(f"{base.__module__}.{base.__qualname__} does not support" " subclassing") return super().__new__(cls, name, bases, cls_namespace) class NoPublicConstructor(Final): def __call__(cls, *args: Any, **kwds: Any) -> Any: - raise TypeError( - f"{cls.__module__}.{cls.__qualname__} has no public constructor" - ) + raise TypeError(f"{cls.__module__}.{cls.__qualname__} has no public constructor") @property def _create(cls: Type[T]) -> Type[T]: diff --git a/client/src/telethon/_impl/client/types/participant.py b/client/src/telethon/_impl/client/types/participant.py index ce7937b5..62689a5d 100644 --- a/client/src/telethon/_impl/client/types/participant.py +++ b/client/src/telethon/_impl/client/types/participant.py @@ -157,22 +157,16 @@ class Participant(metaclass=NoPublicConstructor): """ :data:`True` if the participant is the creator of the chat. """ - return isinstance( - self._raw, (types.ChannelParticipantCreator, types.ChatParticipantCreator) - ) + return isinstance(self._raw, (types.ChannelParticipantCreator, types.ChatParticipantCreator)) @property def admin_rights(self) -> Optional[set[AdminRight]]: """ The set of administrator rights this participant has been granted, if they are an administrator. """ - if isinstance( - self._raw, (types.ChannelParticipantCreator, types.ChannelParticipantAdmin) - ): + if isinstance(self._raw, (types.ChannelParticipantCreator, types.ChannelParticipantAdmin)): return AdminRight._from_raw(self._raw.admin_rights) - elif isinstance( - self._raw, (types.ChatParticipantCreator, types.ChatParticipantAdmin) - ): + elif isinstance(self._raw, (types.ChatParticipantCreator, types.ChatParticipantAdmin)): return AdminRight._chat_rights() else: return None @@ -194,13 +188,9 @@ class Participant(metaclass=NoPublicConstructor): participant = self.user or self.banned or self.left assert participant if isinstance(participant, User): - await self._client.set_participant_admin_rights( - self._chat, participant, rights - ) + await self._client.set_participant_admin_rights(self._chat, participant, rights) else: - raise TypeError( - f"participant of type {participant.__class__.__name__} cannot be made admin" - ) + raise TypeError(f"participant of type {participant.__class__.__name__} cannot be made admin") async def set_restrictions( self, @@ -213,6 +203,4 @@ class Participant(metaclass=NoPublicConstructor): """ participant = self.user or self.banned or self.left assert participant - await self._client.set_participant_restrictions( - self._chat, participant, restrictions, until=until - ) + await self._client.set_participant_restrictions(self._chat, participant, restrictions, until=until) diff --git a/client/src/telethon/_impl/client/types/peer/__init__.py b/client/src/telethon/_impl/client/types/peer/__init__.py index 91c5d6ba..582a8304 100644 --- a/client/src/telethon/_impl/client/types/peer/__init__.py +++ b/client/src/telethon/_impl/client/types/peer/__init__.py @@ -15,9 +15,7 @@ if TYPE_CHECKING: from ...client.client import Client -def build_chat_map( - client: Client, users: Sequence[abcs.User], chats: Sequence[abcs.Chat] -) -> dict[int, Peer]: +def build_chat_map(client: Client, users: Sequence[abcs.User], chats: Sequence[abcs.Chat]) -> dict[int, Peer]: users_iter = (User._from_raw(u) for u in users) chats_iter = ( ( @@ -45,9 +43,7 @@ def build_chat_map( for x in v: print(x, file=sys.stderr) - raise RuntimeError( - f"chat identifier collision: {k}; please report this" - ) + raise RuntimeError(f"chat identifier collision: {k}; please report this") return result @@ -81,11 +77,7 @@ def expand_peer(client: Client, peer: abcs.Peer, *, broadcast: Optional[bool]) - until_date=None, ) - return ( - Channel._from_raw(channel) - if broadcast - else Group._from_raw(client, channel) - ) + return Channel._from_raw(channel) if broadcast else Group._from_raw(client, channel) else: raise RuntimeError("unexpected case") diff --git a/client/src/telethon/_impl/client/types/peer/group.py b/client/src/telethon/_impl/client/types/peer/group.py index 5dce7ef3..6fd8b04d 100644 --- a/client/src/telethon/_impl/client/types/peer/group.py +++ b/client/src/telethon/_impl/client/types/peer/group.py @@ -24,13 +24,7 @@ class Group(Peer, metaclass=NoPublicConstructor): def __init__( self, client: Client, - chat: ( - types.ChatEmpty - | types.Chat - | types.ChatForbidden - | types.Channel - | types.ChannelForbidden - ), + chat: (types.ChatEmpty | types.Chat | types.ChatForbidden | types.Channel | types.ChannelForbidden), ) -> None: self._client = client self._raw = chat @@ -96,6 +90,4 @@ class Group(Peer, metaclass=NoPublicConstructor): """ Alias for :meth:`telethon.Client.set_chat_default_restrictions`. """ - await self._client.set_chat_default_restrictions( - self, restrictions, until=until - ) + await self._client.set_chat_default_restrictions(self, restrictions, until=until) diff --git a/client/src/telethon/_impl/crypto/aes.py b/client/src/telethon/_impl/crypto/aes.py index 60e74964..41355f81 100644 --- a/client/src/telethon/_impl/crypto/aes.py +++ b/client/src/telethon/_impl/crypto/aes.py @@ -1,16 +1,10 @@ try: - import cryptg + import cryptg # type: ignore [import-untyped] - def ige_encrypt( - plaintext: bytes | bytearray | memoryview, key: bytes, iv: bytes - ) -> bytes: # noqa: F811 - return cryptg.encrypt_ige( - bytes(plaintext) if not isinstance(plaintext, bytes) else plaintext, key, iv - ) + def ige_encrypt(plaintext: bytes | bytearray | memoryview, key: bytes, iv: bytes) -> bytes: # noqa: F811 + return cryptg.encrypt_ige(bytes(plaintext) if not isinstance(plaintext, bytes) else plaintext, key, iv) - def ige_decrypt( - ciphertext: bytes | bytearray | memoryview, key: bytes, iv: bytes - ) -> bytes: # noqa: F811 + def ige_decrypt(ciphertext: bytes | bytearray | memoryview, key: bytes, iv: bytes) -> bytes: # noqa: F811 return cryptg.decrypt_ige( bytes(ciphertext) if not isinstance(ciphertext, bytes) else ciphertext, key, @@ -18,11 +12,9 @@ try: ) except ImportError: - import pyaes + import pyaes # type: ignore [import-untyped] - def ige_encrypt( - plaintext: bytes | bytearray | memoryview, key: bytes, iv: bytes - ) -> bytes: + def ige_encrypt(plaintext: bytes | bytearray | memoryview, key: bytes, iv: bytes) -> bytes: assert len(plaintext) % 16 == 0 assert len(iv) == 32 @@ -35,10 +27,7 @@ except ImportError: for block_offset in range(0, len(plaintext), 16): plaintext_block = plaintext[block_offset : block_offset + 16] ciphertext_block = bytes( - a ^ b - for a, b in zip( - aes.encrypt([a ^ b for a, b in zip(plaintext_block, iv1)]), iv2 - ) + a ^ b for a, b in zip(aes.encrypt([a ^ b for a, b in zip(plaintext_block, iv1)]), iv2) ) iv1 = ciphertext_block iv2 = plaintext_block @@ -47,9 +36,7 @@ except ImportError: return bytes(ciphertext) - def ige_decrypt( - ciphertext: bytes | bytearray | memoryview, key: bytes, iv: bytes - ) -> bytes: + def ige_decrypt(ciphertext: bytes | bytearray | memoryview, key: bytes, iv: bytes) -> bytes: assert len(ciphertext) % 16 == 0 assert len(iv) == 32 @@ -62,10 +49,7 @@ except ImportError: for block_offset in range(0, len(ciphertext), 16): ciphertext_block = ciphertext[block_offset : block_offset + 16] plaintext_block = bytes( - a ^ b - for a, b in zip( - aes.decrypt([a ^ b for a, b in zip(ciphertext_block, iv2)]), iv1 - ) + a ^ b for a, b in zip(aes.decrypt([a ^ b for a, b in zip(ciphertext_block, iv2)]), iv1) ) iv1 = ciphertext_block iv2 = plaintext_block diff --git a/client/src/telethon/_impl/crypto/auth_key.py b/client/src/telethon/_impl/crypto/auth_key.py index f2c2cdc7..11a4ab84 100644 --- a/client/src/telethon/_impl/crypto/auth_key.py +++ b/client/src/telethon/_impl/crypto/auth_key.py @@ -20,8 +20,4 @@ class AuthKey: return self.data def calc_new_nonce_hash(self, new_nonce: int, number: int) -> int: - return int.from_bytes( - sha1(new_nonce.to_bytes(32) + number.to_bytes(1) + self.aux_hash).digest()[ - 4: - ] - ) + return int.from_bytes(sha1(new_nonce.to_bytes(32) + number.to_bytes(1) + self.aux_hash).digest()[4:]) diff --git a/client/src/telethon/_impl/crypto/crypto.py b/client/src/telethon/_impl/crypto/crypto.py index 17fd788e..d08308b7 100644 --- a/client/src/telethon/_impl/crypto/crypto.py +++ b/client/src/telethon/_impl/crypto/crypto.py @@ -19,9 +19,7 @@ class CalcKey(NamedTuple): # https://core.telegram.org/mtproto/description#defining-aes-key-and-initialization-vector -def calc_key( - auth_key: AuthKey, msg_key: bytes | bytearray | memoryview, side: Side -) -> CalcKey: +def calc_key(auth_key: AuthKey, msg_key: bytes | bytearray | memoryview, side: Side) -> CalcKey: x = int(side) # sha256_a = SHA256 (msg_key + substr (auth_key, x, 36)) @@ -43,12 +41,8 @@ def determine_padding_v2_length(length: int) -> int: return 16 + (16 - (length % 16)) -def _do_encrypt_data_v2( - plaintext: bytes, auth_key: AuthKey, random_padding: bytes -) -> bytes: - padded_plaintext = ( - plaintext + random_padding[: determine_padding_v2_length(len(plaintext))] - ) +def _do_encrypt_data_v2(plaintext: bytes, auth_key: AuthKey, random_padding: bytes) -> bytes: + padded_plaintext = plaintext + random_padding[: determine_padding_v2_length(len(plaintext))] side = Side.CLIENT x = int(side) @@ -70,9 +64,7 @@ def encrypt_data_v2(plaintext: bytes, auth_key: AuthKey) -> bytes: return _do_encrypt_data_v2(plaintext, auth_key, random_padding) -def decrypt_data_v2( - ciphertext: bytes | bytearray | memoryview, auth_key: AuthKey -) -> bytes: +def decrypt_data_v2(ciphertext: bytes | bytearray | memoryview, auth_key: AuthKey) -> bytes: side = Side.SERVER x = int(side) diff --git a/client/src/telethon/_impl/crypto/rsa.py b/client/src/telethon/_impl/crypto/rsa.py index 72338887..c125bc60 100644 --- a/client/src/telethon/_impl/crypto/rsa.py +++ b/client/src/telethon/_impl/crypto/rsa.py @@ -34,17 +34,13 @@ def encrypt_hashed(data: bytes, key: PublicKey, random_bytes: bytes) -> bytes: temp_key = random_bytes[192 + 32 * attempt : 192 + 32 * attempt + 32] # data_with_hash := data_pad_reversed + SHA256(temp_key + data_with_padding); -- after this assignment, data_with_hash is exactly 224 bytes long. - data_with_hash = ( - data_pad_reversed + sha256(temp_key + data_with_padding).digest() - ) + data_with_hash = data_pad_reversed + sha256(temp_key + data_with_padding).digest() # aes_encrypted := AES256_IGE(data_with_hash, temp_key, 0); -- AES256-IGE encryption with zero IV. aes_encrypted = ige_encrypt(data_with_hash, temp_key, bytes(32)) # temp_key_xor := temp_key XOR SHA256(aes_encrypted); -- adjusted key, 32 bytes - temp_key_xor = bytes( - a ^ b for a, b in zip(temp_key, sha256(aes_encrypted).digest()) - ) + temp_key_xor = bytes(a ^ b for a, b in zip(temp_key, sha256(aes_encrypted).digest())) # key_aes_encrypted := temp_key_xor + aes_encrypted; -- exactly 256 bytes (2048 bits) long key_aes_encrypted = temp_key_xor + aes_encrypted @@ -87,6 +83,4 @@ j4WcDuXc2CTHgH8gFTNhp/Y8/SpDOhvn9QIDAQAB ) -RSA_KEYS = { - compute_fingerprint(key): key for key in (PRODUCTION_RSA_KEY, TESTMODE_RSA_KEY) -} +RSA_KEYS = {compute_fingerprint(key): key for key in (PRODUCTION_RSA_KEY, TESTMODE_RSA_KEY)} diff --git a/client/src/telethon/_impl/crypto/two_factor_auth.py b/client/src/telethon/_impl/crypto/two_factor_auth.py index 384ca5a9..4a05d1aa 100644 --- a/client/src/telethon/_impl/crypto/two_factor_auth.py +++ b/client/src/telethon/_impl/crypto/two_factor_auth.py @@ -20,9 +20,7 @@ def h(*data: bytes | bytearray | memoryview) -> bytes: # SH(data, salt) := H(salt | data | salt) -def sh( - data: bytes | bytearray | memoryview, salt: bytes | bytearray | memoryview -) -> bytes: +def sh(data: bytes | bytearray | memoryview, salt: bytes | bytearray | memoryview) -> bytes: return h(salt, data, salt) diff --git a/client/src/telethon/_impl/mtproto/authentication.py b/client/src/telethon/_impl/mtproto/authentication.py index 1b20dc47..a5902f04 100644 --- a/client/src/telethon/_impl/mtproto/authentication.py +++ b/client/src/telethon/_impl/mtproto/authentication.py @@ -108,13 +108,9 @@ def _do_step2(data: Step1, response: bytes, random_bytes: bytes) -> tuple[bytes, ) try: - fingerprint = next( - fp for fp in res_pq.server_public_key_fingerprints if fp in RSA_KEYS - ) + fingerprint = next(fp for fp in res_pq.server_public_key_fingerprints if fp in RSA_KEYS) except StopIteration: - raise ValueError( - f"unknown fingerprints: {res_pq.server_public_key_fingerprints}" - ) + raise ValueError(f"unknown fingerprints: {res_pq.server_public_key_fingerprints}") key = RSA_KEYS[fingerprint] ciphertext = encrypt_hashed(pq_inner_data, key, random_bytes) @@ -133,9 +129,7 @@ def step2(data: Step1, response: bytes) -> tuple[bytes, Step2]: return _do_step2(data, response, os.urandom(288)) -def _do_step3( - data: Step2, response: bytes, random_bytes: bytes, now: int -) -> tuple[bytes, Step3]: +def _do_step3(data: Step2, response: bytes, random_bytes: bytes, now: int) -> tuple[bytes, Step3]: assert len(random_bytes) == 272 nonce = data.nonce @@ -158,9 +152,7 @@ def _do_step3( check_server_nonce(server_dh_params.server_nonce, server_nonce) if len(server_dh_params.encrypted_answer) % 16 != 0: - raise ValueError( - f"encrypted response not padded with size: {len(server_dh_params.encrypted_answer)}" - ) + raise ValueError(f"encrypted response not padded with size: {len(server_dh_params.encrypted_answer)}") key, iv = generate_key_data_from_nonce(server_nonce, new_nonce) assert isinstance(server_dh_params.encrypted_answer, bytes) @@ -172,9 +164,7 @@ def _do_step3( server_dh_inner = AbcServerDhInnerData._read_from(plain_text_reader) assert isinstance(server_dh_inner, ServerDhInnerData) - expected_answer_hash = sha1( - plain_text_answer[20 : 20 + plain_text_reader._pos] - ).digest() + expected_answer_hash = sha1(plain_text_answer[20 : 20 + plain_text_reader._pos]).digest() if got_answer_hash != expected_answer_hash: raise ValueError("invalid answer hash") @@ -213,15 +203,11 @@ def _do_step3( ) client_dh_inner_hashed = sha1(client_dh_inner).digest() + client_dh_inner - client_dh_inner_hashed += random_bytes[ - : (16 - (len(client_dh_inner_hashed) % 16)) % 16 - ] + client_dh_inner_hashed += random_bytes[: (16 - (len(client_dh_inner_hashed) % 16)) % 16] client_dh_encrypted = encrypt_ige(client_dh_inner_hashed, key, iv) - return set_client_dh_params( - nonce=nonce, server_nonce=server_nonce, encrypted_data=client_dh_encrypted - ), Step3( + return set_client_dh_params(nonce=nonce, server_nonce=server_nonce, encrypted_data=client_dh_encrypted), Step3( nonce=nonce, server_nonce=server_nonce, new_nonce=new_nonce, @@ -277,10 +263,7 @@ def create_key(data: Step3, response: bytes) -> CreatedKey: first_salt = struct.unpack( " None: def check_g_in_range(value: int, low: int, high: int) -> None: if not (low < value < high): - raise ValueError(f"g parameter {value} not in range({low+1}, {high})") + raise ValueError(f"g parameter {value} not in range({low + 1}, {high})") diff --git a/client/src/telethon/_impl/mtproto/mtp/encrypted.py b/client/src/telethon/_impl/mtproto/mtp/encrypted.py index 5195b950..88927056 100644 --- a/client/src/telethon/_impl/mtproto/mtp/encrypted.py +++ b/client/src/telethon/_impl/mtproto/mtp/encrypted.py @@ -107,9 +107,7 @@ class Encrypted(Mtp): ) -> None: self._auth_key = auth_key self._time_offset: int = time_offset or 0 - self._salts: list[FutureSalt] = [ - FutureSalt(valid_since=0, valid_until=0x7FFFFFFF, salt=first_salt or 0) - ] + self._salts: list[FutureSalt] = [FutureSalt(valid_since=0, valid_until=0x7FFFFFFF, salt=first_salt or 0)] self._start_salt_time: Optional[tuple[int, float]] = None self._compression_threshold = compression_threshold self._deserialization: list[Deserialization] = [] @@ -203,9 +201,7 @@ class Encrypted(Mtp): if self._msg_count == 1: del self._buffer[:CONTAINER_HEADER_LEN] - self._buffer[:HEADER_LEN] = struct.pack( - " None: new_session = NewSessionCreated.from_bytes(message.body) self._salts.clear() - self._salts.append( - FutureSalt( - valid_since=0, valid_until=0x7FFFFFFF, salt=new_session.server_salt - ) - ) + self._salts.append(FutureSalt(valid_since=0, valid_until=0x7FFFFFFF, salt=new_session.server_salt)) def _handle_container(self, message: Message) -> None: container = MsgContainer.from_bytes(message.body) @@ -376,11 +362,7 @@ class Encrypted(Mtp): self._deserialization.append(Update(message.body)) def _try_request_salts(self) -> None: - if ( - len(self._salts) == 1 - and self._salt_request_msg_id is None - and self._get_current_salt() != 0 - ): + if len(self._salts) == 1 and self._salt_request_msg_id is None and self._get_current_salt() != 0: # If salts are requested in a container leading to bad_msg, # the bad_msg_id will refer to the container, not the salts request. # @@ -388,9 +370,7 @@ class Encrypted(Mtp): # This would break, because we couldn't identify the response. # # So salts are only requested once we have a valid salt to reduce the chances of this happening. - self._salt_request_msg_id = self._serialize_msg( - bytes(get_future_salts(num=NUM_FUTURE_SALTS)), True - ) + self._salt_request_msg_id = self._serialize_msg(bytes(get_future_salts(num=NUM_FUTURE_SALTS)), True) def push(self, request: bytes) -> Optional[MsgId]: if self._start_salt_time and len(self._salts) >= 2: @@ -435,9 +415,7 @@ class Encrypted(Mtp): return MsgId(self._last_msg_id), encrypt_data_v2(result, self._auth_key) - def deserialize( - self, payload: bytes | bytearray | memoryview - ) -> list[Deserialization]: + def deserialize(self, payload: bytes | bytearray | memoryview) -> list[Deserialization]: check_message_buffer(payload) plaintext = decrypt_data_v2(payload, self._auth_key) diff --git a/client/src/telethon/_impl/mtproto/mtp/plain.py b/client/src/telethon/_impl/mtproto/mtp/plain.py index b1fe03d2..43b93518 100644 --- a/client/src/telethon/_impl/mtproto/mtp/plain.py +++ b/client/src/telethon/_impl/mtproto/mtp/plain.py @@ -31,9 +31,7 @@ class Plain(Mtp): self._buffer.clear() return MsgId(0), result - def deserialize( - self, payload: bytes | bytearray | memoryview - ) -> list[Deserialization]: + def deserialize(self, payload: bytes | bytearray | memoryview) -> list[Deserialization]: check_message_buffer(payload) auth_key_id, msg_id, length = struct.unpack_from("= 0, got: {length}") if 20 + length > len(payload): - raise ValueError( - f"message too short, expected: {20 + length}, got {len(payload)}" - ) + raise ValueError(f"message too short, expected: {20 + length}, got {len(payload)}") return [RpcResult(MsgId(0), bytes(payload[20 : 20 + length]))] diff --git a/client/src/telethon/_impl/mtproto/mtp/types.py b/client/src/telethon/_impl/mtproto/mtp/types.py index 3798d859..58f176cd 100644 --- a/client/src/telethon/_impl/mtproto/mtp/types.py +++ b/client/src/telethon/_impl/mtproto/mtp/types.py @@ -116,11 +116,7 @@ class RpcError(ValueError): def __eq__(self, other: object) -> bool: if not isinstance(other, self.__class__): return NotImplemented - return ( - self._code == other._code - and self._name == other._name - and self._value == other._value - ) + return self._code == other._code and self._name == other._name and self._value == other._value # https://core.telegram.org/mtproto/service_messages_about_messages @@ -156,9 +152,7 @@ class BadMessage(ValueError): self.msg_id = msg_id self._code = code self._caused_by = caused_by - self.severity = ( - logging.WARNING if self._code in NON_FATAL_MSG_IDS else logging.ERROR - ) + self.severity = logging.WARNING if self._code in NON_FATAL_MSG_IDS else logging.ERROR @property def code(self) -> int: @@ -201,9 +195,7 @@ class Mtp(ABC): """ @abstractmethod - def deserialize( - self, payload: bytes | bytearray | memoryview - ) -> list[Deserialization]: + def deserialize(self, payload: bytes | bytearray | memoryview) -> list[Deserialization]: """ Deserialize incoming buffer payload. """ diff --git a/client/src/telethon/_impl/mtproto/transport/intermediate.py b/client/src/telethon/_impl/mtproto/transport/intermediate.py index 3a9aebbb..7d4c0053 100644 --- a/client/src/telethon/_impl/mtproto/transport/intermediate.py +++ b/client/src/telethon/_impl/mtproto/transport/intermediate.py @@ -42,10 +42,7 @@ class Intermediate(Transport): raise MissingBytes(expected=length, got=len(input)) if length <= 4: - if ( - length >= 4 - and (status := struct.unpack("= 4 and (status := struct.unpack(" 0, got: {length}") diff --git a/client/src/telethon/_impl/mtproto/utils.py b/client/src/telethon/_impl/mtproto/utils.py index 7b56fd17..68df596c 100644 --- a/client/src/telethon/_impl/mtproto/utils.py +++ b/client/src/telethon/_impl/mtproto/utils.py @@ -12,9 +12,7 @@ MESSAGE_SIZE_OVERHEAD = 8 + 4 + 4 # msg_id, seq_no, bytes def check_message_buffer(message: bytes | bytearray | memoryview) -> None: if len(message) < 20: - raise ValueError( - f"server payload is too small to be a valid message: {message.hex()}" - ) + raise ValueError(f"server payload is too small to be a valid message: {message.hex()}") # https://core.telegram.org/mtproto/description#content-related-message diff --git a/client/src/telethon/_impl/mtsender/sender.py b/client/src/telethon/_impl/mtsender/sender.py index 84533451..86d04fa1 100644 --- a/client/src/telethon/_impl/mtsender/sender.py +++ b/client/src/telethon/_impl/mtsender/sender.py @@ -313,13 +313,7 @@ class Sender: def _on_ping_timeout(self) -> None: ping_id = generate_random_id() - self._enqueue_body( - bytes( - ping_delay_disconnect( - ping_id=ping_id, disconnect_delay=NO_PING_DISCONNECT - ) - ) - ) + self._enqueue_body(bytes(ping_delay_disconnect(ping_id=ping_id, disconnect_delay=NO_PING_DISCONNECT))) self._next_ping = asyncio.get_running_loop().time() + PING_DELAY def _process_mtp_buffer(self, updates: list[Updates]) -> None: @@ -335,9 +329,7 @@ class Sender: else: self._process_bad_message(result) - def _process_update( - self, updates: list[Updates], update: bytes | bytearray | memoryview - ) -> None: + def _process_update(self, updates: list[Updates], update: bytes | bytearray | memoryview) -> None: try: updates.append(Updates.from_bytes(update)) except ValueError: @@ -441,9 +433,7 @@ class Sender: req.state.msg_id == msg_id or req.state.container_msg_id == msg_id ): raise RuntimeError("got response for unsent request") - elif isinstance(req.state, Sent) and ( - req.state.msg_id == msg_id or req.state.container_msg_id == msg_id - ): + elif isinstance(req.state, Sent) and (req.state.msg_id == msg_id or req.state.container_msg_id == msg_id): yield self._requests.pop(i) @property diff --git a/client/src/telethon/_impl/session/chat/hash_cache.py b/client/src/telethon/_impl/session/chat/hash_cache.py index 60117029..2c5f7fe8 100644 --- a/client/src/telethon/_impl/session/chat/hash_cache.py +++ b/client/src/telethon/_impl/session/chat/hash_cache.py @@ -65,9 +65,7 @@ class ChatHashCache: return self._has_peer(peer.peer) elif isinstance(peer, types.NotifyForumTopic): return self._has_peer(peer.peer) - elif isinstance( - peer, (types.NotifyUsers, types.NotifyChats, types.NotifyBroadcasts) - ): + elif isinstance(peer, (types.NotifyUsers, types.NotifyChats, types.NotifyBroadcasts)): return True else: raise RuntimeError("unexpected case") @@ -120,9 +118,7 @@ class ChatHashCache: elif isinstance(participant, types.ChannelParticipantAdmin): return ( self._has(participant.user_id) - and ( - participant.inviter_id is None or self._has(participant.inviter_id) - ) + and (participant.inviter_id is None or self._has(participant.inviter_id)) and self._has(participant.promoted_by) ) elif isinstance(participant, types.ChannelParticipantBanned): diff --git a/client/src/telethon/_impl/session/chat/peer_ref.py b/client/src/telethon/_impl/session/chat/peer_ref.py index 9752e1de..c1bd0070 100644 --- a/client/src/telethon/_impl/session/chat/peer_ref.py +++ b/client/src/telethon/_impl/session/chat/peer_ref.py @@ -43,12 +43,8 @@ class PeerRef(abc.ABC): __slots__ = ("identifier", "authorization") - def __init__( - self, identifier: PeerIdentifier, authorization: PeerAuth = None - ) -> None: - assert ( - identifier >= 0 - ), "PeerRef identifiers must be positive; see the documentation for Peers" + def __init__(self, identifier: PeerIdentifier, authorization: PeerAuth = None) -> None: + assert identifier >= 0, "PeerRef identifiers must be positive; see the documentation for Peers" self.identifier = identifier self.authorization = authorization @@ -83,9 +79,7 @@ class PeerRef(abc.ABC): authorization: Optional[int] = None else: try: - (authorization,) = struct.unpack( - "!q", base64.urlsafe_b64decode(auth.encode("ascii") + b"=") - ) + (authorization,) = struct.unpack("!q", base64.urlsafe_b64decode(auth.encode("ascii") + b"=")) except Exception: raise ValueError(f"invalid PeerRef string: {string!r}") @@ -137,21 +131,14 @@ class PeerRef(abc.ABC): if self.authorization is None: auth = "0" else: - auth = ( - base64.urlsafe_b64encode(struct.pack("!q", self.authorization)) - .decode("ascii") - .rstrip("=") - ) + auth = base64.urlsafe_b64encode(struct.pack("!q", self.authorization)).decode("ascii").rstrip("=") return f"{self.identifier}.{auth}" def __eq__(self, other: object) -> bool: if not isinstance(other, self.__class__): return NotImplemented - return ( - self.identifier == other.identifier - and self.authorization == other.authorization - ) + return self.identifier == other.identifier and self.authorization == other.authorization @property def _ref(self) -> UserRef | GroupRef | ChannelRef: @@ -182,16 +169,12 @@ class UserRef(PeerRef): def _to_input_peer(self) -> abcs.InputPeer: if self.identifier == SELF_USER_SENTINEL_ID: return types.InputPeerSelf() - return types.InputPeerUser( - user_id=self.identifier, access_hash=self.authorization or 0 - ) + return types.InputPeerUser(user_id=self.identifier, access_hash=self.authorization or 0) def _to_input_user(self) -> abcs.InputUser: if self.identifier == SELF_USER_SENTINEL_ID: return types.InputUserSelf() - return types.InputUser( - user_id=self.identifier, access_hash=self.authorization or 0 - ) + return types.InputUser(user_id=self.identifier, access_hash=self.authorization or 0) def __str__(self) -> str: return f"{USER_PREFIX}{self._encode_str()}" @@ -252,14 +235,10 @@ class ChannelRef(PeerRef): return types.PeerChannel(channel_id=self.identifier) def _to_input_peer(self) -> abcs.InputPeer: - return types.InputPeerChannel( - channel_id=self.identifier, access_hash=self.authorization or 0 - ) + return types.InputPeerChannel(channel_id=self.identifier, access_hash=self.authorization or 0) def _to_input_channel(self) -> types.InputChannel: - return types.InputChannel( - channel_id=self.identifier, access_hash=self.authorization or 0 - ) + return types.InputChannel(channel_id=self.identifier, access_hash=self.authorization or 0) def __str__(self) -> str: return f"{CHANNEL_PREFIX}{self._encode_str()}" diff --git a/client/src/telethon/_impl/session/message_box/adaptor.py b/client/src/telethon/_impl/session/message_box/adaptor.py index 2eafb5bd..a2560757 100644 --- a/client/src/telethon/_impl/session/message_box/adaptor.py +++ b/client/src/telethon/_impl/session/message_box/adaptor.py @@ -27,9 +27,7 @@ def update_short(short: types.UpdateShort) -> types.UpdatesCombined: ) -def update_short_message( - short: types.UpdateShortMessage, self_id: int -) -> types.UpdatesCombined: +def update_short_message(short: types.UpdateShortMessage, self_id: int) -> types.UpdatesCombined: return update_short( types.UpdateShort( update=types.UpdateNewMessage( @@ -46,9 +44,7 @@ def update_short_message( noforwards=False, reactions=None, id=short.id, - from_id=types.PeerUser( - user_id=self_id if short.out else short.user_id - ), + from_id=types.PeerUser(user_id=self_id if short.out else short.user_id), peer_id=types.PeerChat( chat_id=short.user_id, ), diff --git a/client/src/telethon/_impl/session/message_box/defs.py b/client/src/telethon/_impl/session/message_box/defs.py index 3033c921..dcfe4b71 100644 --- a/client/src/telethon/_impl/session/message_box/defs.py +++ b/client/src/telethon/_impl/session/message_box/defs.py @@ -38,9 +38,7 @@ class PtsInfo: self.entry = entry def __repr__(self) -> str: - return ( - f"PtsInfo(pts={self.pts}, pts_count={self.pts_count}, entry={self.entry})" - ) + return f"PtsInfo(pts={self.pts}, pts_count={self.pts_count}, entry={self.entry})" class State: @@ -70,9 +68,7 @@ class PossibleGap: self.updates = updates def __repr__(self) -> str: - return ( - f"PossibleGap(deadline={self.deadline}, update_count={len(self.updates)})" - ) + return f"PossibleGap(deadline={self.deadline}, update_count={len(self.updates)})" class PrematureEndReason(Enum): diff --git a/client/src/telethon/_impl/session/message_box/messagebox.py b/client/src/telethon/_impl/session/message_box/messagebox.py index a93a01f9..68fb9c37 100644 --- a/client/src/telethon/_impl/session/message_box/messagebox.py +++ b/client/src/telethon/_impl/session/message_box/messagebox.py @@ -91,13 +91,9 @@ class MessageBox: self.map[ENTRY_ACCOUNT] = State(pts=state.pts, deadline=deadline) if state.qts != NO_SEQ: self.map[ENTRY_SECRET] = State(pts=state.qts, deadline=deadline) - self.map.update( - (s.id, State(pts=s.pts, deadline=deadline)) for s in state.channels - ) + self.map.update((s.id, State(pts=s.pts, deadline=deadline)) for s in state.channels) - self.date = datetime.datetime.fromtimestamp( - state.date, tz=datetime.timezone.utc - ) + self.date = datetime.datetime.fromtimestamp(state.date, tz=datetime.timezone.utc) self.seq = state.seq self.possible_gaps.clear() self.getting_diff_for.clear() @@ -136,28 +132,18 @@ class MessageBox: default_deadline = next_updates_deadline() if self.possible_gaps: - deadline = min( - default_deadline, *(gap.deadline for gap in self.possible_gaps.values()) - ) + deadline = min(default_deadline, *(gap.deadline for gap in self.possible_gaps.values())) elif self.next_deadline in self.map: deadline = min(default_deadline, self.map[self.next_deadline].deadline) else: deadline = default_deadline if now >= deadline: - self.getting_diff_for.update( - entry - for entry, gap in self.possible_gaps.items() - if now >= gap.deadline - ) - self.getting_diff_for.update( - entry for entry, state in self.map.items() if now >= state.deadline - ) + self.getting_diff_for.update(entry for entry, gap in self.possible_gaps.items() if now >= gap.deadline) + self.getting_diff_for.update(entry for entry, state in self.map.items() if now >= state.deadline) if __debug__: - self._trace( - "deadlines met, now getting diff for: %r", self.getting_diff_for - ) + self._trace("deadlines met, now getting diff for: %r", self.getting_diff_for) for entry in self.getting_diff_for: self.possible_gaps.pop(entry, None) @@ -171,19 +157,12 @@ class MessageBox: entry: Entry = ENTRY_ACCOUNT # for pyright to know it's not unbound for entry in entries: if entry not in self.map: - raise RuntimeError( - "Called reset_deadline on an entry for which we do not have state" - ) + raise RuntimeError("Called reset_deadline on an entry for which we do not have state") self.map[entry].deadline = deadline if self.next_deadline in entries: - self.next_deadline = min( - self.map.items(), key=lambda entry_state: entry_state[1].deadline - )[0] - elif ( - self.next_deadline in self.map - and deadline < self.map[self.next_deadline].deadline - ): + self.next_deadline = min(self.map.items(), key=lambda entry_state: entry_state[1].deadline)[0] + elif self.next_deadline in self.map and deadline < self.map[self.next_deadline].deadline: self.next_deadline = entry def reset_channel_deadline(self, channel_id: int, timeout: Optional[float]) -> None: @@ -200,9 +179,7 @@ class MessageBox: assert isinstance(state, types.updates.State) self.map[ENTRY_ACCOUNT] = State(state.pts, deadline) self.map[ENTRY_SECRET] = State(state.qts, deadline) - self.date = datetime.datetime.fromtimestamp( - state.date, tz=datetime.timezone.utc - ) + self.date = datetime.datetime.fromtimestamp(state.date, tz=datetime.timezone.utc) self.seq = state.seq def try_set_channel_state(self, id: int, pts: int) -> None: @@ -215,15 +192,11 @@ class MessageBox: def try_begin_get_diff(self, entry: Entry, reason: str) -> None: if entry not in self.map: if entry in self.possible_gaps: - raise RuntimeError( - "Should not have a possible_gap for an entry not in the state map" - ) + raise RuntimeError("Should not have a possible_gap for an entry not in the state map") return if __debug__: - self._trace( - "marking entry=%r as needing difference because: %s", entry, reason - ) + self._trace("marking entry=%r as needing difference because: %s", entry, reason) self.getting_diff_for.add(entry) self.possible_gaps.pop(entry, None) @@ -231,14 +204,10 @@ class MessageBox: try: self.getting_diff_for.remove(entry) except KeyError: - raise RuntimeError( - "Called end_get_diff on an entry which was not getting diff for" - ) + raise RuntimeError("Called end_get_diff on an entry which was not getting diff for") self.reset_deadlines({entry}, next_updates_deadline()) - assert ( - entry not in self.possible_gaps - ), "gaps shouldn't be created while getting difference" + assert entry not in self.possible_gaps, "gaps shouldn't be created while getting difference" def ensure_known_peer_hashes( self, @@ -246,10 +215,7 @@ class MessageBox: chat_hashes: ChatHashCache, ) -> None: if not chat_hashes.extend_from_updates(updates): - can_recover = ( - not isinstance(updates, types.UpdateShort) - or pts_info_from_update(updates.update) is not None - ) + can_recover = not isinstance(updates, types.UpdateShort) or pts_info_from_update(updates.update) is not None if can_recover: self.try_begin_get_diff(ENTRY_ACCOUNT, "missing hash") raise Gap @@ -275,9 +241,7 @@ class MessageBox: if combined.seq_start != NO_SEQ: if self.seq + 1 > combined.seq_start: if __debug__: - self._trace( - "skipping updates as they should have already been handled" - ) + self._trace("skipping updates as they should have already been handled") return result, combined.users, combined.chats elif self.seq + 1 < combined.seq_start: self.try_begin_get_diff(ENTRY_ACCOUNT, "detected gap") @@ -305,17 +269,13 @@ class MessageBox: if __debug__: self._trace("updating seq as local pts was updated too") if combined.date != NO_DATE: - self.date = datetime.datetime.fromtimestamp( - combined.date, tz=datetime.timezone.utc - ) + self.date = datetime.datetime.fromtimestamp(combined.date, tz=datetime.timezone.utc) if combined.seq != NO_SEQ: self.seq = combined.seq if self.possible_gaps: if __debug__: - self._trace( - "trying to re-apply count=%r possible gaps", len(self.possible_gaps) - ) + self._trace("trying to re-apply count=%r possible gaps", len(self.possible_gaps)) for key in list(self.possible_gaps.keys()): self.possible_gaps[key].updates.sort(key=update_sort_key) @@ -332,9 +292,7 @@ class MessageBox: applied, ) - self.possible_gaps = { - entry: gap for entry, gap in self.possible_gaps.items() if gap.updates - } + self.possible_gaps = {entry: gap for entry, gap in self.possible_gaps.items() if gap.updates} return result, combined.users, combined.chats @@ -384,8 +342,7 @@ class MessageBox: ) if pts.entry not in self.possible_gaps: self.possible_gaps[pts.entry] = PossibleGap( - deadline=asyncio.get_running_loop().time() - + POSSIBLE_GAP_TIMEOUT, + deadline=asyncio.get_running_loop().time() + POSSIBLE_GAP_TIMEOUT, updates=[], ) @@ -413,20 +370,14 @@ class MessageBox: for entry in (ENTRY_ACCOUNT, ENTRY_SECRET): if entry in self.getting_diff_for: if entry not in self.map: - raise RuntimeError( - "Should not try to get difference for an entry without known state" - ) + raise RuntimeError("Should not try to get difference for an entry without known state") gd = functions.updates.get_difference( pts=self.map[ENTRY_ACCOUNT].pts, pts_limit=None, pts_total_limit=None, date=int(self.date.timestamp()), - qts=( - self.map[ENTRY_SECRET].pts - if ENTRY_SECRET in self.map - else NO_SEQ - ), + qts=(self.map[ENTRY_SECRET].pts if ENTRY_SECRET in self.map else NO_SEQ), qts_limit=None, ) if __debug__: @@ -447,9 +398,7 @@ class MessageBox: result: tuple[list[abcs.Update], Sequence[abcs.User], Sequence[abcs.Chat]] if isinstance(diff, types.updates.DifferenceEmpty): finish = True - self.date = datetime.datetime.fromtimestamp( - diff.date, tz=datetime.timezone.utc - ) + self.date = datetime.datetime.fromtimestamp(diff.date, tz=datetime.timezone.utc) self.seq = diff.seq result = [], [], [] elif isinstance(diff, types.updates.Difference): @@ -502,9 +451,7 @@ class MessageBox: assert isinstance(state, types.updates.State) self.map[ENTRY_ACCOUNT].pts = state.pts self.map[ENTRY_SECRET].pts = state.qts - self.date = datetime.datetime.fromtimestamp( - state.date, tz=datetime.timezone.utc - ) + self.date = datetime.datetime.fromtimestamp(state.date, tz=datetime.timezone.utc) self.seq = state.seq updates, users, chats = self.process_updates( @@ -560,19 +507,13 @@ class MessageBox: channel=channel, filter=types.ChannelMessagesFilterEmpty(), pts=state.pts, - limit=( - BOT_CHANNEL_DIFF_LIMIT - if chat_hashes.is_self_bot - else USER_CHANNEL_DIFF_LIMIT - ), + limit=(BOT_CHANNEL_DIFF_LIMIT if chat_hashes.is_self_bot else USER_CHANNEL_DIFF_LIMIT), ) if __debug__: self._trace("requesting channel difference: %s", gd) return gd else: - raise RuntimeError( - "should not try to get difference for an entry without known state" - ) + raise RuntimeError("should not try to get difference for an entry without known state") else: self.end_get_diff(entry) self.map.pop(entry, None) @@ -638,9 +579,7 @@ class MessageBox: else: raise RuntimeError("unexpected case") - def end_channel_difference( - self, channel_id: int, reason: PrematureEndReason - ) -> None: + def end_channel_difference(self, channel_id: int, reason: PrematureEndReason) -> None: entry: Entry = channel_id if __debug__: self._trace("ending channel=%r difference: %s", entry, reason) diff --git a/client/src/telethon/_impl/session/storage/sqlite.py b/client/src/telethon/_impl/session/storage/sqlite.py index 2988e336..335ee035 100644 --- a/client/src/telethon/_impl/session/storage/sqlite.py +++ b/client/src/telethon/_impl/session/storage/sqlite.py @@ -38,9 +38,7 @@ class SqliteSession(Storage): if version == 7: session = self._load_v7(c) else: - raise ValueError( - "only migration from sqlite session format 7 supported" - ) + raise ValueError("only migration from sqlite session format 7 supported") self._reset(c) self._get_or_init_version(c) @@ -105,11 +103,7 @@ class SqliteSession(Storage): DataCenter(id=id, ipv4_addr=ipv4_addr, ipv6_addr=ipv6_addr, auth=auth) for (id, ipv4_addr, ipv6_addr, auth) in datacenter ], - user=( - User(id=user[0], dc=user[1], bot=bool(user[2]), username=user[3]) - if user - else None - ), + user=(User(id=user[0], dc=user[1], bot=bool(user[2]), username=user[3]) if user else None), state=( UpdateState( pts=state[0], @@ -166,9 +160,7 @@ class SqliteSession(Storage): @staticmethod def _get_or_init_version(c: sqlite3.Cursor) -> int: - c.execute( - "select name from sqlite_master where type='table' and name='version'" - ) + c.execute("select name from sqlite_master where type='table' and name='version'") if c.fetchone(): c.execute("select version from version") tup = c.fetchone() diff --git a/client/src/telethon/_impl/tl/core/reader.py b/client/src/telethon/_impl/tl/core/reader.py index 3bb32a4c..36c8622d 100644 --- a/client/src/telethon/_impl/tl/core/reader.py +++ b/client/src/telethon/_impl/tl/core/reader.py @@ -23,9 +23,7 @@ def _bootstrap_get_ty(constructor_id: int) -> Optional[Type["Serializable"]]: from ..mtproto.layer import TYPE_MAPPING as MTPROTO_TYPES if API_TYPES.keys() & MTPROTO_TYPES.keys(): - raise RuntimeError( - "generated api and mtproto schemas cannot have colliding constructor identifiers" - ) + raise RuntimeError("generated api and mtproto schemas cannot have colliding constructor identifiers") ALL_TYPES = API_TYPES | MTPROTO_TYPES # Signatures don't fully match, but this is a private method @@ -39,9 +37,7 @@ class Reader: __slots__ = ("_view", "_pos", "_len") def __init__(self, buffer: "Buffer") -> None: - self._view = ( - memoryview(buffer) if not isinstance(buffer, memoryview) else buffer - ) + self._view = memoryview(buffer) if not isinstance(buffer, memoryview) else buffer self._pos = 0 self._len = len(self._view) diff --git a/client/src/telethon/_impl/tl/core/request.py b/client/src/telethon/_impl/tl/core/request.py index b053bdd4..83acc8d5 100644 --- a/client/src/telethon/_impl/tl/core/request.py +++ b/client/src/telethon/_impl/tl/core/request.py @@ -14,9 +14,7 @@ def _bootstrap_get_deserializer( from ..mtproto.layer import RESPONSE_MAPPING as MTPROTO_DESER if API_DESER.keys() & MTPROTO_DESER.keys(): - raise RuntimeError( - "generated api and mtproto schemas cannot have colliding constructor identifiers" - ) + raise RuntimeError("generated api and mtproto schemas cannot have colliding constructor identifiers") ALL_DESER = API_DESER | MTPROTO_DESER Request._get_deserializer = ALL_DESER.get # type: ignore [assignment] diff --git a/client/src/telethon/_impl/tl/core/serializable.py b/client/src/telethon/_impl/tl/core/serializable.py index 8af4e56c..9d753e4d 100644 --- a/client/src/telethon/_impl/tl/core/serializable.py +++ b/client/src/telethon/_impl/tl/core/serializable.py @@ -49,9 +49,7 @@ class Serializable(abc.ABC): def __eq__(self, other: object) -> bool: if not isinstance(other, self.__class__): return NotImplemented - return all( - getattr(self, attr) == getattr(other, attr) for attr in self.__slots__ - ) + return all(getattr(self, attr) == getattr(other, attr) for attr in self.__slots__) def serialize_bytes_to(buffer: bytearray, data: bytes | bytearray | memoryview) -> None: diff --git a/client/src/telethon/types/buttons.py b/client/src/telethon/types/buttons.py index 3c857305..b67bb83f 100644 --- a/client/src/telethon/types/buttons.py +++ b/client/src/telethon/types/buttons.py @@ -19,6 +19,7 @@ and those you can define when using :meth:`telethon.Client.send_message`: buttons.Callback('Demo', b'data') ]) """ + from .._impl.client.types.buttons import ( Callback, RequestGeoLocation, diff --git a/client/tests/auth_key_test.py b/client/tests/auth_key_test.py index 978b9554..2a6095dd 100644 --- a/client/tests/auth_key_test.py +++ b/client/tests/auth_key_test.py @@ -26,25 +26,16 @@ def test_auth_key_id() -> None: def test_calc_new_nonce_hash1() -> None: auth_key = get_auth_key() new_nonce = get_new_nonce() - assert ( - auth_key.calc_new_nonce_hash(new_nonce, 1) - == 258944117842285651226187582903746985063 - ) + assert auth_key.calc_new_nonce_hash(new_nonce, 1) == 258944117842285651226187582903746985063 def test_calc_new_nonce_hash2() -> None: auth_key = get_auth_key() new_nonce = get_new_nonce() - assert ( - auth_key.calc_new_nonce_hash(new_nonce, 2) - == 324588944215647649895949797213421233055 - ) + assert auth_key.calc_new_nonce_hash(new_nonce, 2) == 324588944215647649895949797213421233055 def test_calc_new_nonce_hash3() -> None: auth_key = get_auth_key() new_nonce = get_new_nonce() - assert ( - auth_key.calc_new_nonce_hash(new_nonce, 3) - == 100989356540453064705070297823778556733 - ) + assert auth_key.calc_new_nonce_hash(new_nonce, 3) == 100989356540453064705070297823778556733 diff --git a/client/tests/crypto_test.py b/client/tests/crypto_test.py index e94fa11d..faf92d2b 100644 --- a/client/tests/crypto_test.py +++ b/client/tests/crypto_test.py @@ -66,14 +66,8 @@ def test_key_from_nonce() -> None: new_nonce = int.from_bytes(bytes(range(32))) key, iv = generate_key_data_from_nonce(server_nonce, new_nonce) - assert ( - key - == b'\x07X\xf1S;a]$\xf6\xe8\xa9Jo\xcb\xee\nU\xea\xab"\x17\xd7)\\\xa9!=\x1a-}\x16\xa6' - ) - assert ( - iv - == b"Z\x84\x10\x8e\x98\x05el\xe8d\x07\x0e\x16nb\x18\xf6x>\x85\x11G\x1aZ\xb7\x80,\xf2\x00\x01\x02\x03" - ) + assert key == b'\x07X\xf1S;a]$\xf6\xe8\xa9Jo\xcb\xee\nU\xea\xab"\x17\xd7)\\\xa9!=\x1a-}\x16\xa6' + assert iv == b"Z\x84\x10\x8e\x98\x05el\xe8d\x07\x0e\x16nb\x18\xf6x>\x85\x11G\x1aZ\xb7\x80,\xf2\x00\x01\x02\x03" def test_verify_ige_encryption() -> None: @@ -88,5 +82,7 @@ def test_verify_ige_decryption() -> None: ciphertext = get_test_aes_key_or_iv() key = get_test_aes_key_or_iv() iv = get_test_aes_key_or_iv() - expected = b"\xe5wz\xfa\xcd{,\x16\xf7\xac@\xca\xe6\x1e\xf6\x03\xfe\xe6\t\x8f\xb8\xa8\x86\n\xb9\xeeg,\xd7\xe5\xba\xcc" + expected = ( + b"\xe5wz\xfa\xcd{,\x16\xf7\xac@\xca\xe6\x1e\xf6\x03\xfe\xe6\t\x8f\xb8\xa8\x86\n\xb9\xeeg,\xd7\xe5\xba\xcc" + ) assert decrypt_ige(ciphertext, key, iv) == expected diff --git a/client/tests/parsers_test.py b/client/tests/parsers_test.py index 4541306a..373ed105 100644 --- a/client/tests/parsers_test.py +++ b/client/tests/parsers_test.py @@ -32,10 +32,7 @@ def test_parse_all_entities_markdown() -> None: markdown = "Some **bold** (__strong__), *italics* (_cursive_), inline `code`, a\n```rust\npre\n```\nblock, a [link](https://example.com), and [mentions](tg://user?id=12345678)" text, entities = parse_markdown_message(markdown) - assert ( - text - == "Some bold (strong), italics (cursive), inline code, a\npre\nblock, a link, and mentions" - ) + assert text == "Some bold (strong), italics (cursive), inline code, a\npre\nblock, a link, and mentions" assert entities == [ types.MessageEntityBold(offset=5, length=4), types.MessageEntityBold(offset=11, length=6), @@ -92,10 +89,7 @@ def test_parse_emoji_html() -> None: def test_parse_all_entities_html() -> None: html = 'Some bold (strong), italics (cursive), inline code, a
pre
block, a link,
spoilers
and mentions' text, entities = parse_html_message(html) - assert ( - text - == "Some bold (strong), italics (cursive), inline code, a pre block, a link, spoilers and mentions" - ) + assert text == "Some bold (strong), italics (cursive), inline code, a pre block, a link, spoilers and mentions" assert entities == [ types.MessageEntityBold(offset=5, length=4), types.MessageEntityBold(offset=11, length=6), diff --git a/generator/pyproject.toml b/generator/pyproject.toml index b0166ca2..2228b95c 100644 --- a/generator/pyproject.toml +++ b/generator/pyproject.toml @@ -38,6 +38,10 @@ build-backend = "setuptools.build_meta" version = {attr = "telethon_generator.version.__version__"} [tool.ruff] +line-length = 120 + +[tool.ruff.lint] +select = ["F", "E", "W", "I"] ignore = [ "E501", # formatter takes care of lines that are too long besides documentation ] diff --git a/generator/src/telethon_generator/_impl/codegen/generator.py b/generator/src/telethon_generator/_impl/codegen/generator.py index d3fbe128..c922d90c 100644 --- a/generator/src/telethon_generator/_impl/codegen/generator.py +++ b/generator/src/telethon_generator/_impl/codegen/generator.py @@ -17,9 +17,7 @@ from .serde.deserialization import ( from .serde.serialization import generate_function, generate_write -def generate_init( - writer: SourceWriter, namespaces: set[str], classes: set[str] -) -> None: +def generate_init(writer: SourceWriter, namespaces: set[str], classes: set[str]) -> None: sorted_cls = list(sorted(classes)) sorted_ns = list(sorted(namespaces)) @@ -93,9 +91,7 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None: writer = fs.open(type_path) if type_path not in fs: - writer.write( - "# pyright: reportUnusedImport=false, reportConstantRedefinition=false" - ) + writer.write("# pyright: reportUnusedImport=false, reportConstantRedefinition=false") writer.write("import struct") writer.write("from typing import Optional, Self, Sequence") writer.write("from .. import abcs") @@ -106,9 +102,7 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None: generated_type_names.add(f"{ns}{to_class_name(typedef.name)}") # class Type(BaseType) - writer.write( - f"class {to_class_name(typedef.name)}({inner_type_fmt(typedef.ty)}):" - ) + writer.write(f"class {to_class_name(typedef.name)}({inner_type_fmt(typedef.ty)}):") # __slots__ = ('params', ...) slots = " ".join(f"'{p.name}'," for p in property_params) @@ -121,9 +115,7 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None: # def __init__() if property_params: - params = "".join( - f", {p.name}: {param_type_fmt(p.ty)}" for p in property_params - ) + params = "".join(f", {p.name}: {param_type_fmt(p.ty)}" for p in property_params) writer.write(f" def __init__(_s, *{params}) -> None:") for p in property_params: writer.write(f" _s.{p.name} = {p.name}") @@ -151,9 +143,7 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None: raise ValueError("nested function-namespaces are not supported") elif len(functiondef.namespace) == 1: function_namespaces.add(functiondef.namespace[0]) - function_path = (Path("functions") / functiondef.namespace[0]).with_suffix( - ".py" - ) + function_path = (Path("functions") / functiondef.namespace[0]).with_suffix(".py") else: function_def_names.add(to_method_name(functiondef.name)) function_path = Path("functions/_nons.py") @@ -173,18 +163,14 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None: params = "".join(f", {p.name}: {param_type_fmt(p.ty)}" for p in required_params) star = "*" if params else "" return_ty = param_type_fmt(NormalParameter(ty=functiondef.ty, flag=None)) - writer.write( - f"def {to_method_name(functiondef.name)}({star}{params}) -> Request[{return_ty}]:" - ) + writer.write(f"def {to_method_name(functiondef.name)}({star}{params}) -> Request[{return_ty}]:") writer.indent(2) generate_function(writer, functiondef) writer.dedent(2) generate_init(fs.open(Path("abcs/__init__.py")), abc_namespaces, abc_class_names) generate_init(fs.open(Path("types/__init__.py")), type_namespaces, type_class_names) - generate_init( - fs.open(Path("functions/__init__.py")), function_namespaces, function_def_names - ) + generate_init(fs.open(Path("functions/__init__.py")), function_namespaces, function_def_names) writer = fs.open(Path("layer.py")) writer.write("# pyright: reportUnusedImport=false") @@ -194,16 +180,12 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None: ) writer.write("from typing import cast, Type") writer.write(f"LAYER = {tl.layer!r}") - writer.write( - "TYPE_MAPPING = {t.constructor_id(): t for t in cast(tuple[Type[Serializable]], (" - ) + writer.write("TYPE_MAPPING = {t.constructor_id(): t for t in cast(tuple[Type[Serializable]], (") for name in sorted(generated_type_names): writer.write(f" types.{name},") writer.write("))}") writer.write("RESPONSE_MAPPING = {") for functiondef in tl.functiondefs: - writer.write( - f" {hex(functiondef.id)}: {function_deserializer_fmt(functiondef)}," - ) + writer.write(f" {hex(functiondef.id)}: {function_deserializer_fmt(functiondef)},") writer.write("}") writer.write("__all__ = ['LAYER', 'TYPE_MAPPING', 'RESPONSE_MAPPING']") diff --git a/generator/src/telethon_generator/_impl/codegen/serde/common.py b/generator/src/telethon_generator/_impl/codegen/serde/common.py index 0269af53..affb263d 100644 --- a/generator/src/telethon_generator/_impl/codegen/serde/common.py +++ b/generator/src/telethon_generator/_impl/codegen/serde/common.py @@ -50,9 +50,7 @@ _TRIVIAL_STRUCT_MAP = {"int": "i", "long": "q", "double": "d", "Bool": "I"} def trivial_struct_fmt(ty: BaseParameter) -> str: try: - return ( - _TRIVIAL_STRUCT_MAP[ty.ty.name] if isinstance(ty, NormalParameter) else "I" - ) + return _TRIVIAL_STRUCT_MAP[ty.ty.name] if isinstance(ty, NormalParameter) else "I" except KeyError: raise ValueError("input param was not trivial") diff --git a/generator/src/telethon_generator/_impl/codegen/serde/deserialization.py b/generator/src/telethon_generator/_impl/codegen/serde/deserialization.py index a5aa2d8b..1d132626 100644 --- a/generator/src/telethon_generator/_impl/codegen/serde/deserialization.py +++ b/generator/src/telethon_generator/_impl/codegen/serde/deserialization.py @@ -38,9 +38,7 @@ def reader_read_fmt(ty: Type, constructor_id: int) -> tuple[str, Optional[str]]: return f"reader.read_serializable({inner_type_fmt(ty)})", "type-abstract" -def generate_normal_param_read( - writer: SourceWriter, name: str, param: NormalParameter, constructor_id: int -) -> None: +def generate_normal_param_read(writer: SourceWriter, name: str, param: NormalParameter, constructor_id: int) -> None: flag_check = f"_{param.flag.name} & {1 << param.flag.index}" if param.flag else None if param.ty.name == "true": if not flag_check: @@ -55,9 +53,7 @@ def generate_normal_param_read( if param.ty.generic_arg: if param.ty.name not in ("Vector", "vector"): - raise ValueError( - "generic_arg deserialization for non-vectors is not supported" - ) + raise ValueError("generic_arg deserialization for non-vectors is not supported") if param.ty.bare: writer.write("__len = reader.read_fmt(' str: def function_deserializer_fmt(defn: Definition) -> str: if defn.ty.generic_arg: if defn.ty.name != ("Vector"): - raise ValueError( - "generic_arg return for non-boxed-vectors is not supported" - ) + raise ValueError("generic_arg return for non-boxed-vectors is not supported") elif defn.ty.generic_ref: raise ValueError("return for generic refs inside vector is not supported") elif is_trivial(NormalParameter(ty=defn.ty.generic_arg, flag=None)): @@ -138,13 +126,9 @@ def function_deserializer_fmt(defn: Definition) -> str: elif defn.ty.generic_arg.name == "long": return "deserialize_i64_list" else: - raise ValueError( - f"return for trivial arg {defn.ty.generic_arg} is not supported" - ) + raise ValueError(f"return for trivial arg {defn.ty.generic_arg} is not supported") elif defn.ty.generic_arg.bare: - raise ValueError( - "return for non-boxed serializables inside a vector is not supported" - ) + raise ValueError("return for non-boxed serializables inside a vector is not supported") else: return f"list_deserializer({inner_type_fmt(defn.ty.generic_arg)})" elif defn.ty.generic_ref: diff --git a/generator/src/telethon_generator/_impl/codegen/serde/serialization.py b/generator/src/telethon_generator/_impl/codegen/serde/serialization.py index 792893c9..1ccf3b11 100644 --- a/generator/src/telethon_generator/_impl/codegen/serde/serialization.py +++ b/generator/src/telethon_generator/_impl/codegen/serde/serialization.py @@ -15,15 +15,11 @@ def param_value_expr(param: Parameter) -> str: return f"{pre}{mid}{suf}" -def generate_buffer_append( - writer: SourceWriter, buffer: str, name: str, ty: Type -) -> None: +def generate_buffer_append(writer: SourceWriter, buffer: str, name: str, ty: Type) -> None: if is_trivial(NormalParameter(ty=ty, flag=None)): fmt = trivial_struct_fmt(NormalParameter(ty=ty, flag=None)) if ty.name == "Bool": - writer.write( - f"{buffer} += struct.pack(f'<{fmt}', (0x997275b5 if {name} else 0xbc799737))" - ) + writer.write(f"{buffer} += struct.pack(f'<{fmt}', (0x997275b5 if {name} else 0xbc799737))") else: writer.write(f"{buffer} += struct.pack(f'<{fmt}', {name})") elif ty.generic_ref or ty.name == "Object": @@ -58,9 +54,7 @@ def generate_normal_param_write( if param.ty.generic_arg: if param.ty.name not in ("Vector", "vector"): - raise ValueError( - "generic_arg deserialization for non-vectors is not supported" - ) + raise ValueError("generic_arg deserialization for non-vectors is not supported") if param.ty.bare: writer.write(f"{buffer} += struct.pack(' None: else f"(0 if self.{p.name} is None else {1 << p.ty.flag.index})" ) for p in defn.params - if isinstance(p.ty, NormalParameter) - and p.ty.flag - and p.ty.flag.name == param.name + if isinstance(p.ty, NormalParameter) and p.ty.flag and p.ty.flag.name == param.name ) writer.write(f"_{param.name} = {flags or 0}") @@ -123,9 +113,7 @@ def generate_write(writer: SourceWriter, defn: Definition) -> None: for param in iter: if not isinstance(param.ty, NormalParameter): raise RuntimeError("FlagsParameter should be considered trivial") - generate_normal_param_write( - writer, tmp_names, "buffer", f"self.{param.name}", param.ty - ) + generate_normal_param_write(writer, tmp_names, "buffer", f"self.{param.name}", param.ty) def generate_function(writer: SourceWriter, defn: Definition) -> None: @@ -148,9 +136,7 @@ def generate_function(writer: SourceWriter, defn: Definition) -> None: else f"(0 if {p.name} is None else {1 << p.ty.flag.index})" ) for p in defn.params - if isinstance(p.ty, NormalParameter) - and p.ty.flag - and p.ty.flag.name == param.name + if isinstance(p.ty, NormalParameter) and p.ty.flag and p.ty.flag.name == param.name ) writer.write(f"{param.name} = {flags or 0}") @@ -161,7 +147,5 @@ def generate_function(writer: SourceWriter, defn: Definition) -> None: for param in iter: if not isinstance(param.ty, NormalParameter): raise RuntimeError("FlagsParameter should be considered trivial") - generate_normal_param_write( - writer, tmp_names, "_buffer", param.name, param.ty - ) + generate_normal_param_write(writer, tmp_names, "_buffer", param.name, param.ty) writer.write("return Request(b'' + _buffer)") diff --git a/generator/src/telethon_generator/_impl/tl_parser/loader.py b/generator/src/telethon_generator/_impl/tl_parser/loader.py index 70185de7..80f4e3a3 100644 --- a/generator/src/telethon_generator/_impl/tl_parser/loader.py +++ b/generator/src/telethon_generator/_impl/tl_parser/loader.py @@ -35,6 +35,4 @@ def load_tl_file(path: str | Path) -> ParsedTl: else: functiondefs.append(definition) - return ParsedTl( - layer=layer, typedefs=list(typedefs), functiondefs=list(functiondefs) - ) + return ParsedTl(layer=layer, typedefs=list(typedefs), functiondefs=list(functiondefs)) diff --git a/generator/tests/generator_test.py b/generator/tests/generator_test.py index 949a3179..0fda1f7b 100644 --- a/generator/tests/generator_test.py +++ b/generator/tests/generator_test.py @@ -14,9 +14,7 @@ def gen_py_code( functiondefs: Optional[list[Definition]] = None, ) -> str: fs = FakeFs() - generate( - fs, ParsedTl(layer=0, typedefs=typedefs or [], functiondefs=functiondefs or []) - ) + generate(fs, ParsedTl(layer=0, typedefs=typedefs or [], functiondefs=functiondefs or [])) generated = bytearray() for path, data in fs._files.items(): if path.stem not in ("__init__", "layer"): @@ -27,9 +25,7 @@ def gen_py_code( def test_generic_functions_use_bytes_parameters() -> None: - definitions = get_definitions( - "invokeWithLayer#da9b0d0d {X:Type} layer:int query:!X = X;" - ) + definitions = get_definitions("invokeWithLayer#da9b0d0d {X:Type} layer:int query:!X = X;") result = gen_py_code(functiondefs=definitions) assert "invoke_with_layer" in result assert "query: _bytes" in result diff --git a/generator/tests/parameter_test.py b/generator/tests/parameter_test.py index 6daf7bbd..61829f30 100644 --- a/generator/tests/parameter_test.py +++ b/generator/tests/parameter_test.py @@ -54,18 +54,14 @@ def test_valid_param() -> None: assert Parameter.from_str("foo:!bar") == Parameter( name="foo", ty=NormalParameter( - ty=Type( - namespace=[], name="bar", bare=True, generic_ref=True, generic_arg=None - ), + ty=Type(namespace=[], name="bar", bare=True, generic_ref=True, generic_arg=None), flag=None, ), ) assert Parameter.from_str("foo:bar.1?baz") == Parameter( name="foo", ty=NormalParameter( - ty=Type( - namespace=[], name="baz", bare=True, generic_ref=False, generic_arg=None - ), + ty=Type(namespace=[], name="baz", bare=True, generic_ref=False, generic_arg=None), flag=Flag( name="bar", index=1, diff --git a/generator/tests/ty_test.py b/generator/tests/ty_test.py index 3effb0bf..a797c759 100644 --- a/generator/tests/ty_test.py +++ b/generator/tests/ty_test.py @@ -11,9 +11,7 @@ def test_empty_simple() -> None: def test_simple() -> None: - assert Type.from_str("foo") == Type( - namespace=[], name="foo", bare=True, generic_ref=False, generic_arg=None - ) + assert Type.from_str("foo") == Type(namespace=[], name="foo", bare=True, generic_ref=False, generic_arg=None) @mark.parametrize("ty", [".", "..", ".foo", "foo.", "foo..foo", ".foo."]) diff --git a/tools/check.py b/tools/check.py index 3fdf3174..c587d974 100644 --- a/tools/check.py +++ b/tools/check.py @@ -1,6 +1,7 @@ """ Check formatting, type-check and run offline tests. """ + import subprocess import sys import tempfile @@ -15,9 +16,7 @@ def run(*args: str) -> int: def main() -> None: with tempfile.TemporaryDirectory() as tmp_dir: exit( - run("isort", ".", "-c", "--profile", "black", "--gitignore") - or run("black", ".", "--check", "--extend-exclude", BLACK_IGNORE) - or run("mypy", "--strict", ".") + run("mypy", "--strict", ".") or run("ruff", "check", ".") or run("sphinx", "-M", "dummy", "client/doc", tmp_dir, "-n", "-W") or run("pytest", ".", "-m", "not net") diff --git a/tools/codegen.py b/tools/codegen.py index 64a4dc16..316d195c 100644 --- a/tools/codegen.py +++ b/tools/codegen.py @@ -2,6 +2,7 @@ Run `telethon_generator.codegen` on both `api.tl` and `mtproto.tl` to output corresponding Python code in the default directories under the `client/`. """ + import subprocess import sys diff --git a/tools/copy_client_signatures.py b/tools/copy_client_signatures.py index d3e48520..219de610 100644 --- a/tools/copy_client_signatures.py +++ b/tools/copy_client_signatures.py @@ -110,13 +110,13 @@ def main() -> None: function.args.args[0].annotation = None if isinstance(function, ast.AsyncFunctionDef): - call = ast.Await(value=call) + call = ast.Await(value=call) # type: ignore [arg-type] match function.returns: case ast.Constant(value=None): - call = ast.Expr(value=call) + call = ast.Expr(value=call) # type: ignore [arg-type] case _: - call = ast.Return(value=call) + call = ast.Return(value=call) # type: ignore [arg-type] function.body.append(call) class_body.append(function) diff --git a/tools/docgen.py b/tools/docgen.py index 29ed054f..916a3895 100644 --- a/tools/docgen.py +++ b/tools/docgen.py @@ -1,6 +1,7 @@ """ Run `sphinx-build` to create HTML documentation and detect errors. """ + import subprocess import sys diff --git a/tools/fmt.py b/tools/fmt.py index fde86123..08cbb06e 100644 --- a/tools/fmt.py +++ b/tools/fmt.py @@ -1,6 +1,7 @@ """ Sort imports and format code. """ + import subprocess import sys