Upgrade ruff and mypy version, format files

This commit is contained in:
Jahongir Qurbonov 2024-10-04 10:32:11 +05:00
parent 918f719ab2
commit b25ec41a1f
73 changed files with 320 additions and 1007 deletions

View File

@ -15,9 +15,7 @@ def make_link_node(rawtext, app, name, options):
base += "/" base += "/"
set_classes(options) set_classes(options)
node = nodes.reference( node = nodes.reference(rawtext, utils.unescape(name), refuri="{}?q={}".format(base, name), **options)
rawtext, utils.unescape(name), refuri="{}?q={}".format(base, name), **options
)
return node return node

View File

@ -28,10 +28,8 @@ dynamic = ["version"]
[project.optional-dependencies] [project.optional-dependencies]
cryptg = ["cryptg~=0.4"] cryptg = ["cryptg~=0.4"]
dev = [ dev = [
"isort~=5.12", "mypy~=1.11.2",
"black~=23.3.0", "ruff~=0.6.8",
"mypy~=1.3",
"ruff~=0.0.292",
"pytest~=7.3", "pytest~=7.3",
"pytest-asyncio~=0.21", "pytest-asyncio~=0.21",
] ]
@ -55,6 +53,10 @@ backend-path = ["build_backend"]
version = {attr = "telethon.version.__version__"} version = {attr = "telethon.version.__version__"}
[tool.ruff] [tool.ruff]
line-length = 120
[tool.ruff.lint]
select = ["F", "E", "W", "I"]
ignore = [ ignore = [
"E501", # formatter takes care of lines that are too long besides documentation "E501", # formatter takes care of lines that are too long besides documentation
] ]

View File

@ -31,9 +31,7 @@ async def complete_login(client: Client, auth: abcs.auth.Authorization) -> User:
assert isinstance(auth, types.auth.Authorization) assert isinstance(auth, types.auth.Authorization)
assert isinstance(auth.user, types.User) assert isinstance(auth.user, types.User)
user = User._from_raw(auth.user) user = User._from_raw(auth.user)
client._session.user = SessionUser( client._session.user = SessionUser(id=user.id, dc=client._sender.dc_id, bot=user.bot, username=user.username)
id=user.id, dc=client._sender.dc_id, bot=user.bot, username=user.username
)
client._chat_hashes.set_self_user(user.id, user.bot) client._chat_hashes.set_self_user(user.id, user.bot)
@ -56,9 +54,7 @@ async def complete_login(client: Client, auth: abcs.auth.Authorization) -> User:
async def handle_migrate(client: Client, dc_id: Optional[int]) -> None: async def handle_migrate(client: Client, dc_id: Optional[int]) -> None:
assert dc_id is not None assert dc_id is not None
sender, client._session.dcs = await connect_sender( sender, client._session.dcs = await connect_sender(client._config, client._session.dcs, DataCenter(id=dc_id))
client._config, client._session.dcs, DataCenter(id=dc_id)
)
async with client._sender_lock: async with client._sender_lock:
client._sender = sender client._sender = sender
@ -177,9 +173,7 @@ async def interactive_login(
user = await self.check_password(user_or_token, password) user = await self.check_password(user_or_token, password)
else: else:
while True: while True:
print( print("Please enter your password (prompt is hidden; type and press enter)")
"Please enter your password (prompt is hidden; type and press enter)"
)
password = getpass.getpass(": ") password = getpass.getpass(": ")
try: try:
user = await self.check_password(user_or_token, password) user = await self.check_password(user_or_token, password)
@ -208,13 +202,9 @@ async def get_password_information(client: Client) -> PasswordToken:
return PasswordToken._new(result) return PasswordToken._new(result)
async def check_password( async def check_password(self: Client, token: PasswordToken, password: str | bytes) -> User:
self: Client, token: PasswordToken, password: str | bytes
) -> User:
algo = token._password.current_algo algo = token._password.current_algo
if not isinstance( if not isinstance(algo, types.PasswordKdfAlgoSha256Sha256Pbkdf2HmacshA512Iter100000Sha256ModPow):
algo, types.PasswordKdfAlgoSha256Sha256Pbkdf2HmacshA512Iter100000Sha256ModPow
):
raise RuntimeError("unrecognised 2FA algorithm") raise RuntimeError("unrecognised 2FA algorithm")
if not two_factor_auth.check_p_and_g(algo.p, algo.g): if not two_factor_auth.check_p_and_g(algo.p, algo.g):

View File

@ -38,11 +38,7 @@ class InlineResults(metaclass=NoPublicConstructor):
result = await self._client( result = await self._client(
functions.messages.get_inline_bot_results( functions.messages.get_inline_bot_results(
bot=self._bot, bot=self._bot,
peer=( peer=(self._peer._to_input_peer() if self._peer else types.InputPeerEmpty()),
self._peer._to_input_peer()
if self._peer
else types.InputPeerEmpty()
),
geo_point=None, geo_point=None,
query=self._query, query=self._query,
offset=self._offset, offset=self._offset,
@ -51,12 +47,8 @@ class InlineResults(metaclass=NoPublicConstructor):
assert isinstance(result, types.messages.BotResults) assert isinstance(result, types.messages.BotResults)
self._offset = result.next_offset self._offset = result.next_offset
for r in reversed(result.results): for r in reversed(result.results):
assert isinstance( assert isinstance(r, (types.BotInlineMediaResult, types.BotInlineResult))
r, (types.BotInlineMediaResult, types.BotInlineResult) self._buffer.append(InlineResult._create(self._client, result, r, self._peer))
)
self._buffer.append(
InlineResult._create(self._client, result, r, self._peer)
)
if not self._buffer: if not self._buffer:
self._offset = None self._offset = None

View File

@ -53,9 +53,7 @@ class ParticipantList(AsyncList[Participant]):
seen_count = len(self._seen) seen_count = len(self._seen)
for p in chanp.participants: for p in chanp.participants:
part = Participant._from_raw_channel( part = Participant._from_raw_channel(self._client, self._peer, p, chat_map)
self._client, self._peer, p, chat_map
)
pid = part._peer_id() pid = part._peer_id()
if pid not in self._seen: if pid not in self._seen:
self._seen.add(pid) self._seen.add(pid)
@ -66,9 +64,7 @@ class ParticipantList(AsyncList[Participant]):
self._done = len(self._seen) == seen_count self._done = len(self._seen) == seen_count
else: else:
chatp = await self._client( chatp = await self._client(functions.messages.get_full_chat(chat_id=self._peer._to_input_chat()))
functions.messages.get_full_chat(chat_id=self._peer._to_input_chat())
)
assert isinstance(chatp, types.messages.ChatFull) assert isinstance(chatp, types.messages.ChatFull)
assert isinstance(chatp.full_chat, types.ChatFull) assert isinstance(chatp.full_chat, types.ChatFull)
@ -87,17 +83,14 @@ class ParticipantList(AsyncList[Participant]):
) )
elif isinstance(participants, types.ChatParticipants): elif isinstance(participants, types.ChatParticipants):
self._buffer.extend( self._buffer.extend(
Participant._from_raw_chat(self._client, self._peer, p, chat_map) Participant._from_raw_chat(self._client, self._peer, p, chat_map) for p in participants.participants
for p in participants.participants
) )
self._total = len(self._buffer) self._total = len(self._buffer)
self._done = True self._done = True
def get_participants( def get_participants(self: Client, chat: Group | Channel | GroupRef | ChannelRef, /) -> AsyncList[Participant]:
self: Client, chat: Group | Channel | GroupRef | ChannelRef, /
) -> AsyncList[Participant]:
return ParticipantList(self, chat._ref) return ParticipantList(self, chat._ref)
@ -137,9 +130,7 @@ class RecentActionList(AsyncList[RecentAction]):
self._offset = min(e.id for e in self._buffer) self._offset = min(e.id for e in self._buffer)
def get_admin_log( def get_admin_log(self: Client, chat: Group | Channel | GroupRef | ChannelRef, /) -> AsyncList[RecentAction]:
self: Client, chat: Group | Channel | GroupRef | ChannelRef, /
) -> AsyncList[RecentAction]:
return RecentActionList(self, chat._ref) return RecentActionList(self, chat._ref)
@ -174,11 +165,7 @@ class ProfilePhotoList(AsyncList[File]):
else: else:
raise RuntimeError("unexpected case") raise RuntimeError("unexpected case")
self._buffer.extend( self._buffer.extend(filter(None, (File._try_from_raw_photo(self._client, p) for p in photos)))
filter(
None, (File._try_from_raw_photo(self._client, p) for p in photos)
)
)
def get_profile_photos(self: Client, peer: Peer | PeerRef, /) -> AsyncList[File]: def get_profile_photos(self: Client, peer: Peer | PeerRef, /) -> AsyncList[File]:
@ -256,11 +243,7 @@ async def set_chat_default_restrictions(
*, *,
until: Optional[datetime.datetime] = None, until: Optional[datetime.datetime] = None,
) -> None: ) -> None:
banned_rights = ChatRestriction._set_to_raw( banned_rights = ChatRestriction._set_to_raw(set(restrictions), int(until.timestamp()) if until else 0x7FFFFFFF)
set(restrictions), int(until.timestamp()) if until else 0x7FFFFFFF
)
await self( await self(
functions.messages.edit_chat_default_banned_rights( functions.messages.edit_chat_default_banned_rights(peer=chat._ref._to_input_peer(), banned_rights=banned_rights)
peer=chat._ref._to_input_peer(), banned_rights=banned_rights
)
) )

View File

@ -237,9 +237,7 @@ class Client:
lang_code=lang_code or "en", lang_code=lang_code or "en",
catch_up=catch_up or False, catch_up=catch_up or False,
datacenter=datacenter, datacenter=datacenter,
flood_sleep_threshold=( flood_sleep_threshold=(60 if flood_sleep_threshold is None else flood_sleep_threshold),
60 if flood_sleep_threshold is None else flood_sleep_threshold
),
update_queue_limit=update_queue_limit, update_queue_limit=update_queue_limit,
base_logger=base_logger, base_logger=base_logger,
connector=connector or (lambda ip, port: asyncio.open_connection(ip, port)), connector=connector or (lambda ip, port: asyncio.open_connection(ip, port)),
@ -250,8 +248,8 @@ class Client:
self._message_box = MessageBox(base_logger=base_logger) self._message_box = MessageBox(base_logger=base_logger)
self._chat_hashes = ChatHashCache(None) self._chat_hashes = ChatHashCache(None)
self._last_update_limit_warn: Optional[float] = None self._last_update_limit_warn: Optional[float] = None
self._updates: asyncio.Queue[tuple[abcs.Update, dict[int, Peer]]] = ( self._updates: asyncio.Queue[tuple[abcs.Update, dict[int, Peer]]] = asyncio.Queue(
asyncio.Queue(maxsize=self._config.update_queue_limit or 0) maxsize=self._config.update_queue_limit or 0
) )
self._dispatcher: Optional[asyncio.Task[None]] = None self._dispatcher: Optional[asyncio.Task[None]] = None
self._handlers: dict[ self._handlers: dict[
@ -411,9 +409,7 @@ class Client:
""" """
await delete_dialog(self, dialog) await delete_dialog(self, dialog)
async def delete_messages( async def delete_messages(self, chat: Peer | PeerRef, /, message_ids: list[int], *, revoke: bool = True) -> int:
self, chat: Peer | PeerRef, /, message_ids: list[int], *, revoke: bool = True
) -> int:
""" """
Delete messages. Delete messages.
@ -653,9 +649,7 @@ class Client:
""" """
return await forward_messages(self, target, message_ids, source) return await forward_messages(self, target, message_ids, source)
def get_admin_log( def get_admin_log(self, chat: Group | Channel | GroupRef | ChannelRef, /) -> AsyncList[RecentAction]:
self, chat: Group | Channel | GroupRef | ChannelRef, /
) -> AsyncList[RecentAction]:
""" """
Get the recent actions from the administrator's log. Get the recent actions from the administrator's log.
@ -757,9 +751,7 @@ class Client:
""" """
return get_file_bytes(self, media) return get_file_bytes(self, media)
def get_handler_filter( def get_handler_filter(self, handler: Callable[[Event], Awaitable[Any]], /) -> Optional[FilterType]:
self, handler: Callable[[Event], Awaitable[Any]], /
) -> Optional[FilterType]:
""" """
Get the filter associated to the given event handler. Get the filter associated to the given event handler.
@ -861,13 +853,9 @@ class Client:
async for message in reversed(client.get_messages(chat)): async for message in reversed(client.get_messages(chat)):
print(message.sender.name, ':', message.markdown_text) print(message.sender.name, ':', message.markdown_text)
""" """
return get_messages( return get_messages(self, chat, limit, offset_id=offset_id, offset_date=offset_date)
self, chat, limit, offset_id=offset_id, offset_date=offset_date
)
def get_messages_with_ids( def get_messages_with_ids(self, chat: Peer | PeerRef, /, message_ids: list[int]) -> AsyncList[Message]:
self, chat: Peer | PeerRef, /, message_ids: list[int]
) -> AsyncList[Message]:
""" """
Get the full message objects from the corresponding message identifiers. Get the full message objects from the corresponding message identifiers.
@ -891,9 +879,7 @@ class Client:
""" """
return get_messages_with_ids(self, chat, message_ids) return get_messages_with_ids(self, chat, message_ids)
def get_participants( def get_participants(self, chat: Group | Channel | GroupRef | ChannelRef, /) -> AsyncList[Participant]:
self, chat: Group | Channel | GroupRef | ChannelRef, /
) -> AsyncList[Participant]:
""" """
Get the participants in a group or channel, along with their permissions. Get the participants in a group or channel, along with their permissions.
@ -981,9 +967,7 @@ class Client:
""" """
return await inline_query(self, bot, query, peer=peer) return await inline_query(self, bot, query, peer=peer)
async def interactive_login( async def interactive_login(self, phone_or_token: Optional[str] = None, *, password: Optional[str] = None) -> User:
self, phone_or_token: Optional[str] = None, *, password: Optional[str] = None
) -> User:
""" """
Begin an interactive login if needed. Begin an interactive login if needed.
If the account was already logged-in, this method simply returns :term:`yourself`. If the account was already logged-in, this method simply returns :term:`yourself`.
@ -1035,9 +1019,7 @@ class Client:
def on( def on(
self, event_cls: Type[Event], /, filter: Optional[FilterType] = None self, event_cls: Type[Event], /, filter: Optional[FilterType] = None
) -> Callable[ ) -> Callable[[Callable[[Event], Awaitable[Any]]], Callable[[Event], Awaitable[Any]]]:
[Callable[[Event], Awaitable[Any]]], Callable[[Event], Awaitable[Any]]
]:
""" """
Register the decorated function to be invoked when the provided event type occurs. Register the decorated function to be invoked when the provided event type occurs.
@ -1125,9 +1107,7 @@ class Client:
""" """
return prepare_album(self) return prepare_album(self)
async def read_message( async def read_message(self, chat: Peer | PeerRef, /, message_id: int | Literal["all"]) -> None:
self, chat: Peer | PeerRef, /, message_id: int | Literal["all"]
) -> None:
""" """
Mark messages as read. Mark messages as read.
@ -1158,9 +1138,7 @@ class Client:
""" """
await read_message(self, chat, message_id) await read_message(self, chat, message_id)
def remove_event_handler( def remove_event_handler(self, handler: Callable[[Event], Awaitable[Any]], /) -> None:
self, handler: Callable[[Event], Awaitable[Any]], /
) -> None:
""" """
Remove the handler as a function to be called when events occur. Remove the handler as a function to be called when events occur.
This is simply the opposite of :meth:`add_event_handler`. This is simply the opposite of :meth:`add_event_handler`.
@ -1324,9 +1302,7 @@ class Client:
async for message in client.search_all_messages(query='hello'): async for message in client.search_all_messages(query='hello'):
print(message.text) print(message.text)
""" """
return search_all_messages( return search_all_messages(self, limit, query=query, offset_id=offset_id, offset_date=offset_date)
self, limit, query=query, offset_id=offset_id, offset_date=offset_date
)
def search_messages( def search_messages(
self, self,
@ -1372,9 +1348,7 @@ class Client:
async for message in client.search_messages(chat, query='hello'): async for message in client.search_messages(chat, query='hello'):
print(message.text) print(message.text)
""" """
return search_messages( return search_messages(self, chat, limit, query=query, offset_id=offset_id, offset_date=offset_date)
self, chat, limit, query=query, offset_id=offset_id, offset_date=offset_date
)
async def send_audio( async def send_audio(
self, self,
@ -1975,9 +1949,7 @@ class Client:
:meth:`telethon.types.Participant.set_restrictions` :meth:`telethon.types.Participant.set_restrictions`
""" """
await set_participant_restrictions( await set_participant_restrictions(self, chat, participant, restrictions, until=until)
self, chat, participant, restrictions, until=until
)
async def sign_in(self, token: LoginToken, code: str) -> User | PasswordToken: async def sign_in(self, token: LoginToken, code: str) -> User | PasswordToken:
""" """
@ -2024,9 +1996,7 @@ class Client:
""" """
await sign_out(self) await sign_out(self)
async def unpin_message( async def unpin_message(self, chat: Peer | PeerRef, /, message_id: int | Literal["all"]) -> None:
self, chat: Peer | PeerRef, /, message_id: int | Literal["all"]
) -> None:
""" """
Unpin one or all messages from the top. Unpin one or all messages from the top.

View File

@ -51,9 +51,7 @@ class DialogList(AsyncList[Dialog]):
chat_map = build_chat_map(self._client, result.users, result.chats) chat_map = build_chat_map(self._client, result.users, result.chats)
msg_map = build_msg_map(self._client, result.messages, chat_map) msg_map = build_msg_map(self._client, result.messages, chat_map)
self._buffer.extend( self._buffer.extend(Dialog._from_raw(self._client, d, chat_map, msg_map) for d in result.dialogs)
Dialog._from_raw(self._client, d, chat_map, msg_map) for d in result.dialogs
)
def get_dialogs(self: Client) -> AsyncList[Dialog]: def get_dialogs(self: Client) -> AsyncList[Dialog]:
@ -130,9 +128,7 @@ async def edit_draft(
reply_to: Optional[int] = None, reply_to: Optional[int] = None,
) -> Draft: ) -> Draft:
peer = peer._ref peer = peer._ref
message, entities = parse_message( message, entities = parse_message(text=text, markdown=markdown, html=html, allow_empty=False)
text=text, markdown=markdown, html=html, allow_empty=False
)
result = await self( result = await self(
functions.messages.save_draft( functions.messages.save_draft(

View File

@ -189,15 +189,11 @@ async def send_file(
reply_to: Optional[int] = None, reply_to: Optional[int] = None,
keyboard: Optional[KeyboardType] = None, keyboard: Optional[KeyboardType] = None,
) -> Message: ) -> Message:
message, entities = parse_message( message, entities = parse_message(text=caption, markdown=caption_markdown, html=caption_html, allow_empty=True)
text=caption, markdown=caption_markdown, html=caption_html, allow_empty=True
)
# Re-send existing file. # Re-send existing file.
if isinstance(file, File): if isinstance(file, File):
return await do_send_file( return await do_send_file(self, chat, file._input_media, message, entities, reply_to, keyboard)
self, chat, file._input_media, message, entities, reply_to, keyboard
)
# URLs are handled early as they can't use any other attributes either. # URLs are handled early as they can't use any other attributes either.
input_media: abcs.InputMedia input_media: abcs.InputMedia
@ -212,16 +208,10 @@ async def send_file(
else: else:
as_photo = False as_photo = False
if as_photo: if as_photo:
input_media = types.InputMediaPhotoExternal( input_media = types.InputMediaPhotoExternal(spoiler=False, url=file, ttl_seconds=None)
spoiler=False, url=file, ttl_seconds=None
)
else: else:
input_media = types.InputMediaDocumentExternal( input_media = types.InputMediaDocumentExternal(spoiler=False, url=file, ttl_seconds=None)
spoiler=False, url=file, ttl_seconds=None return await do_send_file(self, chat, input_media, message, entities, reply_to, keyboard)
)
return await do_send_file(
self, chat, input_media, message, entities, reply_to, keyboard
)
input_file, name = await upload(self, file, size, name) input_file, name = await upload(self, file, size, name)
@ -285,9 +275,7 @@ async def send_file(
ttl_seconds=None, ttl_seconds=None,
) )
return await do_send_file( return await do_send_file(self, chat, input_media, message, entities, reply_to, keyboard)
self, chat, input_media, message, entities, reply_to, keyboard
)
async def do_send_file( async def do_send_file(
@ -309,11 +297,7 @@ async def do_send_file(
noforwards=False, noforwards=False,
update_stickersets_order=False, update_stickersets_order=False,
peer=chat._ref._to_input_peer(), peer=chat._ref._to_input_peer(),
reply_to=( reply_to=(types.InputReplyToMessage(reply_to_msg_id=reply_to, top_msg_id=None) if reply_to else None),
types.InputReplyToMessage(reply_to_msg_id=reply_to, top_msg_id=None)
if reply_to
else None
),
media=input_media, media=input_media,
message=message, message=message,
random_id=random_id, random_id=random_id,
@ -395,11 +379,7 @@ async def do_upload(
) )
) )
else: else:
await client( await client(functions.upload.save_file_part(file_id=file_id, file_part=part, bytes=to_store))
functions.upload.save_file_part(
file_id=file_id, file_part=part, bytes=to_store
)
)
hash_md5.update(to_store) hash_md5.update(to_store)
buffer.clear() buffer.clear()
@ -458,9 +438,7 @@ def get_file_bytes(self: Client, media: File, /) -> AsyncList[bytes]:
return FileBytesList(self, media) return FileBytesList(self, media)
async def download( async def download(self: Client, media: File, /, file: str | Path | OutFileLike) -> None:
self: Client, media: File, /, file: str | Path | OutFileLike
) -> None:
fd = OutWrapper(file) fd = OutWrapper(file)
try: try:
async for chunk in get_file_bytes(self, media): async for chunk in get_file_bytes(self, media):

View File

@ -46,9 +46,7 @@ async def send_message(
update_stickersets_order=False, update_stickersets_order=False,
peer=chat._ref._to_input_peer(), peer=chat._ref._to_input_peer(),
reply_to=( reply_to=(
types.InputReplyToMessage( types.InputReplyToMessage(reply_to_msg_id=text.replied_message_id, top_msg_id=None)
reply_to_msg_id=text.replied_message_id, top_msg_id=None
)
if text.replied_message_id if text.replied_message_id
else None else None
), ),
@ -60,9 +58,7 @@ async def send_message(
send_as=None, send_as=None,
) )
else: else:
message, entities = parse_message( message, entities = parse_message(text=text, markdown=markdown, html=html, allow_empty=False)
text=text, markdown=markdown, html=html, allow_empty=False
)
request = functions.messages.send_message( request = functions.messages.send_message(
no_webpage=not link_preview, no_webpage=not link_preview,
silent=False, silent=False,
@ -71,11 +67,7 @@ async def send_message(
noforwards=False, noforwards=False,
update_stickersets_order=False, update_stickersets_order=False,
peer=chat._ref._to_input_peer(), peer=chat._ref._to_input_peer(),
reply_to=( reply_to=(types.InputReplyToMessage(reply_to_msg_id=reply_to, top_msg_id=None) if reply_to else None),
types.InputReplyToMessage(reply_to_msg_id=reply_to, top_msg_id=None)
if reply_to
else None
),
message=message, message=message,
random_id=random_id, random_id=random_id,
reply_markup=keyboard._raw if keyboard else None, reply_markup=keyboard._raw if keyboard else None,
@ -91,11 +83,7 @@ async def send_message(
{}, {},
out=result.out, out=result.out,
id=result.id, id=result.id,
from_id=( from_id=(types.PeerUser(user_id=self._session.user.id) if self._session.user else None),
types.PeerUser(user_id=self._session.user.id)
if self._session.user
else None
),
peer_id=chat._ref._to_peer(), peer_id=chat._ref._to_peer(),
reply_to=( reply_to=(
types.MessageReplyHeader( types.MessageReplyHeader(
@ -130,9 +118,7 @@ async def edit_message(
link_preview: bool = False, link_preview: bool = False,
keyboard: Optional[KeyboardType] = None, keyboard: Optional[KeyboardType] = None,
) -> Message: ) -> Message:
message, entities = parse_message( message, entities = parse_message(text=text, markdown=markdown, html=html, allow_empty=False)
text=text, markdown=markdown, html=html, allow_empty=False
)
return self._build_message_map( return self._build_message_map(
await self( await self(
functions.messages.edit_message( functions.messages.edit_message(
@ -160,15 +146,9 @@ async def delete_messages(
) -> int: ) -> int:
peer = chat._ref peer = chat._ref
if isinstance(peer, ChannelRef): if isinstance(peer, ChannelRef):
affected = await self( affected = await self(functions.channels.delete_messages(channel=peer._to_input_channel(), id=message_ids))
functions.channels.delete_messages(
channel=peer._to_input_channel(), id=message_ids
)
)
else: else:
affected = await self( affected = await self(functions.messages.delete_messages(revoke=revoke, id=message_ids))
functions.messages.delete_messages(revoke=revoke, id=message_ids)
)
assert isinstance(affected, types.messages.AffectedMessages) assert isinstance(affected, types.messages.AffectedMessages)
return affected.pts_count return affected.pts_count
@ -205,9 +185,7 @@ class MessageList(AsyncList[Message]):
super().__init__() super().__init__()
self._reversed = False self._reversed = False
def _extend_buffer( def _extend_buffer(self, client: Client, messages: abcs.messages.Messages) -> dict[int, Peer]:
self, client: Client, messages: abcs.messages.Messages
) -> dict[int, Peer]:
if isinstance(messages, types.messages.MessagesNotModified): if isinstance(messages, types.messages.MessagesNotModified):
self._total = messages.count self._total = messages.count
return {} return {}
@ -215,9 +193,7 @@ class MessageList(AsyncList[Message]):
if isinstance(messages, types.messages.Messages): if isinstance(messages, types.messages.Messages):
self._total = len(messages.messages) self._total = len(messages.messages)
self._done = True self._done = True
elif isinstance( elif isinstance(messages, (types.messages.MessagesSlice, types.messages.ChannelMessages)):
messages, (types.messages.MessagesSlice, types.messages.ChannelMessages)
):
self._total = messages.count self._total = messages.count
else: else:
raise RuntimeError("unexpected case") raise RuntimeError("unexpected case")
@ -225,9 +201,7 @@ class MessageList(AsyncList[Message]):
chat_map = build_chat_map(client, messages.users, messages.chats) chat_map = build_chat_map(client, messages.users, messages.chats)
self._buffer.extend( self._buffer.extend(
Message._from_raw(client, m, chat_map) Message._from_raw(client, m, chat_map)
for m in ( for m in (reversed(messages.messages) if self._reversed else messages.messages)
reversed(messages.messages) if self._reversed else messages.messages
)
) )
return chat_map return chat_map
@ -235,11 +209,7 @@ class MessageList(AsyncList[Message]):
self, self,
) -> types.Message | types.MessageService | types.MessageEmpty: ) -> types.Message | types.MessageService | types.MessageEmpty:
return next( return next(
( (m._raw for m in reversed(self._buffer) if not isinstance(m._raw, types.MessageEmpty)),
m._raw
for m in reversed(self._buffer)
if not isinstance(m._raw, types.MessageEmpty)
),
types.MessageEmpty(id=0, peer_id=None), types.MessageEmpty(id=0, peer_id=None),
) )
@ -335,14 +305,10 @@ class CherryPickedList(MessageList):
if isinstance(self._peer, ChannelRef): if isinstance(self._peer, ChannelRef):
result = await self._client( result = await self._client(
functions.channels.get_messages( functions.channels.get_messages(channel=self._peer._to_input_channel(), id=self._ids[:100])
channel=self._peer._to_input_channel(), id=self._ids[:100]
)
) )
else: else:
result = await self._client( result = await self._client(functions.messages.get_messages(id=self._ids[:100]))
functions.messages.get_messages(id=self._ids[:100])
)
self._extend_buffer(self._client, result) self._extend_buffer(self._client, result)
self._ids = self._ids[100:] self._ids = self._ids[100:]
@ -492,9 +458,7 @@ def search_all_messages(
) )
async def pin_message( async def pin_message(self: Client, chat: Peer | PeerRef, /, message_id: int) -> Message:
self: Client, chat: Peer | PeerRef, /, message_id: int
) -> Message:
return self._build_message_map( return self._build_message_map(
await self( await self(
functions.messages.update_pinned_message( functions.messages.update_pinned_message(
@ -509,9 +473,7 @@ async def pin_message(
).get_single() ).get_single()
async def unpin_message( async def unpin_message(self: Client, chat: Peer | PeerRef, /, message_id: int | Literal["all"]) -> None:
self: Client, chat: Peer | PeerRef, /, message_id: int | Literal["all"]
) -> None:
if message_id == "all": if message_id == "all":
await self( await self(
functions.messages.unpin_all_messages( functions.messages.unpin_all_messages(
@ -531,25 +493,15 @@ async def unpin_message(
) )
async def read_message( async def read_message(self: Client, chat: Peer | PeerRef, /, message_id: int | Literal["all"]) -> None:
self: Client, chat: Peer | PeerRef, /, message_id: int | Literal["all"]
) -> None:
if message_id == "all": if message_id == "all":
message_id = 0 message_id = 0
peer = chat._ref peer = chat._ref
if isinstance(peer, ChannelRef): if isinstance(peer, ChannelRef):
await self( await self(functions.channels.read_history(channel=peer._to_input_channel(), max_id=message_id))
functions.channels.read_history(
channel=peer._to_input_channel(), max_id=message_id
)
)
else: else:
await self( await self(functions.messages.read_history(peer=peer._ref._to_input_peer(), max_id=message_id))
functions.messages.read_history(
peer=peer._ref._to_input_peer(), max_id=message_id
)
)
class MessageMap: class MessageMap:
@ -584,9 +536,7 @@ class MessageMap:
def _empty(self, id: int = 0) -> Message: def _empty(self, id: int = 0) -> Message:
return Message._from_raw( return Message._from_raw(
self._client, self._client,
types.MessageEmpty( types.MessageEmpty(id=id, peer_id=self._peer._to_peer() if self._peer else None),
id=id, peer_id=self._peer._to_peer() if self._peer else None
),
{}, {},
) )

View File

@ -82,16 +82,9 @@ async def connect_sender(
# Only the ID of the input DC may be known. # Only the ID of the input DC may be known.
# Find the corresponding address and authentication key if needed. # Find the corresponding address and authentication key if needed.
addr = dc.ipv4_addr or next( addr = dc.ipv4_addr or next(
d.ipv4_addr d.ipv4_addr for d in itertools.chain(known_dcs, KNOWN_DCS) if d.id == dc.id and d.ipv4_addr
for d in itertools.chain(known_dcs, KNOWN_DCS)
if d.id == dc.id and d.ipv4_addr
)
auth = (
None
if force_auth_gen
else dc.auth
or (next((d.auth for d in known_dcs if d.id == dc.id and d.auth), None))
) )
auth = None if force_auth_gen else dc.auth or (next((d.auth for d in known_dcs if d.id == dc.id and d.auth), None))
sender = await do_connect_sender( sender = await do_connect_sender(
Full(), Full(),
@ -122,9 +115,7 @@ async def connect_sender(
) )
except BadStatus as e: except BadStatus as e:
if e.status == 404 and auth: if e.status == 404 and auth:
dc = DataCenter( dc = DataCenter(id=dc.id, ipv4_addr=dc.ipv4_addr, ipv6_addr=dc.ipv6_addr, auth=None)
id=dc.id, ipv4_addr=dc.ipv4_addr, ipv6_addr=dc.ipv6_addr, auth=None
)
config.base_logger.warning( config.base_logger.warning(
"datacenter could not find stored auth; will retry generating a new one: %s", "datacenter could not find stored auth; will retry generating a new one: %s",
dc, dc,
@ -174,12 +165,8 @@ async def connect(self: Client) -> None:
if session := await self._storage.load(): if session := await self._storage.load():
self._session = session self._session = session
datacenter = self._config.datacenter or DataCenter( datacenter = self._config.datacenter or DataCenter(id=self._session.user.dc if self._session.user else DEFAULT_DC)
id=self._session.user.dc if self._session.user else DEFAULT_DC self._sender, self._session.dcs = await connect_sender(self._config, self._session.dcs, datacenter)
)
self._sender, self._session.dcs = await connect_sender(
self._config, self._session.dcs, datacenter
)
if self._message_box.is_empty() and self._session.user: if self._message_box.is_empty() and self._session.user:
try: try:
@ -193,9 +180,7 @@ async def connect(self: Client) -> None:
me = await self.get_me() me = await self.get_me()
assert me is not None assert me is not None
self._chat_hashes.set_self_user(me.id, me.bot) self._chat_hashes.set_self_user(me.id, me.bot)
self._session.user = SessionUser( self._session.user = SessionUser(id=me.id, dc=self._sender.dc_id, bot=me.bot, username=me.username)
id=me.id, dc=self._sender.dc_id, bot=me.bot, username=me.username
)
self._dispatcher = asyncio.create_task(dispatcher(self)) self._dispatcher = asyncio.create_task(dispatcher(self))
@ -214,18 +199,14 @@ async def disconnect(self: Client) -> None:
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
except Exception: except Exception:
self._config.base_logger.exception( self._config.base_logger.exception("unhandled exception when cancelling dispatcher; this is a bug")
"unhandled exception when cancelling dispatcher; this is a bug"
)
finally: finally:
self._dispatcher = None self._dispatcher = None
try: try:
await sender.disconnect() await sender.disconnect()
except Exception: except Exception:
self._config.base_logger.exception( self._config.base_logger.exception("unhandled exception during disconnect; this is a bug")
"unhandled exception during disconnect; this is a bug"
)
try: try:
if self._session.user: if self._session.user:

View File

@ -40,9 +40,7 @@ def add_event_handler(
self._handlers.setdefault(event_cls, []).append((handler, filter)) self._handlers.setdefault(event_cls, []).append((handler, filter))
def remove_event_handler( def remove_event_handler(self: Client, handler: Callable[[Event], Awaitable[Any]], /) -> None:
self: Client, handler: Callable[[Event], Awaitable[Any]], /
) -> None:
for event_cls, handlers in tuple(self._handlers.items()): for event_cls, handlers in tuple(self._handlers.items()):
for i in reversed(range(len(handlers))): for i in reversed(range(len(handlers))):
if handlers[i][0] == handler: if handlers[i][0] == handler:
@ -51,9 +49,7 @@ def remove_event_handler(
del self._handlers[event_cls] del self._handlers[event_cls]
def get_handler_filter( def get_handler_filter(self: Client, handler: Callable[[Event], Awaitable[Any]], /) -> Optional[FilterType]:
self: Client, handler: Callable[[Event], Awaitable[Any]], /
) -> Optional[FilterType]:
for handlers in self._handlers.values(): for handlers in self._handlers.values():
for h, f in handlers: for h, f in handlers:
if h == handler: if h == handler:
@ -84,9 +80,7 @@ def process_socket_updates(client: Client, all_updates: list[abcs.Updates]) -> N
return return
try: try:
result, users, chats = client._message_box.process_updates( result, users, chats = client._message_box.process_updates(updates, client._chat_hashes)
updates, client._chat_hashes
)
except Gap: except Gap:
return return
@ -107,8 +101,7 @@ def extend_update_queue(
except asyncio.QueueFull: except asyncio.QueueFull:
now = asyncio.get_running_loop().time() now = asyncio.get_running_loop().time()
if client._last_update_limit_warn is None or ( if client._last_update_limit_warn is None or (
now - client._last_update_limit_warn now - client._last_update_limit_warn > UPDATE_LIMIT_EXCEEDED_LOG_COOLDOWN
> UPDATE_LIMIT_EXCEEDED_LOG_COOLDOWN
): ):
client._config.base_logger.warning( client._config.base_logger.warning(
"updates are being dropped because limit=%d has been reached", "updates are being dropped because limit=%d has been reached",

View File

@ -53,15 +53,11 @@ def resolved_peer_to_chat(client: Client, resolved: abcs.contacts.ResolvedPeer)
async def resolve_phone(self: Client, phone: str, /) -> Peer: async def resolve_phone(self: Client, phone: str, /) -> Peer:
return resolved_peer_to_chat( return resolved_peer_to_chat(self, await self(functions.contacts.resolve_phone(phone=phone)))
self, await self(functions.contacts.resolve_phone(phone=phone))
)
async def resolve_username(self: Client, username: str, /) -> Peer: async def resolve_username(self: Client, username: str, /) -> Peer:
return resolved_peer_to_chat( return resolved_peer_to_chat(self, await self(functions.contacts.resolve_username(username=username)))
self, await self(functions.contacts.resolve_username(username=username))
)
async def resolve_peers(self: Client, peers: Sequence[Peer | PeerRef], /) -> list[Peer]: async def resolve_peers(self: Client, peers: Sequence[Peer | PeerRef], /) -> list[Peer]:
@ -99,8 +95,4 @@ async def resolve_peers(self: Client, peers: Sequence[Peer | PeerRef], /) -> lis
chats.extend(ret_chats.chats) chats.extend(ret_chats.chats)
chat_map = build_chat_map(self, users, chats) chat_map = build_chat_map(self, users, chats)
return [ return [chat_map.get(ref.identifier) or expand_peer(self, ref._to_peer(), broadcast=None) for ref in refs]
chat_map.get(ref.identifier)
or expand_peer(self, ref._to_peer(), broadcast=None)
for ref in refs
]

View File

@ -34,17 +34,13 @@ def from_name(name: str, *, _cache: dict[str, Type[RpcError]] = {}) -> Type[RpcE
return _cache[name] return _cache[name]
def adapt_rpc( def adapt_rpc(error: RpcError, *, _cache: dict[tuple[int, str], Type[RpcError]] = {}) -> RpcError:
error: RpcError, *, _cache: dict[tuple[int, str], Type[RpcError]] = {}
) -> RpcError:
code = canonicalize_code(error.code) code = canonicalize_code(error.code)
name = canonicalize_name(error.name) name = canonicalize_name(error.name)
tup = code, name tup = code, name
if tup not in _cache: if tup not in _cache:
_cache[tup] = type(pretty_name(name), (from_code(code), from_name(name)), {}) _cache[tup] = type(pretty_name(name), (from_code(code), from_name(name)), {})
return _cache[tup]( return _cache[tup](code=error.code, name=error.name, value=error.value, caused_by=error._caused_by)
code=error.code, name=error.name, value=error.value, caused_by=error._caused_by
)
class ErrorFactory: class ErrorFactory:
@ -55,9 +51,7 @@ class ErrorFactory:
return from_code(int(m[1])) return from_code(int(m[1]))
else: else:
adapted = adapt_user_name(name) adapted = adapt_user_name(name)
if pretty_name(canonicalize_name(adapted)) != name or re.match( if pretty_name(canonicalize_name(adapted)) != name or re.match(r"[A-Z]{2}", name):
r"[A-Z]{2}", name
):
raise AttributeError(f"error subclass names must be CamelCase: {name}") raise AttributeError(f"error subclass names must be CamelCase: {name}")
return from_name(adapted) return from_name(adapted)

View File

@ -24,9 +24,7 @@ class Event(abc.ABC, metaclass=NoPublicConstructor):
@classmethod @classmethod
@abc.abstractmethod @abc.abstractmethod
def _try_from_update( def _try_from_update(cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]) -> Optional[Self]:
cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]
) -> Optional[Self]:
pass pass
@ -52,9 +50,7 @@ class Raw(Event):
self._chat_map = chat_map self._chat_map = chat_map
@classmethod @classmethod
def _try_from_update( def _try_from_update(cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]) -> Optional[Self]:
cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]
) -> Optional[Self]:
return cls._create(client, update, chat_map) return cls._create(client, update, chat_map)

View File

@ -65,9 +65,7 @@ class Any(Combinable):
__slots__ = ("_filters",) __slots__ = ("_filters",)
def __init__( def __init__(self, filter1: FilterType, filter2: FilterType, *filters: FilterType) -> None:
self, filter1: FilterType, filter2: FilterType, *filters: FilterType
) -> None:
self._filters = (filter1, filter2, *filters) self._filters = (filter1, filter2, *filters)
@property @property
@ -111,9 +109,7 @@ class All(Combinable):
__slots__ = ("_filters",) __slots__ = ("_filters",)
def __init__( def __init__(self, filter1: FilterType, filter2: FilterType, *filters: FilterType) -> None:
self, filter1: FilterType, filter2: FilterType, *filters: FilterType
) -> None:
self._filters = (filter1, filter2, *filters) self._filters = (filter1, filter2, *filters)
@property @property

View File

@ -24,15 +24,11 @@ class NewMessage(Event, Message):
""" """
@classmethod @classmethod
def _try_from_update( def _try_from_update(cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]) -> Optional[Self]:
cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]
) -> Optional[Self]:
if isinstance(update, (types.UpdateNewMessage, types.UpdateNewChannelMessage)): if isinstance(update, (types.UpdateNewMessage, types.UpdateNewChannelMessage)):
if isinstance(update.message, types.Message): if isinstance(update.message, types.Message):
return cls._from_raw(client, update.message, chat_map) return cls._from_raw(client, update.message, chat_map)
elif isinstance( elif isinstance(update, (types.UpdateShortMessage, types.UpdateShortChatMessage)):
update, (types.UpdateShortMessage, types.UpdateShortChatMessage)
):
raise RuntimeError("should have been handled by adaptor") raise RuntimeError("should have been handled by adaptor")
return None return None
@ -46,12 +42,8 @@ class MessageEdited(Event, Message):
""" """
@classmethod @classmethod
def _try_from_update( def _try_from_update(cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]) -> Optional[Self]:
cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer] if isinstance(update, (types.UpdateEditMessage, types.UpdateEditChannelMessage)):
) -> Optional[Self]:
if isinstance(
update, (types.UpdateEditMessage, types.UpdateEditChannelMessage)
):
return cls._from_raw(client, update.message, chat_map) return cls._from_raw(client, update.message, chat_map)
else: else:
return None return None
@ -74,9 +66,7 @@ class MessageDeleted(Event):
self._channel_id = channel_id self._channel_id = channel_id
@classmethod @classmethod
def _try_from_update( def _try_from_update(cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]) -> Optional[Self]:
cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]
) -> Optional[Self]:
if isinstance(update, types.UpdateDeleteMessages): if isinstance(update, types.UpdateDeleteMessages):
return cls._create(update.messages, None) return cls._create(update.messages, None)
elif isinstance(update, types.UpdateDeleteChannelMessages): elif isinstance(update, types.UpdateDeleteChannelMessages):
@ -122,9 +112,7 @@ class MessageRead(Event):
self._chat_map = chat_map self._chat_map = chat_map
@classmethod @classmethod
def _try_from_update( def _try_from_update(cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]) -> Optional[Self]:
cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]
) -> Optional[Self]:
if isinstance( if isinstance(
update, update,
( (
@ -139,9 +127,7 @@ class MessageRead(Event):
return None return None
def _peer(self) -> abcs.Peer: def _peer(self) -> abcs.Peer:
if isinstance( if isinstance(self._raw, (types.UpdateReadHistoryInbox, types.UpdateReadHistoryOutbox)):
self._raw, (types.UpdateReadHistoryInbox, types.UpdateReadHistoryOutbox)
):
return self._raw.peer return self._raw.peer
else: else:
return types.PeerChannel(channel_id=self._raw.channel_id) return types.PeerChannel(channel_id=self._raw.channel_id)
@ -154,9 +140,7 @@ class MessageRead(Event):
peer = self._peer() peer = self._peer()
pid = peer_id(peer) pid = peer_id(peer)
if pid not in self._chat_map: if pid not in self._chat_map:
self._chat_map[pid] = expand_peer( self._chat_map[pid] = expand_peer(self._client, peer, broadcast=getattr(self._raw, "post", None))
self._client, peer, broadcast=getattr(self._raw, "post", None)
)
return self._chat_map[pid] return self._chat_map[pid]
@property @property

View File

@ -31,9 +31,7 @@ class ButtonCallback(Event):
self._chat_map = chat_map self._chat_map = chat_map
@classmethod @classmethod
def _try_from_update( def _try_from_update(cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]) -> Optional[Self]:
cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]
) -> Optional[Self]:
if isinstance(update, types.UpdateBotCallbackQuery) and update.data is not None: if isinstance(update, types.UpdateBotCallbackQuery) and update.data is not None:
return cls._create(client, update, chat_map) return cls._create(client, update, chat_map)
else: else:
@ -83,11 +81,7 @@ class ButtonCallback(Event):
chat = self._chat_map.get(pid) or PeerRef._empty_from_peer(self._raw.peer) chat = self._chat_map.get(pid) or PeerRef._empty_from_peer(self._raw.peer)
lst = CherryPickedList(self._client, chat._ref, []) lst = CherryPickedList(self._client, chat._ref, [])
lst._ids.append( lst._ids.append(types.InputMessageCallbackQuery(id=self._raw.msg_id, query_id=self._raw.query_id))
types.InputMessageCallbackQuery(
id=self._raw.msg_id, query_id=self._raw.query_id
)
)
message = (await lst)[0] message = (await lst)[0]
@ -105,9 +99,7 @@ class InlineQuery(Event):
self._raw = update self._raw = update
@classmethod @classmethod
def _try_from_update( def _try_from_update(cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]) -> Optional[Self]:
cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]
) -> Optional[Self]:
if isinstance(update, types.UpdateBotInlineQuery): if isinstance(update, types.UpdateBotInlineQuery):
return cls._create(update) return cls._create(update)
else: else:

View File

@ -133,9 +133,7 @@ def parse(html: str) -> tuple[str, list[MessageEntity]]:
return del_surrogate(parser.text), parser.entities return del_surrogate(parser.text), parser.entities
ENTITY_TO_FORMATTER: dict[ ENTITY_TO_FORMATTER: dict[Type[MessageEntity], tuple[str, str] | Callable[[Any, str], tuple[str, str]]] = {
Type[MessageEntity], tuple[str, str] | Callable[[Any, str], tuple[str, str]]
] = {
MessageEntityBold: ("<strong>", "</strong>"), MessageEntityBold: ("<strong>", "</strong>"),
MessageEntityItalic: ("<em>", "</em>"), MessageEntityItalic: ("<em>", "</em>"),
MessageEntityCode: ("<code>", "</code>"), MessageEntityCode: ("<code>", "</code>"),
@ -196,12 +194,7 @@ def unparse(text: str, entities: Iterable[MessageEntity]) -> str:
while within_surrogate(text, at): while within_surrogate(text, at):
at += 1 at += 1
text = ( text = text[:at] + what + escape(text[at:next_escape_bound]) + text[next_escape_bound:]
text[:at]
+ what
+ escape(text[at:next_escape_bound])
+ text[next_escape_bound:]
)
next_escape_bound = at next_escape_bound = at
text = escape(text[:next_escape_bound]) + text[next_escape_bound:] text = escape(text[:next_escape_bound]) + text[next_escape_bound:]

View File

@ -82,9 +82,7 @@ def parse(message: str) -> tuple[str, list[MessageEntity]]:
else: else:
for entity in reversed(entities): for entity in reversed(entities):
if isinstance(entity, ty): if isinstance(entity, ty):
setattr( setattr(entity, "length", len(message) - getattr(entity, "offset", 0))
entity, "length", len(message) - getattr(entity, "offset", 0)
)
break break
parsed = MARKDOWN.parse(add_surrogate(message.strip())) parsed = MARKDOWN.parse(add_surrogate(message.strip()))
@ -103,25 +101,15 @@ def parse(message: str) -> tuple[str, list[MessageEntity]]:
if token.type in ("blockquote_close", "blockquote_open"): if token.type in ("blockquote_close", "blockquote_open"):
push(MessageEntityBlockquote) push(MessageEntityBlockquote)
elif token.type == "code_block": elif token.type == "code_block":
entities.append( entities.append(MessageEntityPre(offset=len(message), length=len(token.content), language=""))
MessageEntityPre(
offset=len(message), length=len(token.content), language=""
)
)
message += token.content message += token.content
elif token.type == "code_inline": elif token.type == "code_inline":
entities.append( entities.append(MessageEntityCode(offset=len(message), length=len(token.content)))
MessageEntityCode(offset=len(message), length=len(token.content))
)
message += token.content message += token.content
elif token.type in ("em_close", "em_open"): elif token.type in ("em_close", "em_open"):
push(MessageEntityItalic) push(MessageEntityItalic)
elif token.type == "fence": elif token.type == "fence":
entities.append( entities.append(MessageEntityPre(offset=len(message), length=len(token.content), language=token.info))
MessageEntityPre(
offset=len(message), length=len(token.content), language=token.info
)
)
message += token.content[:-1] # remove a single trailing newline message += token.content[:-1] # remove a single trailing newline
elif token.type == "hardbreak": elif token.type == "hardbreak":
message += "\n" message += "\n"
@ -130,9 +118,7 @@ def parse(message: str) -> tuple[str, list[MessageEntity]]:
elif token.type == "hr": elif token.type == "hr":
message += "\u2015\n\n" message += "\u2015\n\n"
elif token.type in ("link_close", "link_open"): elif token.type in ("link_close", "link_open"):
if ( if token.markup != "autolink": # telegram already picks up on these automatically
token.markup != "autolink"
): # telegram already picks up on these automatically
push(MessageEntityTextUrl, url=token.attrs.get("href")) push(MessageEntityTextUrl, url=token.attrs.get("href"))
elif token.type in ("s_close", "s_open"): elif token.type in ("s_close", "s_open"):
push(MessageEntityStrike) push(MessageEntityStrike)

View File

@ -6,11 +6,7 @@ def add_surrogate(text: str) -> str:
return "".join( return "".join(
# SMP -> Surrogate Pairs (Telegram offsets are calculated with these). # SMP -> Surrogate Pairs (Telegram offsets are calculated with these).
# See https://en.wikipedia.org/wiki/Plane_(Unicode)#Overview for more. # See https://en.wikipedia.org/wiki/Plane_(Unicode)#Overview for more.
( ("".join(chr(y) for y in struct.unpack("<HH", x.encode("utf-16le"))) if (0x10000 <= ord(x) <= 0x10FFFF) else x)
"".join(chr(y) for y in struct.unpack("<HH", x.encode("utf-16le")))
if (0x10000 <= ord(x) <= 0x10FFFF)
else x
)
for x in text for x in text
) )

View File

@ -52,20 +52,12 @@ class AlbumBuilder(metaclass=NoPublicConstructor):
input_media: abcs.InputMedia input_media: abcs.InputMedia
if try_get_url_path(file) is not None: if try_get_url_path(file) is not None:
assert isinstance(file, str) assert isinstance(file, str)
input_media = types.InputMediaPhotoExternal( input_media = types.InputMediaPhotoExternal(spoiler=False, url=file, ttl_seconds=None)
spoiler=False, url=file, ttl_seconds=None
)
else: else:
input_file, _ = await self._client._upload(file, size, "a.jpg") input_file, _ = await self._client._upload(file, size, "a.jpg")
input_media = types.InputMediaUploadedPhoto( input_media = types.InputMediaUploadedPhoto(spoiler=False, file=input_file, stickers=None, ttl_seconds=None)
spoiler=False, file=input_file, stickers=None, ttl_seconds=None
)
media = await self._client( media = await self._client(functions.messages.upload_media(peer=types.InputPeerSelf(), media=input_media))
functions.messages.upload_media(
peer=types.InputPeerSelf(), media=input_media
)
)
assert isinstance(media, types.MessageMediaPhoto) assert isinstance(media, types.MessageMediaPhoto)
assert isinstance(media.photo, types.Photo) assert isinstance(media.photo, types.Photo)
input_media = types.InputMediaPhoto( input_media = types.InputMediaPhoto(
@ -77,9 +69,7 @@ class AlbumBuilder(metaclass=NoPublicConstructor):
), ),
ttl_seconds=media.ttl_seconds, ttl_seconds=media.ttl_seconds,
) )
message, entities = parse_message( message, entities = parse_message(text=caption, markdown=caption_markdown, html=caption_html, allow_empty=True)
text=caption, markdown=caption_markdown, html=caption_html, allow_empty=True
)
self._medias.append( self._medias.append(
types.InputSingleMedia( types.InputSingleMedia(
media=input_media, media=input_media,
@ -132,9 +122,7 @@ class AlbumBuilder(metaclass=NoPublicConstructor):
input_media: abcs.InputMedia input_media: abcs.InputMedia
if try_get_url_path(file) is not None: if try_get_url_path(file) is not None:
assert isinstance(file, str) assert isinstance(file, str)
input_media = types.InputMediaDocumentExternal( input_media = types.InputMediaDocumentExternal(spoiler=False, url=file, ttl_seconds=None)
spoiler=False, url=file, ttl_seconds=None
)
else: else:
input_file, name = await self._client._upload(file, size, name) input_file, name = await self._client._upload(file, size, name)
if mime_type is None: if mime_type is None:
@ -168,11 +156,7 @@ class AlbumBuilder(metaclass=NoPublicConstructor):
ttl_seconds=None, ttl_seconds=None,
) )
media = await self._client( media = await self._client(functions.messages.upload_media(peer=types.InputPeerEmpty(), media=input_media))
functions.messages.upload_media(
peer=types.InputPeerEmpty(), media=input_media
)
)
assert isinstance(media, types.MessageMediaDocument) assert isinstance(media, types.MessageMediaDocument)
assert isinstance(media.document, types.Document) assert isinstance(media.document, types.Document)
input_media = types.InputMediaDocument( input_media = types.InputMediaDocument(
@ -185,9 +169,7 @@ class AlbumBuilder(metaclass=NoPublicConstructor):
ttl_seconds=media.ttl_seconds, ttl_seconds=media.ttl_seconds,
query=None, query=None,
) )
message, entities = parse_message( message, entities = parse_message(text=caption, markdown=caption_markdown, html=caption_html, allow_empty=True)
text=caption, markdown=caption_markdown, html=caption_html, allow_empty=True
)
self._medias.append( self._medias.append(
types.InputSingleMedia( types.InputSingleMedia(
media=input_media, media=input_media,
@ -197,9 +179,7 @@ class AlbumBuilder(metaclass=NoPublicConstructor):
) )
) )
async def send( async def send(self, peer: Peer | PeerRef, *, reply_to: Optional[int] = None) -> list[Message]:
self, peer: Peer | PeerRef, *, reply_to: Optional[int] = None
) -> list[Message]:
""" """
Send the album. Send the album.
@ -225,11 +205,7 @@ class AlbumBuilder(metaclass=NoPublicConstructor):
update_stickersets_order=False, update_stickersets_order=False,
peer=peer._ref._to_input_peer(), peer=peer._ref._to_input_peer(),
reply_to=( reply_to=(
types.InputReplyToMessage( types.InputReplyToMessage(reply_to_msg_id=reply_to, top_msg_id=None) if reply_to else None
reply_to_msg_id=reply_to, top_msg_id=None
)
if reply_to
else None
), ),
multi_media=self._medias, multi_media=self._medias,
schedule_date=None, schedule_date=None,

View File

@ -50,9 +50,7 @@ class Button(abc.ABC):
def __init__(self, text: str) -> None: def __init__(self, text: str) -> None:
if self.__class__ == Button: if self.__class__ == Button:
raise TypeError( raise TypeError(f"Can't instantiate abstract class {self.__class__.__name__}")
f"Can't instantiate abstract class {self.__class__.__name__}"
)
self._raw: RawButtonType = types.KeyboardButton(text=text) self._raw: RawButtonType = types.KeyboardButton(text=text)
self._msg: Optional[weakref.ReferenceType[Message]] = None self._msg: Optional[weakref.ReferenceType[Message]] = None

View File

@ -29,8 +29,6 @@ class InlineButton(Button, abc.ABC):
def __init__(self, text: str) -> None: def __init__(self, text: str) -> None:
if self.__class__ == InlineButton: if self.__class__ == InlineButton:
raise TypeError( raise TypeError(f"Can't instantiate abstract class {self.__class__.__name__}")
f"Can't instantiate abstract class {self.__class__.__name__}"
)
else: else:
super().__init__(text) super().__init__(text)

View File

@ -14,9 +14,7 @@ class SwitchInline(InlineButton):
def __init__(self, text: str, query: Optional[str] = None) -> None: def __init__(self, text: str, query: Optional[str] = None) -> None:
super().__init__(text) super().__init__(text)
self._raw = types.KeyboardButtonSwitchInline( self._raw = types.KeyboardButtonSwitchInline(same_peer=False, text=text, query=query or "", peer_types=None)
same_peer=False, text=text, query=query or "", peer_types=None
)
@property @property
def query(self) -> str: def query(self) -> str:

View File

@ -111,9 +111,7 @@ class ChatRestriction(Enum):
return set(filter(None, iter(restrictions))) return set(filter(None, iter(restrictions)))
@classmethod @classmethod
def _set_to_raw( def _set_to_raw(cls, restrictions: set[ChatRestriction], until_date: int) -> types.ChatBannedRights:
cls, restrictions: set[ChatRestriction], until_date: int
) -> types.ChatBannedRights:
return types.ChatBannedRights( return types.ChatBannedRights(
view_messages=cls.VIEW_MESSAGES in restrictions, view_messages=cls.VIEW_MESSAGES in restrictions,
send_messages=cls.SEND_MESSAGES in restrictions, send_messages=cls.SEND_MESSAGES in restrictions,

View File

@ -90,9 +90,6 @@ class Dialog(metaclass=NoPublicConstructor):
if isinstance(self._raw, types.Dialog): if isinstance(self._raw, types.Dialog):
return self._raw.unread_count return self._raw.unread_count
elif isinstance(self._raw, types.DialogPeerFolder): elif isinstance(self._raw, types.DialogPeerFolder):
return ( return self._raw.unread_unmuted_messages_count + self._raw.unread_muted_messages_count
self._raw.unread_unmuted_messages_count
+ self._raw.unread_muted_messages_count
)
else: else:
raise RuntimeError("unexpected case") raise RuntimeError("unexpected case")

View File

@ -37,9 +37,7 @@ class Draft(metaclass=NoPublicConstructor):
self._chat_map = chat_map self._chat_map = chat_map
@classmethod @classmethod
def _from_raw_update( def _from_raw_update(cls, client: Client, draft: types.UpdateDraftMessage, chat_map: dict[int, Peer]) -> Self:
cls, client: Client, draft: types.UpdateDraftMessage, chat_map: dict[int, Peer]
) -> Self:
return cls._create(client, draft.peer, draft.top_msg_id, draft.draft, chat_map) return cls._create(client, draft.peer, draft.top_msg_id, draft.draft, chat_map)
@classmethod @classmethod
@ -60,9 +58,7 @@ class Draft(metaclass=NoPublicConstructor):
This is also the chat where the message will be sent to by :meth:`send`. This is also the chat where the message will be sent to by :meth:`send`.
""" """
return self._chat_map.get(peer_id(self._peer)) or expand_peer( return self._chat_map.get(peer_id(self._peer)) or expand_peer(self._client, self._peer, broadcast=None)
self._client, self._peer, broadcast=None
)
@property @property
def link_preview(self) -> bool: def link_preview(self) -> bool:
@ -91,9 +87,7 @@ class Draft(metaclass=NoPublicConstructor):
The :attr:`~Message.text_html` of the message that will be sent. The :attr:`~Message.text_html` of the message that will be sent.
""" """
if text := getattr(self._raw, "message", None): if text := getattr(self._raw, "message", None):
return generate_html_message( return generate_html_message(text, getattr(self._raw, "entities", None) or [])
text, getattr(self._raw, "entities", None) or []
)
else: else:
return None return None
@ -103,9 +97,7 @@ class Draft(metaclass=NoPublicConstructor):
The :attr:`~Message.text_markdown` of the message that will be sent. The :attr:`~Message.text_markdown` of the message that will be sent.
""" """
if text := getattr(self._raw, "message", None): if text := getattr(self._raw, "message", None):
return generate_markdown_message( return generate_markdown_message(text, getattr(self._raw, "entities", None) or [])
text, getattr(self._raw, "entities", None) or []
)
else: else:
return None return None
@ -115,11 +107,7 @@ class Draft(metaclass=NoPublicConstructor):
The date when the draft was last updated. The date when the draft was last updated.
""" """
date = getattr(self._raw, "date", None) date = getattr(self._raw, "date", None)
return ( return datetime.datetime.fromtimestamp(date, tz=datetime.timezone.utc) if date is not None else None
datetime.datetime.fromtimestamp(date, tz=datetime.timezone.utc)
if date is not None
else None
)
async def edit( async def edit(
self, self,
@ -192,11 +180,7 @@ class Draft(metaclass=NoPublicConstructor):
noforwards=False, noforwards=False,
update_stickersets_order=False, update_stickersets_order=False,
peer=self._peer_ref()._to_input_peer(), peer=self._peer_ref()._to_input_peer(),
reply_to=( reply_to=(types.InputReplyToMessage(reply_to_msg_id=reply_to, top_msg_id=None) if reply_to else None),
types.InputReplyToMessage(reply_to_msg_id=reply_to, top_msg_id=None)
if reply_to
else None
),
message=message, message=message,
random_id=random_id, random_id=random_id,
reply_markup=None, reply_markup=None,
@ -211,11 +195,7 @@ class Draft(metaclass=NoPublicConstructor):
{}, {},
out=result.out, out=result.out,
id=result.id, id=result.id,
from_id=( from_id=(types.PeerUser(user_id=self._client._session.user.id) if self._client._session.user else None),
types.PeerUser(user_id=self._client._session.user.id)
if self._client._session.user
else None
),
peer_id=self._peer_ref()._to_peer(), peer_id=self._peer_ref()._to_peer(),
reply_to=( reply_to=(
types.MessageReplyHeader( types.MessageReplyHeader(
@ -235,9 +215,7 @@ class Draft(metaclass=NoPublicConstructor):
ttl_period=result.ttl_period, ttl_period=result.ttl_period,
) )
else: else:
return self._client._build_message_map( return self._client._build_message_map(result, self._peer_ref()).with_random_id(random_id)
result, self._peer_ref()
).with_random_id(random_id)
async def delete(self) -> None: async def delete(self) -> None:
""" """

View File

@ -29,11 +29,7 @@ def photo_size_byte_count(size: abcs.PhotoSize) -> int:
elif isinstance(size, types.PhotoSizeProgressive): elif isinstance(size, types.PhotoSizeProgressive):
return max(size.sizes) return max(size.sizes)
elif isinstance(size, types.PhotoStrippedSize): elif isinstance(size, types.PhotoStrippedSize):
return ( return len(stripped_size_header) + (len(size.bytes) - 3) + len(stripped_size_footer)
len(stripped_size_header)
+ (len(size.bytes) - 3)
+ len(stripped_size_footer)
)
else: else:
raise RuntimeError("unexpected case") raise RuntimeError("unexpected case")
@ -180,9 +176,7 @@ class File(metaclass=NoPublicConstructor):
self._client = client self._client = client
@classmethod @classmethod
def _try_from_raw_message_media( def _try_from_raw_message_media(cls, client: Client, raw: abcs.MessageMedia) -> Optional[Self]:
cls, client: Client, raw: abcs.MessageMedia
) -> Optional[Self]:
if isinstance(raw, types.MessageMediaDocument): if isinstance(raw, types.MessageMediaDocument):
if raw.document: if raw.document:
return cls._try_from_raw_document( return cls._try_from_raw_document(
@ -204,13 +198,9 @@ class File(metaclass=NoPublicConstructor):
elif isinstance(raw, types.MessageMediaWebPage): elif isinstance(raw, types.MessageMediaWebPage):
if isinstance(raw.webpage, types.WebPage): if isinstance(raw.webpage, types.WebPage):
if raw.webpage.document: if raw.webpage.document:
return cls._try_from_raw_document( return cls._try_from_raw_document(client, raw.webpage.document, orig_raw=raw)
client, raw.webpage.document, orig_raw=raw
)
if raw.webpage.photo: if raw.webpage.photo:
return cls._try_from_raw_photo( return cls._try_from_raw_photo(client, raw.webpage.photo, orig_raw=raw)
client, raw.webpage.photo, orig_raw=raw
)
return None return None
@ -229,21 +219,13 @@ class File(metaclass=NoPublicConstructor):
attributes=raw.attributes, attributes=raw.attributes,
size=raw.size, size=raw.size,
name=next( name=next(
( (a.file_name for a in raw.attributes if isinstance(a, types.DocumentAttributeFilename)),
a.file_name
for a in raw.attributes
if isinstance(a, types.DocumentAttributeFilename)
),
"", "",
), ),
mime=raw.mime_type, mime=raw.mime_type,
photo=False, photo=False,
muted=next( muted=next(
( (a.nosound for a in raw.attributes if isinstance(a, types.DocumentAttributeVideo)),
a.nosound
for a in raw.attributes
if isinstance(a, types.DocumentAttributeVideo)
),
False, False,
), ),
input_media=types.InputMediaDocument( input_media=types.InputMediaDocument(
@ -361,9 +343,7 @@ class File(metaclass=NoPublicConstructor):
return dim.w return dim.w
for attr in self._attributes: for attr in self._attributes:
if isinstance( if isinstance(attr, (types.DocumentAttributeImageSize, types.DocumentAttributeVideo)):
attr, (types.DocumentAttributeImageSize, types.DocumentAttributeVideo)
):
return attr.w return attr.w
return None return None
@ -377,9 +357,7 @@ class File(metaclass=NoPublicConstructor):
return dim.h return dim.h
for attr in self._attributes: for attr in self._attributes:
if isinstance( if isinstance(attr, (types.DocumentAttributeImageSize, types.DocumentAttributeVideo)):
attr, (types.DocumentAttributeImageSize, types.DocumentAttributeVideo)
):
return attr.h return attr.h
return None return None
@ -410,9 +388,7 @@ class File(metaclass=NoPublicConstructor):
id=self._input_media.id.id, id=self._input_media.id.id,
access_hash=self._input_media.id.access_hash, access_hash=self._input_media.id.access_hash,
file_reference=self._input_media.id.file_reference, file_reference=self._input_media.id.file_reference,
thumb_size=( thumb_size=(self._thumb.type if isinstance(self._thumb, thumb_types) else ""),
self._thumb.type if isinstance(self._thumb, thumb_types) else ""
),
) )
elif isinstance(self._input_media, types.InputMediaPhoto): elif isinstance(self._input_media, types.InputMediaPhoto):
assert isinstance(self._input_media.id, types.InputPhoto) assert isinstance(self._input_media.id, types.InputPhoto)

View File

@ -12,16 +12,11 @@ def _build_keyboard_rows(
) -> list[abcs.KeyboardButtonRow]: ) -> list[abcs.KeyboardButtonRow]:
# list[button] -> list[list[button]] # list[button] -> list[list[button]]
# This does allow for "invalid" inputs (mixing lists and non-lists), but that's acceptable. # This does allow for "invalid" inputs (mixing lists and non-lists), but that's acceptable.
buttons_lists_iter = [ buttons_lists_iter = [button if isinstance(button, list) else [button] for button in (btns or [])]
button if isinstance(button, list) else [button] for button in (btns or [])
]
# Remove empty rows (also making it easy to check if all-empty). # Remove empty rows (also making it easy to check if all-empty).
buttons_lists = [bs for bs in buttons_lists_iter if bs] buttons_lists = [bs for bs in buttons_lists_iter if bs]
return [ return [types.KeyboardButtonRow(buttons=[btn._raw for btn in btns]) for btns in buttons_lists]
types.KeyboardButtonRow(buttons=[btn._raw for btn in btns])
for btns in buttons_lists
]
class Keyboard: class Keyboard:
@ -49,9 +44,7 @@ class Keyboard:
class InlineKeyboard: class InlineKeyboard:
__slots__ = ("_raw",) __slots__ = ("_raw",)
def __init__( def __init__(self, buttons: list[AnyInlineButton] | list[list[AnyInlineButton]]) -> None:
self, buttons: list[AnyInlineButton] | list[list[AnyInlineButton]]
) -> None:
self._raw = types.ReplyInlineMarkup(rows=_build_keyboard_rows(buttons)) self._raw = types.ReplyInlineMarkup(rows=_build_keyboard_rows(buttons))

View File

@ -34,11 +34,7 @@ def generate_random_id() -> int:
def adapt_date(date: Optional[int]) -> Optional[datetime.datetime]: def adapt_date(date: Optional[int]) -> Optional[datetime.datetime]:
return ( return datetime.datetime.fromtimestamp(date, tz=datetime.timezone.utc) if date is not None else None
datetime.datetime.fromtimestamp(date, tz=datetime.timezone.utc)
if date is not None
else None
)
class Message(metaclass=NoPublicConstructor): class Message(metaclass=NoPublicConstructor):
@ -59,20 +55,14 @@ class Message(metaclass=NoPublicConstructor):
print('Found empty message with ID', message.id) print('Found empty message with ID', message.id)
""" """
def __init__( def __init__(self, client: Client, message: abcs.Message, chat_map: dict[int, Peer]) -> None:
self, client: Client, message: abcs.Message, chat_map: dict[int, Peer] assert isinstance(message, (types.Message, types.MessageService, types.MessageEmpty))
) -> None:
assert isinstance(
message, (types.Message, types.MessageService, types.MessageEmpty)
)
self._client = client self._client = client
self._raw = message self._raw = message
self._chat_map = chat_map self._chat_map = chat_map
@classmethod @classmethod
def _from_raw( def _from_raw(cls, client: Client, message: abcs.Message, chat_map: dict[int, Peer]) -> Self:
cls, client: Client, message: abcs.Message, chat_map: dict[int, Peer]
) -> Self:
return cls._create(client, message, chat_map) return cls._create(client, message, chat_map)
@classmethod @classmethod
@ -158,9 +148,7 @@ class Message(metaclass=NoPublicConstructor):
See :ref:`formatting` to learn the HTML elements used. See :ref:`formatting` to learn the HTML elements used.
""" """
if text := getattr(self._raw, "message", None): if text := getattr(self._raw, "message", None):
return generate_html_message( return generate_html_message(text, getattr(self._raw, "entities", None) or [])
text, getattr(self._raw, "entities", None) or []
)
else: else:
return None return None
@ -172,9 +160,7 @@ class Message(metaclass=NoPublicConstructor):
See :ref:`formatting` to learn the formatting characters used. See :ref:`formatting` to learn the formatting characters used.
""" """
if text := getattr(self._raw, "message", None): if text := getattr(self._raw, "message", None):
return generate_markdown_message( return generate_markdown_message(text, getattr(self._raw, "entities", None) or [])
text, getattr(self._raw, "entities", None) or []
)
else: else:
return None return None
@ -193,9 +179,7 @@ class Message(metaclass=NoPublicConstructor):
peer = self._raw.peer_id or types.PeerUser(user_id=0) peer = self._raw.peer_id or types.PeerUser(user_id=0)
pid = peer_id(peer) pid = peer_id(peer)
if pid not in self._chat_map: if pid not in self._chat_map:
self._chat_map[pid] = expand_peer( self._chat_map[pid] = expand_peer(self._client, peer, broadcast=getattr(self._raw, "post", None))
self._client, peer, broadcast=getattr(self._raw, "post", None)
)
return self._chat_map[pid] return self._chat_map[pid]
@property @property
@ -239,14 +223,7 @@ class Message(metaclass=NoPublicConstructor):
This can also be used as a way to check that the message media is an audio. This can also be used as a way to check that the message media is an audio.
""" """
audio = self._file() audio = self._file()
return ( return audio if audio and any(isinstance(a, types.DocumentAttributeAudio) for a in audio._attributes) else None
audio
if audio
and any(
isinstance(a, types.DocumentAttributeAudio) for a in audio._attributes
)
else None
)
@property @property
def video(self) -> Optional[File]: def video(self) -> Optional[File]:
@ -256,14 +233,7 @@ class Message(metaclass=NoPublicConstructor):
This can also be used as a way to check that the message media is a video. This can also be used as a way to check that the message media is a video.
""" """
audio = self._file() audio = self._file()
return ( return audio if audio and any(isinstance(a, types.DocumentAttributeVideo) for a in audio._attributes) else None
audio
if audio
and any(
isinstance(a, types.DocumentAttributeVideo) for a in audio._attributes
)
else None
)
@property @property
def file(self) -> Optional[File]: def file(self) -> Optional[File]:
@ -477,10 +447,7 @@ class Message(metaclass=NoPublicConstructor):
return None return None
return [ return [
[ [create_button(self, button) for button in cast(types.KeyboardButtonRow, row).buttons]
create_button(self, button)
for button in cast(types.KeyboardButtonRow, row).buttons
]
for row in markup.rows for row in markup.rows
] ]
@ -506,13 +473,8 @@ class Message(metaclass=NoPublicConstructor):
return not isinstance(self._raw, types.MessageEmpty) return not isinstance(self._raw, types.MessageEmpty)
def build_msg_map( def build_msg_map(client: Client, messages: Sequence[abcs.Message], chat_map: dict[int, Peer]) -> dict[int, Message]:
client: Client, messages: Sequence[abcs.Message], chat_map: dict[int, Peer] return {msg.id: msg for msg in (Message._from_raw(client, m, chat_map) for m in messages)}
) -> dict[int, Message]:
return {
msg.id: msg
for msg in (Message._from_raw(client, m, chat_map) for m in messages)
}
def parse_message( def parse_message(

View File

@ -16,23 +16,16 @@ class Final(abc.ABCMeta):
cls_namespace: dict[str, object], cls_namespace: dict[str, object],
) -> "Final": ) -> "Final":
# Allow subclassing while within telethon._impl (or other package names). # Allow subclassing while within telethon._impl (or other package names).
allowed_base = Final.__module__[ allowed_base = Final.__module__[: Final.__module__.find(".", Final.__module__.find(".") + 1)]
: Final.__module__.find(".", Final.__module__.find(".") + 1)
]
for base in bases: for base in bases:
if isinstance(base, Final) and not base.__module__.startswith(allowed_base): if isinstance(base, Final) and not base.__module__.startswith(allowed_base):
raise TypeError( raise TypeError(f"{base.__module__}.{base.__qualname__} does not support" " subclassing")
f"{base.__module__}.{base.__qualname__} does not support"
" subclassing"
)
return super().__new__(cls, name, bases, cls_namespace) return super().__new__(cls, name, bases, cls_namespace)
class NoPublicConstructor(Final): class NoPublicConstructor(Final):
def __call__(cls, *args: Any, **kwds: Any) -> Any: def __call__(cls, *args: Any, **kwds: Any) -> Any:
raise TypeError( raise TypeError(f"{cls.__module__}.{cls.__qualname__} has no public constructor")
f"{cls.__module__}.{cls.__qualname__} has no public constructor"
)
@property @property
def _create(cls: Type[T]) -> Type[T]: def _create(cls: Type[T]) -> Type[T]:

View File

@ -157,22 +157,16 @@ class Participant(metaclass=NoPublicConstructor):
""" """
:data:`True` if the participant is the creator of the chat. :data:`True` if the participant is the creator of the chat.
""" """
return isinstance( return isinstance(self._raw, (types.ChannelParticipantCreator, types.ChatParticipantCreator))
self._raw, (types.ChannelParticipantCreator, types.ChatParticipantCreator)
)
@property @property
def admin_rights(self) -> Optional[set[AdminRight]]: def admin_rights(self) -> Optional[set[AdminRight]]:
""" """
The set of administrator rights this participant has been granted, if they are an administrator. The set of administrator rights this participant has been granted, if they are an administrator.
""" """
if isinstance( if isinstance(self._raw, (types.ChannelParticipantCreator, types.ChannelParticipantAdmin)):
self._raw, (types.ChannelParticipantCreator, types.ChannelParticipantAdmin)
):
return AdminRight._from_raw(self._raw.admin_rights) return AdminRight._from_raw(self._raw.admin_rights)
elif isinstance( elif isinstance(self._raw, (types.ChatParticipantCreator, types.ChatParticipantAdmin)):
self._raw, (types.ChatParticipantCreator, types.ChatParticipantAdmin)
):
return AdminRight._chat_rights() return AdminRight._chat_rights()
else: else:
return None return None
@ -194,13 +188,9 @@ class Participant(metaclass=NoPublicConstructor):
participant = self.user or self.banned or self.left participant = self.user or self.banned or self.left
assert participant assert participant
if isinstance(participant, User): if isinstance(participant, User):
await self._client.set_participant_admin_rights( await self._client.set_participant_admin_rights(self._chat, participant, rights)
self._chat, participant, rights
)
else: else:
raise TypeError( raise TypeError(f"participant of type {participant.__class__.__name__} cannot be made admin")
f"participant of type {participant.__class__.__name__} cannot be made admin"
)
async def set_restrictions( async def set_restrictions(
self, self,
@ -213,6 +203,4 @@ class Participant(metaclass=NoPublicConstructor):
""" """
participant = self.user or self.banned or self.left participant = self.user or self.banned or self.left
assert participant assert participant
await self._client.set_participant_restrictions( await self._client.set_participant_restrictions(self._chat, participant, restrictions, until=until)
self._chat, participant, restrictions, until=until
)

View File

@ -15,9 +15,7 @@ if TYPE_CHECKING:
from ...client.client import Client from ...client.client import Client
def build_chat_map( def build_chat_map(client: Client, users: Sequence[abcs.User], chats: Sequence[abcs.Chat]) -> dict[int, Peer]:
client: Client, users: Sequence[abcs.User], chats: Sequence[abcs.Chat]
) -> dict[int, Peer]:
users_iter = (User._from_raw(u) for u in users) users_iter = (User._from_raw(u) for u in users)
chats_iter = ( chats_iter = (
( (
@ -45,9 +43,7 @@ def build_chat_map(
for x in v: for x in v:
print(x, file=sys.stderr) print(x, file=sys.stderr)
raise RuntimeError( raise RuntimeError(f"chat identifier collision: {k}; please report this")
f"chat identifier collision: {k}; please report this"
)
return result return result
@ -81,11 +77,7 @@ def expand_peer(client: Client, peer: abcs.Peer, *, broadcast: Optional[bool]) -
until_date=None, until_date=None,
) )
return ( return Channel._from_raw(channel) if broadcast else Group._from_raw(client, channel)
Channel._from_raw(channel)
if broadcast
else Group._from_raw(client, channel)
)
else: else:
raise RuntimeError("unexpected case") raise RuntimeError("unexpected case")

View File

@ -24,13 +24,7 @@ class Group(Peer, metaclass=NoPublicConstructor):
def __init__( def __init__(
self, self,
client: Client, client: Client,
chat: ( chat: (types.ChatEmpty | types.Chat | types.ChatForbidden | types.Channel | types.ChannelForbidden),
types.ChatEmpty
| types.Chat
| types.ChatForbidden
| types.Channel
| types.ChannelForbidden
),
) -> None: ) -> None:
self._client = client self._client = client
self._raw = chat self._raw = chat
@ -96,6 +90,4 @@ class Group(Peer, metaclass=NoPublicConstructor):
""" """
Alias for :meth:`telethon.Client.set_chat_default_restrictions`. Alias for :meth:`telethon.Client.set_chat_default_restrictions`.
""" """
await self._client.set_chat_default_restrictions( await self._client.set_chat_default_restrictions(self, restrictions, until=until)
self, restrictions, until=until
)

View File

@ -1,16 +1,10 @@
try: try:
import cryptg import cryptg # type: ignore [import-untyped]
def ige_encrypt( def ige_encrypt(plaintext: bytes | bytearray | memoryview, key: bytes, iv: bytes) -> bytes: # noqa: F811
plaintext: bytes | bytearray | memoryview, key: bytes, iv: bytes return cryptg.encrypt_ige(bytes(plaintext) if not isinstance(plaintext, bytes) else plaintext, key, iv)
) -> bytes: # noqa: F811
return cryptg.encrypt_ige(
bytes(plaintext) if not isinstance(plaintext, bytes) else plaintext, key, iv
)
def ige_decrypt( def ige_decrypt(ciphertext: bytes | bytearray | memoryview, key: bytes, iv: bytes) -> bytes: # noqa: F811
ciphertext: bytes | bytearray | memoryview, key: bytes, iv: bytes
) -> bytes: # noqa: F811
return cryptg.decrypt_ige( return cryptg.decrypt_ige(
bytes(ciphertext) if not isinstance(ciphertext, bytes) else ciphertext, bytes(ciphertext) if not isinstance(ciphertext, bytes) else ciphertext,
key, key,
@ -18,11 +12,9 @@ try:
) )
except ImportError: except ImportError:
import pyaes import pyaes # type: ignore [import-untyped]
def ige_encrypt( def ige_encrypt(plaintext: bytes | bytearray | memoryview, key: bytes, iv: bytes) -> bytes:
plaintext: bytes | bytearray | memoryview, key: bytes, iv: bytes
) -> bytes:
assert len(plaintext) % 16 == 0 assert len(plaintext) % 16 == 0
assert len(iv) == 32 assert len(iv) == 32
@ -35,10 +27,7 @@ except ImportError:
for block_offset in range(0, len(plaintext), 16): for block_offset in range(0, len(plaintext), 16):
plaintext_block = plaintext[block_offset : block_offset + 16] plaintext_block = plaintext[block_offset : block_offset + 16]
ciphertext_block = bytes( ciphertext_block = bytes(
a ^ b a ^ b for a, b in zip(aes.encrypt([a ^ b for a, b in zip(plaintext_block, iv1)]), iv2)
for a, b in zip(
aes.encrypt([a ^ b for a, b in zip(plaintext_block, iv1)]), iv2
)
) )
iv1 = ciphertext_block iv1 = ciphertext_block
iv2 = plaintext_block iv2 = plaintext_block
@ -47,9 +36,7 @@ except ImportError:
return bytes(ciphertext) return bytes(ciphertext)
def ige_decrypt( def ige_decrypt(ciphertext: bytes | bytearray | memoryview, key: bytes, iv: bytes) -> bytes:
ciphertext: bytes | bytearray | memoryview, key: bytes, iv: bytes
) -> bytes:
assert len(ciphertext) % 16 == 0 assert len(ciphertext) % 16 == 0
assert len(iv) == 32 assert len(iv) == 32
@ -62,10 +49,7 @@ except ImportError:
for block_offset in range(0, len(ciphertext), 16): for block_offset in range(0, len(ciphertext), 16):
ciphertext_block = ciphertext[block_offset : block_offset + 16] ciphertext_block = ciphertext[block_offset : block_offset + 16]
plaintext_block = bytes( plaintext_block = bytes(
a ^ b a ^ b for a, b in zip(aes.decrypt([a ^ b for a, b in zip(ciphertext_block, iv2)]), iv1)
for a, b in zip(
aes.decrypt([a ^ b for a, b in zip(ciphertext_block, iv2)]), iv1
)
) )
iv1 = ciphertext_block iv1 = ciphertext_block
iv2 = plaintext_block iv2 = plaintext_block

View File

@ -20,8 +20,4 @@ class AuthKey:
return self.data return self.data
def calc_new_nonce_hash(self, new_nonce: int, number: int) -> int: def calc_new_nonce_hash(self, new_nonce: int, number: int) -> int:
return int.from_bytes( return int.from_bytes(sha1(new_nonce.to_bytes(32) + number.to_bytes(1) + self.aux_hash).digest()[4:])
sha1(new_nonce.to_bytes(32) + number.to_bytes(1) + self.aux_hash).digest()[
4:
]
)

View File

@ -19,9 +19,7 @@ class CalcKey(NamedTuple):
# https://core.telegram.org/mtproto/description#defining-aes-key-and-initialization-vector # https://core.telegram.org/mtproto/description#defining-aes-key-and-initialization-vector
def calc_key( def calc_key(auth_key: AuthKey, msg_key: bytes | bytearray | memoryview, side: Side) -> CalcKey:
auth_key: AuthKey, msg_key: bytes | bytearray | memoryview, side: Side
) -> CalcKey:
x = int(side) x = int(side)
# sha256_a = SHA256 (msg_key + substr (auth_key, x, 36)) # sha256_a = SHA256 (msg_key + substr (auth_key, x, 36))
@ -43,12 +41,8 @@ def determine_padding_v2_length(length: int) -> int:
return 16 + (16 - (length % 16)) return 16 + (16 - (length % 16))
def _do_encrypt_data_v2( def _do_encrypt_data_v2(plaintext: bytes, auth_key: AuthKey, random_padding: bytes) -> bytes:
plaintext: bytes, auth_key: AuthKey, random_padding: bytes padded_plaintext = plaintext + random_padding[: determine_padding_v2_length(len(plaintext))]
) -> bytes:
padded_plaintext = (
plaintext + random_padding[: determine_padding_v2_length(len(plaintext))]
)
side = Side.CLIENT side = Side.CLIENT
x = int(side) x = int(side)
@ -70,9 +64,7 @@ def encrypt_data_v2(plaintext: bytes, auth_key: AuthKey) -> bytes:
return _do_encrypt_data_v2(plaintext, auth_key, random_padding) return _do_encrypt_data_v2(plaintext, auth_key, random_padding)
def decrypt_data_v2( def decrypt_data_v2(ciphertext: bytes | bytearray | memoryview, auth_key: AuthKey) -> bytes:
ciphertext: bytes | bytearray | memoryview, auth_key: AuthKey
) -> bytes:
side = Side.SERVER side = Side.SERVER
x = int(side) x = int(side)

View File

@ -34,17 +34,13 @@ def encrypt_hashed(data: bytes, key: PublicKey, random_bytes: bytes) -> bytes:
temp_key = random_bytes[192 + 32 * attempt : 192 + 32 * attempt + 32] temp_key = random_bytes[192 + 32 * attempt : 192 + 32 * attempt + 32]
# data_with_hash := data_pad_reversed + SHA256(temp_key + data_with_padding); -- after this assignment, data_with_hash is exactly 224 bytes long. # data_with_hash := data_pad_reversed + SHA256(temp_key + data_with_padding); -- after this assignment, data_with_hash is exactly 224 bytes long.
data_with_hash = ( data_with_hash = data_pad_reversed + sha256(temp_key + data_with_padding).digest()
data_pad_reversed + sha256(temp_key + data_with_padding).digest()
)
# aes_encrypted := AES256_IGE(data_with_hash, temp_key, 0); -- AES256-IGE encryption with zero IV. # aes_encrypted := AES256_IGE(data_with_hash, temp_key, 0); -- AES256-IGE encryption with zero IV.
aes_encrypted = ige_encrypt(data_with_hash, temp_key, bytes(32)) aes_encrypted = ige_encrypt(data_with_hash, temp_key, bytes(32))
# temp_key_xor := temp_key XOR SHA256(aes_encrypted); -- adjusted key, 32 bytes # temp_key_xor := temp_key XOR SHA256(aes_encrypted); -- adjusted key, 32 bytes
temp_key_xor = bytes( temp_key_xor = bytes(a ^ b for a, b in zip(temp_key, sha256(aes_encrypted).digest()))
a ^ b for a, b in zip(temp_key, sha256(aes_encrypted).digest())
)
# key_aes_encrypted := temp_key_xor + aes_encrypted; -- exactly 256 bytes (2048 bits) long # key_aes_encrypted := temp_key_xor + aes_encrypted; -- exactly 256 bytes (2048 bits) long
key_aes_encrypted = temp_key_xor + aes_encrypted key_aes_encrypted = temp_key_xor + aes_encrypted
@ -87,6 +83,4 @@ j4WcDuXc2CTHgH8gFTNhp/Y8/SpDOhvn9QIDAQAB
) )
RSA_KEYS = { RSA_KEYS = {compute_fingerprint(key): key for key in (PRODUCTION_RSA_KEY, TESTMODE_RSA_KEY)}
compute_fingerprint(key): key for key in (PRODUCTION_RSA_KEY, TESTMODE_RSA_KEY)
}

View File

@ -20,9 +20,7 @@ def h(*data: bytes | bytearray | memoryview) -> bytes:
# SH(data, salt) := H(salt | data | salt) # SH(data, salt) := H(salt | data | salt)
def sh( def sh(data: bytes | bytearray | memoryview, salt: bytes | bytearray | memoryview) -> bytes:
data: bytes | bytearray | memoryview, salt: bytes | bytearray | memoryview
) -> bytes:
return h(salt, data, salt) return h(salt, data, salt)

View File

@ -108,13 +108,9 @@ def _do_step2(data: Step1, response: bytes, random_bytes: bytes) -> tuple[bytes,
) )
try: try:
fingerprint = next( fingerprint = next(fp for fp in res_pq.server_public_key_fingerprints if fp in RSA_KEYS)
fp for fp in res_pq.server_public_key_fingerprints if fp in RSA_KEYS
)
except StopIteration: except StopIteration:
raise ValueError( raise ValueError(f"unknown fingerprints: {res_pq.server_public_key_fingerprints}")
f"unknown fingerprints: {res_pq.server_public_key_fingerprints}"
)
key = RSA_KEYS[fingerprint] key = RSA_KEYS[fingerprint]
ciphertext = encrypt_hashed(pq_inner_data, key, random_bytes) ciphertext = encrypt_hashed(pq_inner_data, key, random_bytes)
@ -133,9 +129,7 @@ def step2(data: Step1, response: bytes) -> tuple[bytes, Step2]:
return _do_step2(data, response, os.urandom(288)) return _do_step2(data, response, os.urandom(288))
def _do_step3( def _do_step3(data: Step2, response: bytes, random_bytes: bytes, now: int) -> tuple[bytes, Step3]:
data: Step2, response: bytes, random_bytes: bytes, now: int
) -> tuple[bytes, Step3]:
assert len(random_bytes) == 272 assert len(random_bytes) == 272
nonce = data.nonce nonce = data.nonce
@ -158,9 +152,7 @@ def _do_step3(
check_server_nonce(server_dh_params.server_nonce, server_nonce) check_server_nonce(server_dh_params.server_nonce, server_nonce)
if len(server_dh_params.encrypted_answer) % 16 != 0: if len(server_dh_params.encrypted_answer) % 16 != 0:
raise ValueError( raise ValueError(f"encrypted response not padded with size: {len(server_dh_params.encrypted_answer)}")
f"encrypted response not padded with size: {len(server_dh_params.encrypted_answer)}"
)
key, iv = generate_key_data_from_nonce(server_nonce, new_nonce) key, iv = generate_key_data_from_nonce(server_nonce, new_nonce)
assert isinstance(server_dh_params.encrypted_answer, bytes) assert isinstance(server_dh_params.encrypted_answer, bytes)
@ -172,9 +164,7 @@ def _do_step3(
server_dh_inner = AbcServerDhInnerData._read_from(plain_text_reader) server_dh_inner = AbcServerDhInnerData._read_from(plain_text_reader)
assert isinstance(server_dh_inner, ServerDhInnerData) assert isinstance(server_dh_inner, ServerDhInnerData)
expected_answer_hash = sha1( expected_answer_hash = sha1(plain_text_answer[20 : 20 + plain_text_reader._pos]).digest()
plain_text_answer[20 : 20 + plain_text_reader._pos]
).digest()
if got_answer_hash != expected_answer_hash: if got_answer_hash != expected_answer_hash:
raise ValueError("invalid answer hash") raise ValueError("invalid answer hash")
@ -213,15 +203,11 @@ def _do_step3(
) )
client_dh_inner_hashed = sha1(client_dh_inner).digest() + client_dh_inner client_dh_inner_hashed = sha1(client_dh_inner).digest() + client_dh_inner
client_dh_inner_hashed += random_bytes[ client_dh_inner_hashed += random_bytes[: (16 - (len(client_dh_inner_hashed) % 16)) % 16]
: (16 - (len(client_dh_inner_hashed) % 16)) % 16
]
client_dh_encrypted = encrypt_ige(client_dh_inner_hashed, key, iv) client_dh_encrypted = encrypt_ige(client_dh_inner_hashed, key, iv)
return set_client_dh_params( return set_client_dh_params(nonce=nonce, server_nonce=server_nonce, encrypted_data=client_dh_encrypted), Step3(
nonce=nonce, server_nonce=server_nonce, encrypted_data=client_dh_encrypted
), Step3(
nonce=nonce, nonce=nonce,
server_nonce=server_nonce, server_nonce=server_nonce,
new_nonce=new_nonce, new_nonce=new_nonce,
@ -277,10 +263,7 @@ def create_key(data: Step3, response: bytes) -> CreatedKey:
first_salt = struct.unpack( first_salt = struct.unpack(
"<q", "<q",
bytes( bytes(a ^ b for a, b in zip(new_nonce.to_bytes(32)[:8], server_nonce.to_bytes(16)[:8])),
a ^ b
for a, b in zip(new_nonce.to_bytes(32)[:8], server_nonce.to_bytes(16)[:8])
),
)[0] )[0]
if dh_gen.nonce_number == 1: if dh_gen.nonce_number == 1:
@ -310,4 +293,4 @@ def check_new_nonce_hash(got: int, expected: int) -> None:
def check_g_in_range(value: int, low: int, high: int) -> None: def check_g_in_range(value: int, low: int, high: int) -> None:
if not (low < value < high): if not (low < value < high):
raise ValueError(f"g parameter {value} not in range({low+1}, {high})") raise ValueError(f"g parameter {value} not in range({low + 1}, {high})")

View File

@ -107,9 +107,7 @@ class Encrypted(Mtp):
) -> None: ) -> None:
self._auth_key = auth_key self._auth_key = auth_key
self._time_offset: int = time_offset or 0 self._time_offset: int = time_offset or 0
self._salts: list[FutureSalt] = [ self._salts: list[FutureSalt] = [FutureSalt(valid_since=0, valid_until=0x7FFFFFFF, salt=first_salt or 0)]
FutureSalt(valid_since=0, valid_until=0x7FFFFFFF, salt=first_salt or 0)
]
self._start_salt_time: Optional[tuple[int, float]] = None self._start_salt_time: Optional[tuple[int, float]] = None
self._compression_threshold = compression_threshold self._compression_threshold = compression_threshold
self._deserialization: list[Deserialization] = [] self._deserialization: list[Deserialization] = []
@ -203,9 +201,7 @@ class Encrypted(Mtp):
if self._msg_count == 1: if self._msg_count == 1:
del self._buffer[:CONTAINER_HEADER_LEN] del self._buffer[:CONTAINER_HEADER_LEN]
self._buffer[:HEADER_LEN] = struct.pack( self._buffer[:HEADER_LEN] = struct.pack("<qq", self._get_current_salt(), self._client_id)
"<qq", self._get_current_salt(), self._client_id
)
if self._msg_count != 1: if self._msg_count != 1:
self._buffer[HEADER_LEN : HEADER_LEN + CONTAINER_HEADER_LEN] = struct.pack( self._buffer[HEADER_LEN : HEADER_LEN + CONTAINER_HEADER_LEN] = struct.pack(
@ -283,11 +279,7 @@ class Encrypted(Mtp):
if isinstance(bad_msg, BadServerSalt): if isinstance(bad_msg, BadServerSalt):
self._salts.clear() self._salts.clear()
self._salts.append( self._salts.append(FutureSalt(valid_since=0, valid_until=0x7FFFFFFF, salt=bad_msg.new_server_salt))
FutureSalt(
valid_since=0, valid_until=0x7FFFFFFF, salt=bad_msg.new_server_salt
)
)
self._salt_request_msg_id = None self._salt_request_msg_id = None
elif bad_msg.error_code in (16, 17): elif bad_msg.error_code in (16, 17):
self._correct_time_offset(message.msg_id) self._correct_time_offset(message.msg_id)
@ -324,9 +316,7 @@ class Encrypted(Mtp):
# Response to internal request, do not propagate. # Response to internal request, do not propagate.
self._salt_request_msg_id = None self._salt_request_msg_id = None
else: else:
self._deserialization.append( self._deserialization.append(RpcResult(MsgId(salts.req_msg_id), message.body))
RpcResult(MsgId(salts.req_msg_id), message.body)
)
self._start_salt_time = (salts.now, self._adjusted_now()) self._start_salt_time = (salts.now, self._adjusted_now())
self._salts = list(salts.salts) self._salts = list(salts.salts)
@ -346,11 +336,7 @@ class Encrypted(Mtp):
def _handle_new_session_created(self, message: Message) -> None: def _handle_new_session_created(self, message: Message) -> None:
new_session = NewSessionCreated.from_bytes(message.body) new_session = NewSessionCreated.from_bytes(message.body)
self._salts.clear() self._salts.clear()
self._salts.append( self._salts.append(FutureSalt(valid_since=0, valid_until=0x7FFFFFFF, salt=new_session.server_salt))
FutureSalt(
valid_since=0, valid_until=0x7FFFFFFF, salt=new_session.server_salt
)
)
def _handle_container(self, message: Message) -> None: def _handle_container(self, message: Message) -> None:
container = MsgContainer.from_bytes(message.body) container = MsgContainer.from_bytes(message.body)
@ -376,11 +362,7 @@ class Encrypted(Mtp):
self._deserialization.append(Update(message.body)) self._deserialization.append(Update(message.body))
def _try_request_salts(self) -> None: def _try_request_salts(self) -> None:
if ( if len(self._salts) == 1 and self._salt_request_msg_id is None and self._get_current_salt() != 0:
len(self._salts) == 1
and self._salt_request_msg_id is None
and self._get_current_salt() != 0
):
# If salts are requested in a container leading to bad_msg, # If salts are requested in a container leading to bad_msg,
# the bad_msg_id will refer to the container, not the salts request. # the bad_msg_id will refer to the container, not the salts request.
# #
@ -388,9 +370,7 @@ class Encrypted(Mtp):
# This would break, because we couldn't identify the response. # This would break, because we couldn't identify the response.
# #
# So salts are only requested once we have a valid salt to reduce the chances of this happening. # So salts are only requested once we have a valid salt to reduce the chances of this happening.
self._salt_request_msg_id = self._serialize_msg( self._salt_request_msg_id = self._serialize_msg(bytes(get_future_salts(num=NUM_FUTURE_SALTS)), True)
bytes(get_future_salts(num=NUM_FUTURE_SALTS)), True
)
def push(self, request: bytes) -> Optional[MsgId]: def push(self, request: bytes) -> Optional[MsgId]:
if self._start_salt_time and len(self._salts) >= 2: if self._start_salt_time and len(self._salts) >= 2:
@ -435,9 +415,7 @@ class Encrypted(Mtp):
return MsgId(self._last_msg_id), encrypt_data_v2(result, self._auth_key) return MsgId(self._last_msg_id), encrypt_data_v2(result, self._auth_key)
def deserialize( def deserialize(self, payload: bytes | bytearray | memoryview) -> list[Deserialization]:
self, payload: bytes | bytearray | memoryview
) -> list[Deserialization]:
check_message_buffer(payload) check_message_buffer(payload)
plaintext = decrypt_data_v2(payload, self._auth_key) plaintext = decrypt_data_v2(payload, self._auth_key)

View File

@ -31,9 +31,7 @@ class Plain(Mtp):
self._buffer.clear() self._buffer.clear()
return MsgId(0), result return MsgId(0), result
def deserialize( def deserialize(self, payload: bytes | bytearray | memoryview) -> list[Deserialization]:
self, payload: bytes | bytearray | memoryview
) -> list[Deserialization]:
check_message_buffer(payload) check_message_buffer(payload)
auth_key_id, msg_id, length = struct.unpack_from("<qqi", payload) auth_key_id, msg_id, length = struct.unpack_from("<qqi", payload)
@ -48,8 +46,6 @@ class Plain(Mtp):
raise ValueError(f"bad length: expected >= 0, got: {length}") raise ValueError(f"bad length: expected >= 0, got: {length}")
if 20 + length > len(payload): if 20 + length > len(payload):
raise ValueError( raise ValueError(f"message too short, expected: {20 + length}, got {len(payload)}")
f"message too short, expected: {20 + length}, got {len(payload)}"
)
return [RpcResult(MsgId(0), bytes(payload[20 : 20 + length]))] return [RpcResult(MsgId(0), bytes(payload[20 : 20 + length]))]

View File

@ -116,11 +116,7 @@ class RpcError(ValueError):
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
if not isinstance(other, self.__class__): if not isinstance(other, self.__class__):
return NotImplemented return NotImplemented
return ( return self._code == other._code and self._name == other._name and self._value == other._value
self._code == other._code
and self._name == other._name
and self._value == other._value
)
# https://core.telegram.org/mtproto/service_messages_about_messages # https://core.telegram.org/mtproto/service_messages_about_messages
@ -156,9 +152,7 @@ class BadMessage(ValueError):
self.msg_id = msg_id self.msg_id = msg_id
self._code = code self._code = code
self._caused_by = caused_by self._caused_by = caused_by
self.severity = ( self.severity = logging.WARNING if self._code in NON_FATAL_MSG_IDS else logging.ERROR
logging.WARNING if self._code in NON_FATAL_MSG_IDS else logging.ERROR
)
@property @property
def code(self) -> int: def code(self) -> int:
@ -201,9 +195,7 @@ class Mtp(ABC):
""" """
@abstractmethod @abstractmethod
def deserialize( def deserialize(self, payload: bytes | bytearray | memoryview) -> list[Deserialization]:
self, payload: bytes | bytearray | memoryview
) -> list[Deserialization]:
""" """
Deserialize incoming buffer payload. Deserialize incoming buffer payload.
""" """

View File

@ -42,10 +42,7 @@ class Intermediate(Transport):
raise MissingBytes(expected=length, got=len(input)) raise MissingBytes(expected=length, got=len(input))
if length <= 4: if length <= 4:
if ( if length >= 4 and (status := struct.unpack("<i", input[4 : 4 + length])[0]) < 0:
length >= 4
and (status := struct.unpack("<i", input[4 : 4 + length])[0]) < 0
):
raise BadStatus(status=-status) raise BadStatus(status=-status)
raise ValueError(f"bad length, expected > 0, got: {length}") raise ValueError(f"bad length, expected > 0, got: {length}")

View File

@ -12,9 +12,7 @@ MESSAGE_SIZE_OVERHEAD = 8 + 4 + 4 # msg_id, seq_no, bytes
def check_message_buffer(message: bytes | bytearray | memoryview) -> None: def check_message_buffer(message: bytes | bytearray | memoryview) -> None:
if len(message) < 20: if len(message) < 20:
raise ValueError( raise ValueError(f"server payload is too small to be a valid message: {message.hex()}")
f"server payload is too small to be a valid message: {message.hex()}"
)
# https://core.telegram.org/mtproto/description#content-related-message # https://core.telegram.org/mtproto/description#content-related-message

View File

@ -313,13 +313,7 @@ class Sender:
def _on_ping_timeout(self) -> None: def _on_ping_timeout(self) -> None:
ping_id = generate_random_id() ping_id = generate_random_id()
self._enqueue_body( self._enqueue_body(bytes(ping_delay_disconnect(ping_id=ping_id, disconnect_delay=NO_PING_DISCONNECT)))
bytes(
ping_delay_disconnect(
ping_id=ping_id, disconnect_delay=NO_PING_DISCONNECT
)
)
)
self._next_ping = asyncio.get_running_loop().time() + PING_DELAY self._next_ping = asyncio.get_running_loop().time() + PING_DELAY
def _process_mtp_buffer(self, updates: list[Updates]) -> None: def _process_mtp_buffer(self, updates: list[Updates]) -> None:
@ -335,9 +329,7 @@ class Sender:
else: else:
self._process_bad_message(result) self._process_bad_message(result)
def _process_update( def _process_update(self, updates: list[Updates], update: bytes | bytearray | memoryview) -> None:
self, updates: list[Updates], update: bytes | bytearray | memoryview
) -> None:
try: try:
updates.append(Updates.from_bytes(update)) updates.append(Updates.from_bytes(update))
except ValueError: except ValueError:
@ -441,9 +433,7 @@ class Sender:
req.state.msg_id == msg_id or req.state.container_msg_id == msg_id req.state.msg_id == msg_id or req.state.container_msg_id == msg_id
): ):
raise RuntimeError("got response for unsent request") raise RuntimeError("got response for unsent request")
elif isinstance(req.state, Sent) and ( elif isinstance(req.state, Sent) and (req.state.msg_id == msg_id or req.state.container_msg_id == msg_id):
req.state.msg_id == msg_id or req.state.container_msg_id == msg_id
):
yield self._requests.pop(i) yield self._requests.pop(i)
@property @property

View File

@ -65,9 +65,7 @@ class ChatHashCache:
return self._has_peer(peer.peer) return self._has_peer(peer.peer)
elif isinstance(peer, types.NotifyForumTopic): elif isinstance(peer, types.NotifyForumTopic):
return self._has_peer(peer.peer) return self._has_peer(peer.peer)
elif isinstance( elif isinstance(peer, (types.NotifyUsers, types.NotifyChats, types.NotifyBroadcasts)):
peer, (types.NotifyUsers, types.NotifyChats, types.NotifyBroadcasts)
):
return True return True
else: else:
raise RuntimeError("unexpected case") raise RuntimeError("unexpected case")
@ -120,9 +118,7 @@ class ChatHashCache:
elif isinstance(participant, types.ChannelParticipantAdmin): elif isinstance(participant, types.ChannelParticipantAdmin):
return ( return (
self._has(participant.user_id) self._has(participant.user_id)
and ( and (participant.inviter_id is None or self._has(participant.inviter_id))
participant.inviter_id is None or self._has(participant.inviter_id)
)
and self._has(participant.promoted_by) and self._has(participant.promoted_by)
) )
elif isinstance(participant, types.ChannelParticipantBanned): elif isinstance(participant, types.ChannelParticipantBanned):

View File

@ -43,12 +43,8 @@ class PeerRef(abc.ABC):
__slots__ = ("identifier", "authorization") __slots__ = ("identifier", "authorization")
def __init__( def __init__(self, identifier: PeerIdentifier, authorization: PeerAuth = None) -> None:
self, identifier: PeerIdentifier, authorization: PeerAuth = None assert identifier >= 0, "PeerRef identifiers must be positive; see the documentation for Peers"
) -> None:
assert (
identifier >= 0
), "PeerRef identifiers must be positive; see the documentation for Peers"
self.identifier = identifier self.identifier = identifier
self.authorization = authorization self.authorization = authorization
@ -83,9 +79,7 @@ class PeerRef(abc.ABC):
authorization: Optional[int] = None authorization: Optional[int] = None
else: else:
try: try:
(authorization,) = struct.unpack( (authorization,) = struct.unpack("!q", base64.urlsafe_b64decode(auth.encode("ascii") + b"="))
"!q", base64.urlsafe_b64decode(auth.encode("ascii") + b"=")
)
except Exception: except Exception:
raise ValueError(f"invalid PeerRef string: {string!r}") raise ValueError(f"invalid PeerRef string: {string!r}")
@ -137,21 +131,14 @@ class PeerRef(abc.ABC):
if self.authorization is None: if self.authorization is None:
auth = "0" auth = "0"
else: else:
auth = ( auth = base64.urlsafe_b64encode(struct.pack("!q", self.authorization)).decode("ascii").rstrip("=")
base64.urlsafe_b64encode(struct.pack("!q", self.authorization))
.decode("ascii")
.rstrip("=")
)
return f"{self.identifier}.{auth}" return f"{self.identifier}.{auth}"
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
if not isinstance(other, self.__class__): if not isinstance(other, self.__class__):
return NotImplemented return NotImplemented
return ( return self.identifier == other.identifier and self.authorization == other.authorization
self.identifier == other.identifier
and self.authorization == other.authorization
)
@property @property
def _ref(self) -> UserRef | GroupRef | ChannelRef: def _ref(self) -> UserRef | GroupRef | ChannelRef:
@ -182,16 +169,12 @@ class UserRef(PeerRef):
def _to_input_peer(self) -> abcs.InputPeer: def _to_input_peer(self) -> abcs.InputPeer:
if self.identifier == SELF_USER_SENTINEL_ID: if self.identifier == SELF_USER_SENTINEL_ID:
return types.InputPeerSelf() return types.InputPeerSelf()
return types.InputPeerUser( return types.InputPeerUser(user_id=self.identifier, access_hash=self.authorization or 0)
user_id=self.identifier, access_hash=self.authorization or 0
)
def _to_input_user(self) -> abcs.InputUser: def _to_input_user(self) -> abcs.InputUser:
if self.identifier == SELF_USER_SENTINEL_ID: if self.identifier == SELF_USER_SENTINEL_ID:
return types.InputUserSelf() return types.InputUserSelf()
return types.InputUser( return types.InputUser(user_id=self.identifier, access_hash=self.authorization or 0)
user_id=self.identifier, access_hash=self.authorization or 0
)
def __str__(self) -> str: def __str__(self) -> str:
return f"{USER_PREFIX}{self._encode_str()}" return f"{USER_PREFIX}{self._encode_str()}"
@ -252,14 +235,10 @@ class ChannelRef(PeerRef):
return types.PeerChannel(channel_id=self.identifier) return types.PeerChannel(channel_id=self.identifier)
def _to_input_peer(self) -> abcs.InputPeer: def _to_input_peer(self) -> abcs.InputPeer:
return types.InputPeerChannel( return types.InputPeerChannel(channel_id=self.identifier, access_hash=self.authorization or 0)
channel_id=self.identifier, access_hash=self.authorization or 0
)
def _to_input_channel(self) -> types.InputChannel: def _to_input_channel(self) -> types.InputChannel:
return types.InputChannel( return types.InputChannel(channel_id=self.identifier, access_hash=self.authorization or 0)
channel_id=self.identifier, access_hash=self.authorization or 0
)
def __str__(self) -> str: def __str__(self) -> str:
return f"{CHANNEL_PREFIX}{self._encode_str()}" return f"{CHANNEL_PREFIX}{self._encode_str()}"

View File

@ -27,9 +27,7 @@ def update_short(short: types.UpdateShort) -> types.UpdatesCombined:
) )
def update_short_message( def update_short_message(short: types.UpdateShortMessage, self_id: int) -> types.UpdatesCombined:
short: types.UpdateShortMessage, self_id: int
) -> types.UpdatesCombined:
return update_short( return update_short(
types.UpdateShort( types.UpdateShort(
update=types.UpdateNewMessage( update=types.UpdateNewMessage(
@ -46,9 +44,7 @@ def update_short_message(
noforwards=False, noforwards=False,
reactions=None, reactions=None,
id=short.id, id=short.id,
from_id=types.PeerUser( from_id=types.PeerUser(user_id=self_id if short.out else short.user_id),
user_id=self_id if short.out else short.user_id
),
peer_id=types.PeerChat( peer_id=types.PeerChat(
chat_id=short.user_id, chat_id=short.user_id,
), ),

View File

@ -38,9 +38,7 @@ class PtsInfo:
self.entry = entry self.entry = entry
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return f"PtsInfo(pts={self.pts}, pts_count={self.pts_count}, entry={self.entry})"
f"PtsInfo(pts={self.pts}, pts_count={self.pts_count}, entry={self.entry})"
)
class State: class State:
@ -70,9 +68,7 @@ class PossibleGap:
self.updates = updates self.updates = updates
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return f"PossibleGap(deadline={self.deadline}, update_count={len(self.updates)})"
f"PossibleGap(deadline={self.deadline}, update_count={len(self.updates)})"
)
class PrematureEndReason(Enum): class PrematureEndReason(Enum):

View File

@ -91,13 +91,9 @@ class MessageBox:
self.map[ENTRY_ACCOUNT] = State(pts=state.pts, deadline=deadline) self.map[ENTRY_ACCOUNT] = State(pts=state.pts, deadline=deadline)
if state.qts != NO_SEQ: if state.qts != NO_SEQ:
self.map[ENTRY_SECRET] = State(pts=state.qts, deadline=deadline) self.map[ENTRY_SECRET] = State(pts=state.qts, deadline=deadline)
self.map.update( self.map.update((s.id, State(pts=s.pts, deadline=deadline)) for s in state.channels)
(s.id, State(pts=s.pts, deadline=deadline)) for s in state.channels
)
self.date = datetime.datetime.fromtimestamp( self.date = datetime.datetime.fromtimestamp(state.date, tz=datetime.timezone.utc)
state.date, tz=datetime.timezone.utc
)
self.seq = state.seq self.seq = state.seq
self.possible_gaps.clear() self.possible_gaps.clear()
self.getting_diff_for.clear() self.getting_diff_for.clear()
@ -136,28 +132,18 @@ class MessageBox:
default_deadline = next_updates_deadline() default_deadline = next_updates_deadline()
if self.possible_gaps: if self.possible_gaps:
deadline = min( deadline = min(default_deadline, *(gap.deadline for gap in self.possible_gaps.values()))
default_deadline, *(gap.deadline for gap in self.possible_gaps.values())
)
elif self.next_deadline in self.map: elif self.next_deadline in self.map:
deadline = min(default_deadline, self.map[self.next_deadline].deadline) deadline = min(default_deadline, self.map[self.next_deadline].deadline)
else: else:
deadline = default_deadline deadline = default_deadline
if now >= deadline: if now >= deadline:
self.getting_diff_for.update( self.getting_diff_for.update(entry for entry, gap in self.possible_gaps.items() if now >= gap.deadline)
entry self.getting_diff_for.update(entry for entry, state in self.map.items() if now >= state.deadline)
for entry, gap in self.possible_gaps.items()
if now >= gap.deadline
)
self.getting_diff_for.update(
entry for entry, state in self.map.items() if now >= state.deadline
)
if __debug__: if __debug__:
self._trace( self._trace("deadlines met, now getting diff for: %r", self.getting_diff_for)
"deadlines met, now getting diff for: %r", self.getting_diff_for
)
for entry in self.getting_diff_for: for entry in self.getting_diff_for:
self.possible_gaps.pop(entry, None) self.possible_gaps.pop(entry, None)
@ -171,19 +157,12 @@ class MessageBox:
entry: Entry = ENTRY_ACCOUNT # for pyright to know it's not unbound entry: Entry = ENTRY_ACCOUNT # for pyright to know it's not unbound
for entry in entries: for entry in entries:
if entry not in self.map: if entry not in self.map:
raise RuntimeError( raise RuntimeError("Called reset_deadline on an entry for which we do not have state")
"Called reset_deadline on an entry for which we do not have state"
)
self.map[entry].deadline = deadline self.map[entry].deadline = deadline
if self.next_deadline in entries: if self.next_deadline in entries:
self.next_deadline = min( self.next_deadline = min(self.map.items(), key=lambda entry_state: entry_state[1].deadline)[0]
self.map.items(), key=lambda entry_state: entry_state[1].deadline elif self.next_deadline in self.map and deadline < self.map[self.next_deadline].deadline:
)[0]
elif (
self.next_deadline in self.map
and deadline < self.map[self.next_deadline].deadline
):
self.next_deadline = entry self.next_deadline = entry
def reset_channel_deadline(self, channel_id: int, timeout: Optional[float]) -> None: def reset_channel_deadline(self, channel_id: int, timeout: Optional[float]) -> None:
@ -200,9 +179,7 @@ class MessageBox:
assert isinstance(state, types.updates.State) assert isinstance(state, types.updates.State)
self.map[ENTRY_ACCOUNT] = State(state.pts, deadline) self.map[ENTRY_ACCOUNT] = State(state.pts, deadline)
self.map[ENTRY_SECRET] = State(state.qts, deadline) self.map[ENTRY_SECRET] = State(state.qts, deadline)
self.date = datetime.datetime.fromtimestamp( self.date = datetime.datetime.fromtimestamp(state.date, tz=datetime.timezone.utc)
state.date, tz=datetime.timezone.utc
)
self.seq = state.seq self.seq = state.seq
def try_set_channel_state(self, id: int, pts: int) -> None: def try_set_channel_state(self, id: int, pts: int) -> None:
@ -215,15 +192,11 @@ class MessageBox:
def try_begin_get_diff(self, entry: Entry, reason: str) -> None: def try_begin_get_diff(self, entry: Entry, reason: str) -> None:
if entry not in self.map: if entry not in self.map:
if entry in self.possible_gaps: if entry in self.possible_gaps:
raise RuntimeError( raise RuntimeError("Should not have a possible_gap for an entry not in the state map")
"Should not have a possible_gap for an entry not in the state map"
)
return return
if __debug__: if __debug__:
self._trace( self._trace("marking entry=%r as needing difference because: %s", entry, reason)
"marking entry=%r as needing difference because: %s", entry, reason
)
self.getting_diff_for.add(entry) self.getting_diff_for.add(entry)
self.possible_gaps.pop(entry, None) self.possible_gaps.pop(entry, None)
@ -231,14 +204,10 @@ class MessageBox:
try: try:
self.getting_diff_for.remove(entry) self.getting_diff_for.remove(entry)
except KeyError: except KeyError:
raise RuntimeError( raise RuntimeError("Called end_get_diff on an entry which was not getting diff for")
"Called end_get_diff on an entry which was not getting diff for"
)
self.reset_deadlines({entry}, next_updates_deadline()) self.reset_deadlines({entry}, next_updates_deadline())
assert ( assert entry not in self.possible_gaps, "gaps shouldn't be created while getting difference"
entry not in self.possible_gaps
), "gaps shouldn't be created while getting difference"
def ensure_known_peer_hashes( def ensure_known_peer_hashes(
self, self,
@ -246,10 +215,7 @@ class MessageBox:
chat_hashes: ChatHashCache, chat_hashes: ChatHashCache,
) -> None: ) -> None:
if not chat_hashes.extend_from_updates(updates): if not chat_hashes.extend_from_updates(updates):
can_recover = ( can_recover = not isinstance(updates, types.UpdateShort) or pts_info_from_update(updates.update) is not None
not isinstance(updates, types.UpdateShort)
or pts_info_from_update(updates.update) is not None
)
if can_recover: if can_recover:
self.try_begin_get_diff(ENTRY_ACCOUNT, "missing hash") self.try_begin_get_diff(ENTRY_ACCOUNT, "missing hash")
raise Gap raise Gap
@ -275,9 +241,7 @@ class MessageBox:
if combined.seq_start != NO_SEQ: if combined.seq_start != NO_SEQ:
if self.seq + 1 > combined.seq_start: if self.seq + 1 > combined.seq_start:
if __debug__: if __debug__:
self._trace( self._trace("skipping updates as they should have already been handled")
"skipping updates as they should have already been handled"
)
return result, combined.users, combined.chats return result, combined.users, combined.chats
elif self.seq + 1 < combined.seq_start: elif self.seq + 1 < combined.seq_start:
self.try_begin_get_diff(ENTRY_ACCOUNT, "detected gap") self.try_begin_get_diff(ENTRY_ACCOUNT, "detected gap")
@ -305,17 +269,13 @@ class MessageBox:
if __debug__: if __debug__:
self._trace("updating seq as local pts was updated too") self._trace("updating seq as local pts was updated too")
if combined.date != NO_DATE: if combined.date != NO_DATE:
self.date = datetime.datetime.fromtimestamp( self.date = datetime.datetime.fromtimestamp(combined.date, tz=datetime.timezone.utc)
combined.date, tz=datetime.timezone.utc
)
if combined.seq != NO_SEQ: if combined.seq != NO_SEQ:
self.seq = combined.seq self.seq = combined.seq
if self.possible_gaps: if self.possible_gaps:
if __debug__: if __debug__:
self._trace( self._trace("trying to re-apply count=%r possible gaps", len(self.possible_gaps))
"trying to re-apply count=%r possible gaps", len(self.possible_gaps)
)
for key in list(self.possible_gaps.keys()): for key in list(self.possible_gaps.keys()):
self.possible_gaps[key].updates.sort(key=update_sort_key) self.possible_gaps[key].updates.sort(key=update_sort_key)
@ -332,9 +292,7 @@ class MessageBox:
applied, applied,
) )
self.possible_gaps = { self.possible_gaps = {entry: gap for entry, gap in self.possible_gaps.items() if gap.updates}
entry: gap for entry, gap in self.possible_gaps.items() if gap.updates
}
return result, combined.users, combined.chats return result, combined.users, combined.chats
@ -384,8 +342,7 @@ class MessageBox:
) )
if pts.entry not in self.possible_gaps: if pts.entry not in self.possible_gaps:
self.possible_gaps[pts.entry] = PossibleGap( self.possible_gaps[pts.entry] = PossibleGap(
deadline=asyncio.get_running_loop().time() deadline=asyncio.get_running_loop().time() + POSSIBLE_GAP_TIMEOUT,
+ POSSIBLE_GAP_TIMEOUT,
updates=[], updates=[],
) )
@ -413,20 +370,14 @@ class MessageBox:
for entry in (ENTRY_ACCOUNT, ENTRY_SECRET): for entry in (ENTRY_ACCOUNT, ENTRY_SECRET):
if entry in self.getting_diff_for: if entry in self.getting_diff_for:
if entry not in self.map: if entry not in self.map:
raise RuntimeError( raise RuntimeError("Should not try to get difference for an entry without known state")
"Should not try to get difference for an entry without known state"
)
gd = functions.updates.get_difference( gd = functions.updates.get_difference(
pts=self.map[ENTRY_ACCOUNT].pts, pts=self.map[ENTRY_ACCOUNT].pts,
pts_limit=None, pts_limit=None,
pts_total_limit=None, pts_total_limit=None,
date=int(self.date.timestamp()), date=int(self.date.timestamp()),
qts=( qts=(self.map[ENTRY_SECRET].pts if ENTRY_SECRET in self.map else NO_SEQ),
self.map[ENTRY_SECRET].pts
if ENTRY_SECRET in self.map
else NO_SEQ
),
qts_limit=None, qts_limit=None,
) )
if __debug__: if __debug__:
@ -447,9 +398,7 @@ class MessageBox:
result: tuple[list[abcs.Update], Sequence[abcs.User], Sequence[abcs.Chat]] result: tuple[list[abcs.Update], Sequence[abcs.User], Sequence[abcs.Chat]]
if isinstance(diff, types.updates.DifferenceEmpty): if isinstance(diff, types.updates.DifferenceEmpty):
finish = True finish = True
self.date = datetime.datetime.fromtimestamp( self.date = datetime.datetime.fromtimestamp(diff.date, tz=datetime.timezone.utc)
diff.date, tz=datetime.timezone.utc
)
self.seq = diff.seq self.seq = diff.seq
result = [], [], [] result = [], [], []
elif isinstance(diff, types.updates.Difference): elif isinstance(diff, types.updates.Difference):
@ -502,9 +451,7 @@ class MessageBox:
assert isinstance(state, types.updates.State) assert isinstance(state, types.updates.State)
self.map[ENTRY_ACCOUNT].pts = state.pts self.map[ENTRY_ACCOUNT].pts = state.pts
self.map[ENTRY_SECRET].pts = state.qts self.map[ENTRY_SECRET].pts = state.qts
self.date = datetime.datetime.fromtimestamp( self.date = datetime.datetime.fromtimestamp(state.date, tz=datetime.timezone.utc)
state.date, tz=datetime.timezone.utc
)
self.seq = state.seq self.seq = state.seq
updates, users, chats = self.process_updates( updates, users, chats = self.process_updates(
@ -560,19 +507,13 @@ class MessageBox:
channel=channel, channel=channel,
filter=types.ChannelMessagesFilterEmpty(), filter=types.ChannelMessagesFilterEmpty(),
pts=state.pts, pts=state.pts,
limit=( limit=(BOT_CHANNEL_DIFF_LIMIT if chat_hashes.is_self_bot else USER_CHANNEL_DIFF_LIMIT),
BOT_CHANNEL_DIFF_LIMIT
if chat_hashes.is_self_bot
else USER_CHANNEL_DIFF_LIMIT
),
) )
if __debug__: if __debug__:
self._trace("requesting channel difference: %s", gd) self._trace("requesting channel difference: %s", gd)
return gd return gd
else: else:
raise RuntimeError( raise RuntimeError("should not try to get difference for an entry without known state")
"should not try to get difference for an entry without known state"
)
else: else:
self.end_get_diff(entry) self.end_get_diff(entry)
self.map.pop(entry, None) self.map.pop(entry, None)
@ -638,9 +579,7 @@ class MessageBox:
else: else:
raise RuntimeError("unexpected case") raise RuntimeError("unexpected case")
def end_channel_difference( def end_channel_difference(self, channel_id: int, reason: PrematureEndReason) -> None:
self, channel_id: int, reason: PrematureEndReason
) -> None:
entry: Entry = channel_id entry: Entry = channel_id
if __debug__: if __debug__:
self._trace("ending channel=%r difference: %s", entry, reason) self._trace("ending channel=%r difference: %s", entry, reason)

View File

@ -38,9 +38,7 @@ class SqliteSession(Storage):
if version == 7: if version == 7:
session = self._load_v7(c) session = self._load_v7(c)
else: else:
raise ValueError( raise ValueError("only migration from sqlite session format 7 supported")
"only migration from sqlite session format 7 supported"
)
self._reset(c) self._reset(c)
self._get_or_init_version(c) self._get_or_init_version(c)
@ -105,11 +103,7 @@ class SqliteSession(Storage):
DataCenter(id=id, ipv4_addr=ipv4_addr, ipv6_addr=ipv6_addr, auth=auth) DataCenter(id=id, ipv4_addr=ipv4_addr, ipv6_addr=ipv6_addr, auth=auth)
for (id, ipv4_addr, ipv6_addr, auth) in datacenter for (id, ipv4_addr, ipv6_addr, auth) in datacenter
], ],
user=( user=(User(id=user[0], dc=user[1], bot=bool(user[2]), username=user[3]) if user else None),
User(id=user[0], dc=user[1], bot=bool(user[2]), username=user[3])
if user
else None
),
state=( state=(
UpdateState( UpdateState(
pts=state[0], pts=state[0],
@ -166,9 +160,7 @@ class SqliteSession(Storage):
@staticmethod @staticmethod
def _get_or_init_version(c: sqlite3.Cursor) -> int: def _get_or_init_version(c: sqlite3.Cursor) -> int:
c.execute( c.execute("select name from sqlite_master where type='table' and name='version'")
"select name from sqlite_master where type='table' and name='version'"
)
if c.fetchone(): if c.fetchone():
c.execute("select version from version") c.execute("select version from version")
tup = c.fetchone() tup = c.fetchone()

View File

@ -23,9 +23,7 @@ def _bootstrap_get_ty(constructor_id: int) -> Optional[Type["Serializable"]]:
from ..mtproto.layer import TYPE_MAPPING as MTPROTO_TYPES from ..mtproto.layer import TYPE_MAPPING as MTPROTO_TYPES
if API_TYPES.keys() & MTPROTO_TYPES.keys(): if API_TYPES.keys() & MTPROTO_TYPES.keys():
raise RuntimeError( raise RuntimeError("generated api and mtproto schemas cannot have colliding constructor identifiers")
"generated api and mtproto schemas cannot have colliding constructor identifiers"
)
ALL_TYPES = API_TYPES | MTPROTO_TYPES ALL_TYPES = API_TYPES | MTPROTO_TYPES
# Signatures don't fully match, but this is a private method # Signatures don't fully match, but this is a private method
@ -39,9 +37,7 @@ class Reader:
__slots__ = ("_view", "_pos", "_len") __slots__ = ("_view", "_pos", "_len")
def __init__(self, buffer: "Buffer") -> None: def __init__(self, buffer: "Buffer") -> None:
self._view = ( self._view = memoryview(buffer) if not isinstance(buffer, memoryview) else buffer
memoryview(buffer) if not isinstance(buffer, memoryview) else buffer
)
self._pos = 0 self._pos = 0
self._len = len(self._view) self._len = len(self._view)

View File

@ -14,9 +14,7 @@ def _bootstrap_get_deserializer(
from ..mtproto.layer import RESPONSE_MAPPING as MTPROTO_DESER from ..mtproto.layer import RESPONSE_MAPPING as MTPROTO_DESER
if API_DESER.keys() & MTPROTO_DESER.keys(): if API_DESER.keys() & MTPROTO_DESER.keys():
raise RuntimeError( raise RuntimeError("generated api and mtproto schemas cannot have colliding constructor identifiers")
"generated api and mtproto schemas cannot have colliding constructor identifiers"
)
ALL_DESER = API_DESER | MTPROTO_DESER ALL_DESER = API_DESER | MTPROTO_DESER
Request._get_deserializer = ALL_DESER.get # type: ignore [assignment] Request._get_deserializer = ALL_DESER.get # type: ignore [assignment]

View File

@ -49,9 +49,7 @@ class Serializable(abc.ABC):
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
if not isinstance(other, self.__class__): if not isinstance(other, self.__class__):
return NotImplemented return NotImplemented
return all( return all(getattr(self, attr) == getattr(other, attr) for attr in self.__slots__)
getattr(self, attr) == getattr(other, attr) for attr in self.__slots__
)
def serialize_bytes_to(buffer: bytearray, data: bytes | bytearray | memoryview) -> None: def serialize_bytes_to(buffer: bytearray, data: bytes | bytearray | memoryview) -> None:

View File

@ -19,6 +19,7 @@ and those you can define when using :meth:`telethon.Client.send_message`:
buttons.Callback('Demo', b'data') buttons.Callback('Demo', b'data')
]) ])
""" """
from .._impl.client.types.buttons import ( from .._impl.client.types.buttons import (
Callback, Callback,
RequestGeoLocation, RequestGeoLocation,

View File

@ -26,25 +26,16 @@ def test_auth_key_id() -> None:
def test_calc_new_nonce_hash1() -> None: def test_calc_new_nonce_hash1() -> None:
auth_key = get_auth_key() auth_key = get_auth_key()
new_nonce = get_new_nonce() new_nonce = get_new_nonce()
assert ( assert auth_key.calc_new_nonce_hash(new_nonce, 1) == 258944117842285651226187582903746985063
auth_key.calc_new_nonce_hash(new_nonce, 1)
== 258944117842285651226187582903746985063
)
def test_calc_new_nonce_hash2() -> None: def test_calc_new_nonce_hash2() -> None:
auth_key = get_auth_key() auth_key = get_auth_key()
new_nonce = get_new_nonce() new_nonce = get_new_nonce()
assert ( assert auth_key.calc_new_nonce_hash(new_nonce, 2) == 324588944215647649895949797213421233055
auth_key.calc_new_nonce_hash(new_nonce, 2)
== 324588944215647649895949797213421233055
)
def test_calc_new_nonce_hash3() -> None: def test_calc_new_nonce_hash3() -> None:
auth_key = get_auth_key() auth_key = get_auth_key()
new_nonce = get_new_nonce() new_nonce = get_new_nonce()
assert ( assert auth_key.calc_new_nonce_hash(new_nonce, 3) == 100989356540453064705070297823778556733
auth_key.calc_new_nonce_hash(new_nonce, 3)
== 100989356540453064705070297823778556733
)

View File

@ -66,14 +66,8 @@ def test_key_from_nonce() -> None:
new_nonce = int.from_bytes(bytes(range(32))) new_nonce = int.from_bytes(bytes(range(32)))
key, iv = generate_key_data_from_nonce(server_nonce, new_nonce) key, iv = generate_key_data_from_nonce(server_nonce, new_nonce)
assert ( assert key == b'\x07X\xf1S;a]$\xf6\xe8\xa9Jo\xcb\xee\nU\xea\xab"\x17\xd7)\\\xa9!=\x1a-}\x16\xa6'
key assert iv == b"Z\x84\x10\x8e\x98\x05el\xe8d\x07\x0e\x16nb\x18\xf6x>\x85\x11G\x1aZ\xb7\x80,\xf2\x00\x01\x02\x03"
== b'\x07X\xf1S;a]$\xf6\xe8\xa9Jo\xcb\xee\nU\xea\xab"\x17\xd7)\\\xa9!=\x1a-}\x16\xa6'
)
assert (
iv
== b"Z\x84\x10\x8e\x98\x05el\xe8d\x07\x0e\x16nb\x18\xf6x>\x85\x11G\x1aZ\xb7\x80,\xf2\x00\x01\x02\x03"
)
def test_verify_ige_encryption() -> None: def test_verify_ige_encryption() -> None:
@ -88,5 +82,7 @@ def test_verify_ige_decryption() -> None:
ciphertext = get_test_aes_key_or_iv() ciphertext = get_test_aes_key_or_iv()
key = get_test_aes_key_or_iv() key = get_test_aes_key_or_iv()
iv = get_test_aes_key_or_iv() iv = get_test_aes_key_or_iv()
expected = b"\xe5wz\xfa\xcd{,\x16\xf7\xac@\xca\xe6\x1e\xf6\x03\xfe\xe6\t\x8f\xb8\xa8\x86\n\xb9\xeeg,\xd7\xe5\xba\xcc" expected = (
b"\xe5wz\xfa\xcd{,\x16\xf7\xac@\xca\xe6\x1e\xf6\x03\xfe\xe6\t\x8f\xb8\xa8\x86\n\xb9\xeeg,\xd7\xe5\xba\xcc"
)
assert decrypt_ige(ciphertext, key, iv) == expected assert decrypt_ige(ciphertext, key, iv) == expected

View File

@ -32,10 +32,7 @@ def test_parse_all_entities_markdown() -> None:
markdown = "Some **bold** (__strong__), *italics* (_cursive_), inline `code`, a\n```rust\npre\n```\nblock, a [link](https://example.com), and [mentions](tg://user?id=12345678)" markdown = "Some **bold** (__strong__), *italics* (_cursive_), inline `code`, a\n```rust\npre\n```\nblock, a [link](https://example.com), and [mentions](tg://user?id=12345678)"
text, entities = parse_markdown_message(markdown) text, entities = parse_markdown_message(markdown)
assert ( assert text == "Some bold (strong), italics (cursive), inline code, a\npre\nblock, a link, and mentions"
text
== "Some bold (strong), italics (cursive), inline code, a\npre\nblock, a link, and mentions"
)
assert entities == [ assert entities == [
types.MessageEntityBold(offset=5, length=4), types.MessageEntityBold(offset=5, length=4),
types.MessageEntityBold(offset=11, length=6), types.MessageEntityBold(offset=11, length=6),
@ -92,10 +89,7 @@ def test_parse_emoji_html() -> None:
def test_parse_all_entities_html() -> None: def test_parse_all_entities_html() -> None:
html = 'Some <b>bold</b> (<strong>strong</strong>), <i>italics</i> (<em>cursive</em>), inline <code>code</code>, a <pre>pre</pre> block, a <a href="https://example.com">link</a>, <details>spoilers</details> and <a href="tg://user?id=12345678">mentions</a>' html = 'Some <b>bold</b> (<strong>strong</strong>), <i>italics</i> (<em>cursive</em>), inline <code>code</code>, a <pre>pre</pre> block, a <a href="https://example.com">link</a>, <details>spoilers</details> and <a href="tg://user?id=12345678">mentions</a>'
text, entities = parse_html_message(html) text, entities = parse_html_message(html)
assert ( assert text == "Some bold (strong), italics (cursive), inline code, a pre block, a link, spoilers and mentions"
text
== "Some bold (strong), italics (cursive), inline code, a pre block, a link, spoilers and mentions"
)
assert entities == [ assert entities == [
types.MessageEntityBold(offset=5, length=4), types.MessageEntityBold(offset=5, length=4),
types.MessageEntityBold(offset=11, length=6), types.MessageEntityBold(offset=11, length=6),

View File

@ -38,6 +38,10 @@ build-backend = "setuptools.build_meta"
version = {attr = "telethon_generator.version.__version__"} version = {attr = "telethon_generator.version.__version__"}
[tool.ruff] [tool.ruff]
line-length = 120
[tool.ruff.lint]
select = ["F", "E", "W", "I"]
ignore = [ ignore = [
"E501", # formatter takes care of lines that are too long besides documentation "E501", # formatter takes care of lines that are too long besides documentation
] ]

View File

@ -17,9 +17,7 @@ from .serde.deserialization import (
from .serde.serialization import generate_function, generate_write from .serde.serialization import generate_function, generate_write
def generate_init( def generate_init(writer: SourceWriter, namespaces: set[str], classes: set[str]) -> None:
writer: SourceWriter, namespaces: set[str], classes: set[str]
) -> None:
sorted_cls = list(sorted(classes)) sorted_cls = list(sorted(classes))
sorted_ns = list(sorted(namespaces)) sorted_ns = list(sorted(namespaces))
@ -93,9 +91,7 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
writer = fs.open(type_path) writer = fs.open(type_path)
if type_path not in fs: if type_path not in fs:
writer.write( writer.write("# pyright: reportUnusedImport=false, reportConstantRedefinition=false")
"# pyright: reportUnusedImport=false, reportConstantRedefinition=false"
)
writer.write("import struct") writer.write("import struct")
writer.write("from typing import Optional, Self, Sequence") writer.write("from typing import Optional, Self, Sequence")
writer.write("from .. import abcs") writer.write("from .. import abcs")
@ -106,9 +102,7 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
generated_type_names.add(f"{ns}{to_class_name(typedef.name)}") generated_type_names.add(f"{ns}{to_class_name(typedef.name)}")
# class Type(BaseType) # class Type(BaseType)
writer.write( writer.write(f"class {to_class_name(typedef.name)}({inner_type_fmt(typedef.ty)}):")
f"class {to_class_name(typedef.name)}({inner_type_fmt(typedef.ty)}):"
)
# __slots__ = ('params', ...) # __slots__ = ('params', ...)
slots = " ".join(f"'{p.name}'," for p in property_params) slots = " ".join(f"'{p.name}'," for p in property_params)
@ -121,9 +115,7 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
# def __init__() # def __init__()
if property_params: if property_params:
params = "".join( params = "".join(f", {p.name}: {param_type_fmt(p.ty)}" for p in property_params)
f", {p.name}: {param_type_fmt(p.ty)}" for p in property_params
)
writer.write(f" def __init__(_s, *{params}) -> None:") writer.write(f" def __init__(_s, *{params}) -> None:")
for p in property_params: for p in property_params:
writer.write(f" _s.{p.name} = {p.name}") writer.write(f" _s.{p.name} = {p.name}")
@ -151,9 +143,7 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
raise ValueError("nested function-namespaces are not supported") raise ValueError("nested function-namespaces are not supported")
elif len(functiondef.namespace) == 1: elif len(functiondef.namespace) == 1:
function_namespaces.add(functiondef.namespace[0]) function_namespaces.add(functiondef.namespace[0])
function_path = (Path("functions") / functiondef.namespace[0]).with_suffix( function_path = (Path("functions") / functiondef.namespace[0]).with_suffix(".py")
".py"
)
else: else:
function_def_names.add(to_method_name(functiondef.name)) function_def_names.add(to_method_name(functiondef.name))
function_path = Path("functions/_nons.py") function_path = Path("functions/_nons.py")
@ -173,18 +163,14 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
params = "".join(f", {p.name}: {param_type_fmt(p.ty)}" for p in required_params) params = "".join(f", {p.name}: {param_type_fmt(p.ty)}" for p in required_params)
star = "*" if params else "" star = "*" if params else ""
return_ty = param_type_fmt(NormalParameter(ty=functiondef.ty, flag=None)) return_ty = param_type_fmt(NormalParameter(ty=functiondef.ty, flag=None))
writer.write( writer.write(f"def {to_method_name(functiondef.name)}({star}{params}) -> Request[{return_ty}]:")
f"def {to_method_name(functiondef.name)}({star}{params}) -> Request[{return_ty}]:"
)
writer.indent(2) writer.indent(2)
generate_function(writer, functiondef) generate_function(writer, functiondef)
writer.dedent(2) writer.dedent(2)
generate_init(fs.open(Path("abcs/__init__.py")), abc_namespaces, abc_class_names) generate_init(fs.open(Path("abcs/__init__.py")), abc_namespaces, abc_class_names)
generate_init(fs.open(Path("types/__init__.py")), type_namespaces, type_class_names) generate_init(fs.open(Path("types/__init__.py")), type_namespaces, type_class_names)
generate_init( generate_init(fs.open(Path("functions/__init__.py")), function_namespaces, function_def_names)
fs.open(Path("functions/__init__.py")), function_namespaces, function_def_names
)
writer = fs.open(Path("layer.py")) writer = fs.open(Path("layer.py"))
writer.write("# pyright: reportUnusedImport=false") writer.write("# pyright: reportUnusedImport=false")
@ -194,16 +180,12 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
) )
writer.write("from typing import cast, Type") writer.write("from typing import cast, Type")
writer.write(f"LAYER = {tl.layer!r}") writer.write(f"LAYER = {tl.layer!r}")
writer.write( writer.write("TYPE_MAPPING = {t.constructor_id(): t for t in cast(tuple[Type[Serializable]], (")
"TYPE_MAPPING = {t.constructor_id(): t for t in cast(tuple[Type[Serializable]], ("
)
for name in sorted(generated_type_names): for name in sorted(generated_type_names):
writer.write(f" types.{name},") writer.write(f" types.{name},")
writer.write("))}") writer.write("))}")
writer.write("RESPONSE_MAPPING = {") writer.write("RESPONSE_MAPPING = {")
for functiondef in tl.functiondefs: for functiondef in tl.functiondefs:
writer.write( writer.write(f" {hex(functiondef.id)}: {function_deserializer_fmt(functiondef)},")
f" {hex(functiondef.id)}: {function_deserializer_fmt(functiondef)},"
)
writer.write("}") writer.write("}")
writer.write("__all__ = ['LAYER', 'TYPE_MAPPING', 'RESPONSE_MAPPING']") writer.write("__all__ = ['LAYER', 'TYPE_MAPPING', 'RESPONSE_MAPPING']")

View File

@ -50,9 +50,7 @@ _TRIVIAL_STRUCT_MAP = {"int": "i", "long": "q", "double": "d", "Bool": "I"}
def trivial_struct_fmt(ty: BaseParameter) -> str: def trivial_struct_fmt(ty: BaseParameter) -> str:
try: try:
return ( return _TRIVIAL_STRUCT_MAP[ty.ty.name] if isinstance(ty, NormalParameter) else "I"
_TRIVIAL_STRUCT_MAP[ty.ty.name] if isinstance(ty, NormalParameter) else "I"
)
except KeyError: except KeyError:
raise ValueError("input param was not trivial") raise ValueError("input param was not trivial")

View File

@ -38,9 +38,7 @@ def reader_read_fmt(ty: Type, constructor_id: int) -> tuple[str, Optional[str]]:
return f"reader.read_serializable({inner_type_fmt(ty)})", "type-abstract" return f"reader.read_serializable({inner_type_fmt(ty)})", "type-abstract"
def generate_normal_param_read( def generate_normal_param_read(writer: SourceWriter, name: str, param: NormalParameter, constructor_id: int) -> None:
writer: SourceWriter, name: str, param: NormalParameter, constructor_id: int
) -> None:
flag_check = f"_{param.flag.name} & {1 << param.flag.index}" if param.flag else None flag_check = f"_{param.flag.name} & {1 << param.flag.index}" if param.flag else None
if param.ty.name == "true": if param.ty.name == "true":
if not flag_check: if not flag_check:
@ -55,9 +53,7 @@ def generate_normal_param_read(
if param.ty.generic_arg: if param.ty.generic_arg:
if param.ty.name not in ("Vector", "vector"): if param.ty.name not in ("Vector", "vector"):
raise ValueError( raise ValueError("generic_arg deserialization for non-vectors is not supported")
"generic_arg deserialization for non-vectors is not supported"
)
if param.ty.bare: if param.ty.bare:
writer.write("__len = reader.read_fmt('<i', 4)[0]") writer.write("__len = reader.read_fmt('<i', 4)[0]")
@ -70,18 +66,12 @@ def generate_normal_param_read(
if is_trivial(generic): if is_trivial(generic):
fmt = trivial_struct_fmt(generic) fmt = trivial_struct_fmt(generic)
size = struct.calcsize(f"<{fmt}") size = struct.calcsize(f"<{fmt}")
writer.write( writer.write(f"_{name} = [*reader.read_fmt(f'<{{__len}}{fmt}', __len * {size})]")
f"_{name} = [*reader.read_fmt(f'<{{__len}}{fmt}', __len * {size})]"
)
if param.ty.generic_arg.name == "Bool": if param.ty.generic_arg.name == "Bool":
writer.write( writer.write(f"assert all(__x in (0xbc799737, 0x0x997275b5) for __x in _{name})")
f"assert all(__x in (0xbc799737, 0x0x997275b5) for __x in _{name})"
)
writer.write(f"_{name} = [_{name} == 0x997275b5]") writer.write(f"_{name} = [_{name} == 0x997275b5]")
else: else:
fmt_read, type_ignore = reader_read_fmt( fmt_read, type_ignore = reader_read_fmt(param.ty.generic_arg, constructor_id)
param.ty.generic_arg, constructor_id
)
comment = f" # type: ignore [{type_ignore}]" if type_ignore else "" comment = f" # type: ignore [{type_ignore}]" if type_ignore else ""
writer.write(f"_{name} = [{fmt_read} for _ in range(__len)]{comment}") writer.write(f"_{name} = [{fmt_read} for _ in range(__len)]{comment}")
else: else:
@ -127,9 +117,7 @@ def param_value_fmt(param: Parameter) -> str:
def function_deserializer_fmt(defn: Definition) -> str: def function_deserializer_fmt(defn: Definition) -> str:
if defn.ty.generic_arg: if defn.ty.generic_arg:
if defn.ty.name != ("Vector"): if defn.ty.name != ("Vector"):
raise ValueError( raise ValueError("generic_arg return for non-boxed-vectors is not supported")
"generic_arg return for non-boxed-vectors is not supported"
)
elif defn.ty.generic_ref: elif defn.ty.generic_ref:
raise ValueError("return for generic refs inside vector is not supported") raise ValueError("return for generic refs inside vector is not supported")
elif is_trivial(NormalParameter(ty=defn.ty.generic_arg, flag=None)): elif is_trivial(NormalParameter(ty=defn.ty.generic_arg, flag=None)):
@ -138,13 +126,9 @@ def function_deserializer_fmt(defn: Definition) -> str:
elif defn.ty.generic_arg.name == "long": elif defn.ty.generic_arg.name == "long":
return "deserialize_i64_list" return "deserialize_i64_list"
else: else:
raise ValueError( raise ValueError(f"return for trivial arg {defn.ty.generic_arg} is not supported")
f"return for trivial arg {defn.ty.generic_arg} is not supported"
)
elif defn.ty.generic_arg.bare: elif defn.ty.generic_arg.bare:
raise ValueError( raise ValueError("return for non-boxed serializables inside a vector is not supported")
"return for non-boxed serializables inside a vector is not supported"
)
else: else:
return f"list_deserializer({inner_type_fmt(defn.ty.generic_arg)})" return f"list_deserializer({inner_type_fmt(defn.ty.generic_arg)})"
elif defn.ty.generic_ref: elif defn.ty.generic_ref:

View File

@ -15,15 +15,11 @@ def param_value_expr(param: Parameter) -> str:
return f"{pre}{mid}{suf}" return f"{pre}{mid}{suf}"
def generate_buffer_append( def generate_buffer_append(writer: SourceWriter, buffer: str, name: str, ty: Type) -> None:
writer: SourceWriter, buffer: str, name: str, ty: Type
) -> None:
if is_trivial(NormalParameter(ty=ty, flag=None)): if is_trivial(NormalParameter(ty=ty, flag=None)):
fmt = trivial_struct_fmt(NormalParameter(ty=ty, flag=None)) fmt = trivial_struct_fmt(NormalParameter(ty=ty, flag=None))
if ty.name == "Bool": if ty.name == "Bool":
writer.write( writer.write(f"{buffer} += struct.pack(f'<{fmt}', (0x997275b5 if {name} else 0xbc799737))")
f"{buffer} += struct.pack(f'<{fmt}', (0x997275b5 if {name} else 0xbc799737))"
)
else: else:
writer.write(f"{buffer} += struct.pack(f'<{fmt}', {name})") writer.write(f"{buffer} += struct.pack(f'<{fmt}', {name})")
elif ty.generic_ref or ty.name == "Object": elif ty.generic_ref or ty.name == "Object":
@ -58,9 +54,7 @@ def generate_normal_param_write(
if param.ty.generic_arg: if param.ty.generic_arg:
if param.ty.name not in ("Vector", "vector"): if param.ty.name not in ("Vector", "vector"):
raise ValueError( raise ValueError("generic_arg deserialization for non-vectors is not supported")
"generic_arg deserialization for non-vectors is not supported"
)
if param.ty.bare: if param.ty.bare:
writer.write(f"{buffer} += struct.pack('<i', len({name}))") writer.write(f"{buffer} += struct.pack('<i', len({name}))")
@ -76,9 +70,7 @@ def generate_normal_param_write(
f"{buffer} += struct.pack(f'<{{len({name})}}{fmt}', *(0x997275b5 if {tmp} else 0xbc799737 for {tmp} in {name}))" f"{buffer} += struct.pack(f'<{{len({name})}}{fmt}', *(0x997275b5 if {tmp} else 0xbc799737 for {tmp} in {name}))"
) )
else: else:
writer.write( writer.write(f"{buffer} += struct.pack(f'<{{len({name})}}{fmt}', *{name})")
f"{buffer} += struct.pack(f'<{{len({name})}}{fmt}', *{name})"
)
else: else:
tmp = next(tmp_names) tmp = next(tmp_names)
writer.write(f"for {tmp} in {name}:") writer.write(f"for {tmp} in {name}:")
@ -110,9 +102,7 @@ def generate_write(writer: SourceWriter, defn: Definition) -> None:
else f"(0 if self.{p.name} is None else {1 << p.ty.flag.index})" else f"(0 if self.{p.name} is None else {1 << p.ty.flag.index})"
) )
for p in defn.params for p in defn.params
if isinstance(p.ty, NormalParameter) if isinstance(p.ty, NormalParameter) and p.ty.flag and p.ty.flag.name == param.name
and p.ty.flag
and p.ty.flag.name == param.name
) )
writer.write(f"_{param.name} = {flags or 0}") writer.write(f"_{param.name} = {flags or 0}")
@ -123,9 +113,7 @@ def generate_write(writer: SourceWriter, defn: Definition) -> None:
for param in iter: for param in iter:
if not isinstance(param.ty, NormalParameter): if not isinstance(param.ty, NormalParameter):
raise RuntimeError("FlagsParameter should be considered trivial") raise RuntimeError("FlagsParameter should be considered trivial")
generate_normal_param_write( generate_normal_param_write(writer, tmp_names, "buffer", f"self.{param.name}", param.ty)
writer, tmp_names, "buffer", f"self.{param.name}", param.ty
)
def generate_function(writer: SourceWriter, defn: Definition) -> None: def generate_function(writer: SourceWriter, defn: Definition) -> None:
@ -148,9 +136,7 @@ def generate_function(writer: SourceWriter, defn: Definition) -> None:
else f"(0 if {p.name} is None else {1 << p.ty.flag.index})" else f"(0 if {p.name} is None else {1 << p.ty.flag.index})"
) )
for p in defn.params for p in defn.params
if isinstance(p.ty, NormalParameter) if isinstance(p.ty, NormalParameter) and p.ty.flag and p.ty.flag.name == param.name
and p.ty.flag
and p.ty.flag.name == param.name
) )
writer.write(f"{param.name} = {flags or 0}") writer.write(f"{param.name} = {flags or 0}")
@ -161,7 +147,5 @@ def generate_function(writer: SourceWriter, defn: Definition) -> None:
for param in iter: for param in iter:
if not isinstance(param.ty, NormalParameter): if not isinstance(param.ty, NormalParameter):
raise RuntimeError("FlagsParameter should be considered trivial") raise RuntimeError("FlagsParameter should be considered trivial")
generate_normal_param_write( generate_normal_param_write(writer, tmp_names, "_buffer", param.name, param.ty)
writer, tmp_names, "_buffer", param.name, param.ty
)
writer.write("return Request(b'' + _buffer)") writer.write("return Request(b'' + _buffer)")

View File

@ -35,6 +35,4 @@ def load_tl_file(path: str | Path) -> ParsedTl:
else: else:
functiondefs.append(definition) functiondefs.append(definition)
return ParsedTl( return ParsedTl(layer=layer, typedefs=list(typedefs), functiondefs=list(functiondefs))
layer=layer, typedefs=list(typedefs), functiondefs=list(functiondefs)
)

View File

@ -14,9 +14,7 @@ def gen_py_code(
functiondefs: Optional[list[Definition]] = None, functiondefs: Optional[list[Definition]] = None,
) -> str: ) -> str:
fs = FakeFs() fs = FakeFs()
generate( generate(fs, ParsedTl(layer=0, typedefs=typedefs or [], functiondefs=functiondefs or []))
fs, ParsedTl(layer=0, typedefs=typedefs or [], functiondefs=functiondefs or [])
)
generated = bytearray() generated = bytearray()
for path, data in fs._files.items(): for path, data in fs._files.items():
if path.stem not in ("__init__", "layer"): if path.stem not in ("__init__", "layer"):
@ -27,9 +25,7 @@ def gen_py_code(
def test_generic_functions_use_bytes_parameters() -> None: def test_generic_functions_use_bytes_parameters() -> None:
definitions = get_definitions( definitions = get_definitions("invokeWithLayer#da9b0d0d {X:Type} layer:int query:!X = X;")
"invokeWithLayer#da9b0d0d {X:Type} layer:int query:!X = X;"
)
result = gen_py_code(functiondefs=definitions) result = gen_py_code(functiondefs=definitions)
assert "invoke_with_layer" in result assert "invoke_with_layer" in result
assert "query: _bytes" in result assert "query: _bytes" in result

View File

@ -54,18 +54,14 @@ def test_valid_param() -> None:
assert Parameter.from_str("foo:!bar") == Parameter( assert Parameter.from_str("foo:!bar") == Parameter(
name="foo", name="foo",
ty=NormalParameter( ty=NormalParameter(
ty=Type( ty=Type(namespace=[], name="bar", bare=True, generic_ref=True, generic_arg=None),
namespace=[], name="bar", bare=True, generic_ref=True, generic_arg=None
),
flag=None, flag=None,
), ),
) )
assert Parameter.from_str("foo:bar.1?baz") == Parameter( assert Parameter.from_str("foo:bar.1?baz") == Parameter(
name="foo", name="foo",
ty=NormalParameter( ty=NormalParameter(
ty=Type( ty=Type(namespace=[], name="baz", bare=True, generic_ref=False, generic_arg=None),
namespace=[], name="baz", bare=True, generic_ref=False, generic_arg=None
),
flag=Flag( flag=Flag(
name="bar", name="bar",
index=1, index=1,

View File

@ -11,9 +11,7 @@ def test_empty_simple() -> None:
def test_simple() -> None: def test_simple() -> None:
assert Type.from_str("foo") == Type( assert Type.from_str("foo") == Type(namespace=[], name="foo", bare=True, generic_ref=False, generic_arg=None)
namespace=[], name="foo", bare=True, generic_ref=False, generic_arg=None
)
@mark.parametrize("ty", [".", "..", ".foo", "foo.", "foo..foo", ".foo."]) @mark.parametrize("ty", [".", "..", ".foo", "foo.", "foo..foo", ".foo."])

View File

@ -1,6 +1,7 @@
""" """
Check formatting, type-check and run offline tests. Check formatting, type-check and run offline tests.
""" """
import subprocess import subprocess
import sys import sys
import tempfile import tempfile
@ -15,9 +16,7 @@ def run(*args: str) -> int:
def main() -> None: def main() -> None:
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
exit( exit(
run("isort", ".", "-c", "--profile", "black", "--gitignore") run("mypy", "--strict", ".")
or run("black", ".", "--check", "--extend-exclude", BLACK_IGNORE)
or run("mypy", "--strict", ".")
or run("ruff", "check", ".") or run("ruff", "check", ".")
or run("sphinx", "-M", "dummy", "client/doc", tmp_dir, "-n", "-W") or run("sphinx", "-M", "dummy", "client/doc", tmp_dir, "-n", "-W")
or run("pytest", ".", "-m", "not net") or run("pytest", ".", "-m", "not net")

View File

@ -2,6 +2,7 @@
Run `telethon_generator.codegen` on both `api.tl` and `mtproto.tl` to output Run `telethon_generator.codegen` on both `api.tl` and `mtproto.tl` to output
corresponding Python code in the default directories under the `client/`. corresponding Python code in the default directories under the `client/`.
""" """
import subprocess import subprocess
import sys import sys

View File

@ -110,13 +110,13 @@ def main() -> None:
function.args.args[0].annotation = None function.args.args[0].annotation = None
if isinstance(function, ast.AsyncFunctionDef): if isinstance(function, ast.AsyncFunctionDef):
call = ast.Await(value=call) call = ast.Await(value=call) # type: ignore [arg-type]
match function.returns: match function.returns:
case ast.Constant(value=None): case ast.Constant(value=None):
call = ast.Expr(value=call) call = ast.Expr(value=call) # type: ignore [arg-type]
case _: case _:
call = ast.Return(value=call) call = ast.Return(value=call) # type: ignore [arg-type]
function.body.append(call) function.body.append(call)
class_body.append(function) class_body.append(function)

View File

@ -1,6 +1,7 @@
""" """
Run `sphinx-build` to create HTML documentation and detect errors. Run `sphinx-build` to create HTML documentation and detect errors.
""" """
import subprocess import subprocess
import sys import sys

View File

@ -1,6 +1,7 @@
""" """
Sort imports and format code. Sort imports and format code.
""" """
import subprocess import subprocess
import sys import sys