Rollback line length

This commit is contained in:
Jahongir Qurbonov 2024-10-06 14:02:23 +05:00
parent 41b848db38
commit 053bd9c02b
67 changed files with 995 additions and 305 deletions

View File

@ -15,7 +15,9 @@ def make_link_node(rawtext, app, name, options):
base += "/"
set_classes(options)
node = nodes.reference(rawtext, utils.unescape(name), refuri="{}?q={}".format(base, name), **options)
node = nodes.reference(
rawtext, utils.unescape(name), refuri="{}?q={}".format(base, name), **options
)
return node

View File

@ -52,9 +52,6 @@ backend-path = ["build_backend"]
[tool.setuptools.dynamic]
version = {attr = "telethon.version.__version__"}
[tool.ruff]
line-length = 120
[tool.ruff.lint]
select = ["F", "E", "W", "I"]
ignore = [

View File

@ -31,7 +31,9 @@ async def complete_login(client: Client, auth: abcs.auth.Authorization) -> User:
assert isinstance(auth, types.auth.Authorization)
assert isinstance(auth.user, types.User)
user = User._from_raw(auth.user)
client._session.user = SessionUser(id=user.id, dc=client._sender.dc_id, bot=user.bot, username=user.username)
client._session.user = SessionUser(
id=user.id, dc=client._sender.dc_id, bot=user.bot, username=user.username
)
client._chat_hashes.set_self_user(user.id, user.bot)
@ -54,7 +56,9 @@ async def complete_login(client: Client, auth: abcs.auth.Authorization) -> User:
async def handle_migrate(client: Client, dc_id: Optional[int]) -> None:
assert dc_id is not None
sender, client._session.dcs = await connect_sender(client._config, client._session.dcs, DataCenter(id=dc_id))
sender, client._session.dcs = await connect_sender(
client._config, client._session.dcs, DataCenter(id=dc_id)
)
async with client._sender_lock:
client._sender = sender
@ -173,7 +177,9 @@ async def interactive_login(
user = await self.check_password(user_or_token, password)
else:
while True:
print("Please enter your password (prompt is hidden; type and press enter)")
print(
"Please enter your password (prompt is hidden; type and press enter)"
)
password = getpass.getpass(": ")
try:
user = await self.check_password(user_or_token, password)
@ -202,9 +208,13 @@ async def get_password_information(client: Client) -> PasswordToken:
return PasswordToken._new(result)
async def check_password(self: Client, token: PasswordToken, password: str | bytes) -> User:
async def check_password(
self: Client, token: PasswordToken, password: str | bytes
) -> User:
algo = token._password.current_algo
if not isinstance(algo, types.PasswordKdfAlgoSha256Sha256Pbkdf2HmacshA512Iter100000Sha256ModPow):
if not isinstance(
algo, types.PasswordKdfAlgoSha256Sha256Pbkdf2HmacshA512Iter100000Sha256ModPow
):
raise RuntimeError("unrecognised 2FA algorithm")
if not two_factor_auth.check_p_and_g(algo.p, algo.g):

View File

@ -38,7 +38,11 @@ class InlineResults(metaclass=NoPublicConstructor):
result = await self._client(
functions.messages.get_inline_bot_results(
bot=self._bot,
peer=(self._peer._to_input_peer() if self._peer else types.InputPeerEmpty()),
peer=(
self._peer._to_input_peer()
if self._peer
else types.InputPeerEmpty()
),
geo_point=None,
query=self._query,
offset=self._offset,
@ -47,8 +51,12 @@ class InlineResults(metaclass=NoPublicConstructor):
assert isinstance(result, types.messages.BotResults)
self._offset = result.next_offset
for r in reversed(result.results):
assert isinstance(r, (types.BotInlineMediaResult, types.BotInlineResult))
self._buffer.append(InlineResult._create(self._client, result, r, self._peer))
assert isinstance(
r, (types.BotInlineMediaResult, types.BotInlineResult)
)
self._buffer.append(
InlineResult._create(self._client, result, r, self._peer)
)
if not self._buffer:
self._offset = None

View File

@ -53,7 +53,9 @@ class ParticipantList(AsyncList[Participant]):
seen_count = len(self._seen)
for p in chanp.participants:
part = Participant._from_raw_channel(self._client, self._peer, p, chat_map)
part = Participant._from_raw_channel(
self._client, self._peer, p, chat_map
)
pid = part._peer_id()
if pid not in self._seen:
self._seen.add(pid)
@ -64,7 +66,9 @@ class ParticipantList(AsyncList[Participant]):
self._done = len(self._seen) == seen_count
else:
chatp = await self._client(functions.messages.get_full_chat(chat_id=self._peer._to_input_chat()))
chatp = await self._client(
functions.messages.get_full_chat(chat_id=self._peer._to_input_chat())
)
assert isinstance(chatp, types.messages.ChatFull)
assert isinstance(chatp.full_chat, types.ChatFull)
@ -83,14 +87,17 @@ class ParticipantList(AsyncList[Participant]):
)
elif isinstance(participants, types.ChatParticipants):
self._buffer.extend(
Participant._from_raw_chat(self._client, self._peer, p, chat_map) for p in participants.participants
Participant._from_raw_chat(self._client, self._peer, p, chat_map)
for p in participants.participants
)
self._total = len(self._buffer)
self._done = True
def get_participants(self: Client, chat: Group | Channel | GroupRef | ChannelRef, /) -> AsyncList[Participant]:
def get_participants(
self: Client, chat: Group | Channel | GroupRef | ChannelRef, /
) -> AsyncList[Participant]:
return ParticipantList(self, chat._ref)
@ -130,7 +137,9 @@ class RecentActionList(AsyncList[RecentAction]):
self._offset = min(e.id for e in self._buffer)
def get_admin_log(self: Client, chat: Group | Channel | GroupRef | ChannelRef, /) -> AsyncList[RecentAction]:
def get_admin_log(
self: Client, chat: Group | Channel | GroupRef | ChannelRef, /
) -> AsyncList[RecentAction]:
return RecentActionList(self, chat._ref)
@ -165,7 +174,11 @@ class ProfilePhotoList(AsyncList[File]):
else:
raise RuntimeError("unexpected case")
self._buffer.extend(filter(None, (File._try_from_raw_photo(self._client, p) for p in photos)))
self._buffer.extend(
filter(
None, (File._try_from_raw_photo(self._client, p) for p in photos)
)
)
def get_profile_photos(self: Client, peer: Peer | PeerRef, /) -> AsyncList[File]:
@ -243,7 +256,11 @@ async def set_chat_default_restrictions(
*,
until: Optional[datetime.datetime] = None,
) -> None:
banned_rights = ChatRestriction._set_to_raw(set(restrictions), int(until.timestamp()) if until else 0x7FFFFFFF)
await self(
functions.messages.edit_chat_default_banned_rights(peer=chat._ref._to_input_peer(), banned_rights=banned_rights)
banned_rights = ChatRestriction._set_to_raw(
set(restrictions), int(until.timestamp()) if until else 0x7FFFFFFF
)
await self(
functions.messages.edit_chat_default_banned_rights(
peer=chat._ref._to_input_peer(), banned_rights=banned_rights
)
)

View File

@ -237,7 +237,9 @@ class Client:
lang_code=lang_code or "en",
catch_up=catch_up or False,
datacenter=datacenter,
flood_sleep_threshold=(60 if flood_sleep_threshold is None else flood_sleep_threshold),
flood_sleep_threshold=(
60 if flood_sleep_threshold is None else flood_sleep_threshold
),
update_queue_limit=update_queue_limit,
base_logger=base_logger,
connector=connector or (lambda ip, port: asyncio.open_connection(ip, port)),
@ -248,8 +250,8 @@ class Client:
self._message_box = MessageBox(base_logger=base_logger)
self._chat_hashes = ChatHashCache(None)
self._last_update_limit_warn: Optional[float] = None
self._updates: asyncio.Queue[tuple[abcs.Update, dict[int, Peer]]] = asyncio.Queue(
maxsize=self._config.update_queue_limit or 0
self._updates: asyncio.Queue[tuple[abcs.Update, dict[int, Peer]]] = (
asyncio.Queue(maxsize=self._config.update_queue_limit or 0)
)
self._dispatcher: Optional[asyncio.Task[None]] = None
self._handlers: dict[
@ -409,7 +411,9 @@ class Client:
"""
await delete_dialog(self, dialog)
async def delete_messages(self, chat: Peer | PeerRef, /, message_ids: list[int], *, revoke: bool = True) -> int:
async def delete_messages(
self, chat: Peer | PeerRef, /, message_ids: list[int], *, revoke: bool = True
) -> int:
"""
Delete messages.
@ -649,7 +653,9 @@ class Client:
"""
return await forward_messages(self, target, message_ids, source)
def get_admin_log(self, chat: Group | Channel | GroupRef | ChannelRef, /) -> AsyncList[RecentAction]:
def get_admin_log(
self, chat: Group | Channel | GroupRef | ChannelRef, /
) -> AsyncList[RecentAction]:
"""
Get the recent actions from the administrator's log.
@ -751,7 +757,9 @@ class Client:
"""
return get_file_bytes(self, media)
def get_handler_filter(self, handler: Callable[[Event], Awaitable[Any]], /) -> Optional[FilterType]:
def get_handler_filter(
self, handler: Callable[[Event], Awaitable[Any]], /
) -> Optional[FilterType]:
"""
Get the filter associated to the given event handler.
@ -853,9 +861,13 @@ class Client:
async for message in reversed(client.get_messages(chat)):
print(message.sender.name, ':', message.markdown_text)
"""
return get_messages(self, chat, limit, offset_id=offset_id, offset_date=offset_date)
return get_messages(
self, chat, limit, offset_id=offset_id, offset_date=offset_date
)
def get_messages_with_ids(self, chat: Peer | PeerRef, /, message_ids: list[int]) -> AsyncList[Message]:
def get_messages_with_ids(
self, chat: Peer | PeerRef, /, message_ids: list[int]
) -> AsyncList[Message]:
"""
Get the full message objects from the corresponding message identifiers.
@ -879,7 +891,9 @@ class Client:
"""
return get_messages_with_ids(self, chat, message_ids)
def get_participants(self, chat: Group | Channel | GroupRef | ChannelRef, /) -> AsyncList[Participant]:
def get_participants(
self, chat: Group | Channel | GroupRef | ChannelRef, /
) -> AsyncList[Participant]:
"""
Get the participants in a group or channel, along with their permissions.
@ -967,7 +981,9 @@ class Client:
"""
return await inline_query(self, bot, query, peer=peer)
async def interactive_login(self, phone_or_token: Optional[str] = None, *, password: Optional[str] = None) -> User:
async def interactive_login(
self, phone_or_token: Optional[str] = None, *, password: Optional[str] = None
) -> User:
"""
Begin an interactive login if needed.
If the account was already logged-in, this method simply returns :term:`yourself`.
@ -1019,7 +1035,9 @@ class Client:
def on(
self, event_cls: Type[Event], /, filter: Optional[FilterType] = None
) -> Callable[[Callable[[Event], Awaitable[Any]]], Callable[[Event], Awaitable[Any]]]:
) -> Callable[
[Callable[[Event], Awaitable[Any]]], Callable[[Event], Awaitable[Any]]
]:
"""
Register the decorated function to be invoked when the provided event type occurs.
@ -1107,7 +1125,9 @@ class Client:
"""
return prepare_album(self)
async def read_message(self, chat: Peer | PeerRef, /, message_id: int | Literal["all"]) -> None:
async def read_message(
self, chat: Peer | PeerRef, /, message_id: int | Literal["all"]
) -> None:
"""
Mark messages as read.
@ -1138,7 +1158,9 @@ class Client:
"""
await read_message(self, chat, message_id)
def remove_event_handler(self, handler: Callable[[Event], Awaitable[Any]], /) -> None:
def remove_event_handler(
self, handler: Callable[[Event], Awaitable[Any]], /
) -> None:
"""
Remove the handler as a function to be called when events occur.
This is simply the opposite of :meth:`add_event_handler`.
@ -1302,7 +1324,9 @@ class Client:
async for message in client.search_all_messages(query='hello'):
print(message.text)
"""
return search_all_messages(self, limit, query=query, offset_id=offset_id, offset_date=offset_date)
return search_all_messages(
self, limit, query=query, offset_id=offset_id, offset_date=offset_date
)
def search_messages(
self,
@ -1348,7 +1372,9 @@ class Client:
async for message in client.search_messages(chat, query='hello'):
print(message.text)
"""
return search_messages(self, chat, limit, query=query, offset_id=offset_id, offset_date=offset_date)
return search_messages(
self, chat, limit, query=query, offset_id=offset_id, offset_date=offset_date
)
async def send_audio(
self,
@ -1949,7 +1975,9 @@ class Client:
:meth:`telethon.types.Participant.set_restrictions`
"""
await set_participant_restrictions(self, chat, participant, restrictions, until=until)
await set_participant_restrictions(
self, chat, participant, restrictions, until=until
)
async def sign_in(self, token: LoginToken, code: str) -> User | PasswordToken:
"""
@ -1996,7 +2024,9 @@ class Client:
"""
await sign_out(self)
async def unpin_message(self, chat: Peer | PeerRef, /, message_id: int | Literal["all"]) -> None:
async def unpin_message(
self, chat: Peer | PeerRef, /, message_id: int | Literal["all"]
) -> None:
"""
Unpin one or all messages from the top.

View File

@ -51,7 +51,9 @@ class DialogList(AsyncList[Dialog]):
chat_map = build_chat_map(self._client, result.users, result.chats)
msg_map = build_msg_map(self._client, result.messages, chat_map)
self._buffer.extend(Dialog._from_raw(self._client, d, chat_map, msg_map) for d in result.dialogs)
self._buffer.extend(
Dialog._from_raw(self._client, d, chat_map, msg_map) for d in result.dialogs
)
def get_dialogs(self: Client) -> AsyncList[Dialog]:
@ -128,7 +130,9 @@ async def edit_draft(
reply_to: Optional[int] = None,
) -> Draft:
peer = peer._ref
message, entities = parse_message(text=text, markdown=markdown, html=html, allow_empty=False)
message, entities = parse_message(
text=text, markdown=markdown, html=html, allow_empty=False
)
result = await self(
functions.messages.save_draft(

View File

@ -189,11 +189,15 @@ async def send_file(
reply_to: Optional[int] = None,
keyboard: Optional[KeyboardType] = None,
) -> Message:
message, entities = parse_message(text=caption, markdown=caption_markdown, html=caption_html, allow_empty=True)
message, entities = parse_message(
text=caption, markdown=caption_markdown, html=caption_html, allow_empty=True
)
# Re-send existing file.
if isinstance(file, File):
return await do_send_file(self, chat, file._input_media, message, entities, reply_to, keyboard)
return await do_send_file(
self, chat, file._input_media, message, entities, reply_to, keyboard
)
# URLs are handled early as they can't use any other attributes either.
input_media: abcs.InputMedia
@ -208,10 +212,16 @@ async def send_file(
else:
as_photo = False
if as_photo:
input_media = types.InputMediaPhotoExternal(spoiler=False, url=file, ttl_seconds=None)
input_media = types.InputMediaPhotoExternal(
spoiler=False, url=file, ttl_seconds=None
)
else:
input_media = types.InputMediaDocumentExternal(spoiler=False, url=file, ttl_seconds=None)
return await do_send_file(self, chat, input_media, message, entities, reply_to, keyboard)
input_media = types.InputMediaDocumentExternal(
spoiler=False, url=file, ttl_seconds=None
)
return await do_send_file(
self, chat, input_media, message, entities, reply_to, keyboard
)
input_file, name = await upload(self, file, size, name)
@ -275,7 +285,9 @@ async def send_file(
ttl_seconds=None,
)
return await do_send_file(self, chat, input_media, message, entities, reply_to, keyboard)
return await do_send_file(
self, chat, input_media, message, entities, reply_to, keyboard
)
async def do_send_file(
@ -297,7 +309,11 @@ async def do_send_file(
noforwards=False,
update_stickersets_order=False,
peer=chat._ref._to_input_peer(),
reply_to=(types.InputReplyToMessage(reply_to_msg_id=reply_to, top_msg_id=None) if reply_to else None),
reply_to=(
types.InputReplyToMessage(reply_to_msg_id=reply_to, top_msg_id=None)
if reply_to
else None
),
media=input_media,
message=message,
random_id=random_id,
@ -379,7 +395,11 @@ async def do_upload(
)
)
else:
await client(functions.upload.save_file_part(file_id=file_id, file_part=part, bytes=to_store))
await client(
functions.upload.save_file_part(
file_id=file_id, file_part=part, bytes=to_store
)
)
hash_md5.update(to_store)
buffer.clear()
@ -438,7 +458,9 @@ def get_file_bytes(self: Client, media: File, /) -> AsyncList[bytes]:
return FileBytesList(self, media)
async def download(self: Client, media: File, /, file: str | Path | OutFileLike) -> None:
async def download(
self: Client, media: File, /, file: str | Path | OutFileLike
) -> None:
fd = OutWrapper(file)
try:
async for chunk in get_file_bytes(self, media):

View File

@ -46,7 +46,9 @@ async def send_message(
update_stickersets_order=False,
peer=chat._ref._to_input_peer(),
reply_to=(
types.InputReplyToMessage(reply_to_msg_id=text.replied_message_id, top_msg_id=None)
types.InputReplyToMessage(
reply_to_msg_id=text.replied_message_id, top_msg_id=None
)
if text.replied_message_id
else None
),
@ -58,7 +60,9 @@ async def send_message(
send_as=None,
)
else:
message, entities = parse_message(text=text, markdown=markdown, html=html, allow_empty=False)
message, entities = parse_message(
text=text, markdown=markdown, html=html, allow_empty=False
)
request = functions.messages.send_message(
no_webpage=not link_preview,
silent=False,
@ -67,7 +71,11 @@ async def send_message(
noforwards=False,
update_stickersets_order=False,
peer=chat._ref._to_input_peer(),
reply_to=(types.InputReplyToMessage(reply_to_msg_id=reply_to, top_msg_id=None) if reply_to else None),
reply_to=(
types.InputReplyToMessage(reply_to_msg_id=reply_to, top_msg_id=None)
if reply_to
else None
),
message=message,
random_id=random_id,
reply_markup=keyboard._raw if keyboard else None,
@ -83,7 +91,11 @@ async def send_message(
{},
out=result.out,
id=result.id,
from_id=(types.PeerUser(user_id=self._session.user.id) if self._session.user else None),
from_id=(
types.PeerUser(user_id=self._session.user.id)
if self._session.user
else None
),
peer_id=chat._ref._to_peer(),
reply_to=(
types.MessageReplyHeader(
@ -118,7 +130,9 @@ async def edit_message(
link_preview: bool = False,
keyboard: Optional[KeyboardType] = None,
) -> Message:
message, entities = parse_message(text=text, markdown=markdown, html=html, allow_empty=False)
message, entities = parse_message(
text=text, markdown=markdown, html=html, allow_empty=False
)
return self._build_message_map(
await self(
functions.messages.edit_message(
@ -146,9 +160,15 @@ async def delete_messages(
) -> int:
peer = chat._ref
if isinstance(peer, ChannelRef):
affected = await self(functions.channels.delete_messages(channel=peer._to_input_channel(), id=message_ids))
affected = await self(
functions.channels.delete_messages(
channel=peer._to_input_channel(), id=message_ids
)
)
else:
affected = await self(functions.messages.delete_messages(revoke=revoke, id=message_ids))
affected = await self(
functions.messages.delete_messages(revoke=revoke, id=message_ids)
)
assert isinstance(affected, types.messages.AffectedMessages)
return affected.pts_count
@ -185,7 +205,9 @@ class MessageList(AsyncList[Message]):
super().__init__()
self._reversed = False
def _extend_buffer(self, client: Client, messages: abcs.messages.Messages) -> dict[int, Peer]:
def _extend_buffer(
self, client: Client, messages: abcs.messages.Messages
) -> dict[int, Peer]:
if isinstance(messages, types.messages.MessagesNotModified):
self._total = messages.count
return {}
@ -193,7 +215,9 @@ class MessageList(AsyncList[Message]):
if isinstance(messages, types.messages.Messages):
self._total = len(messages.messages)
self._done = True
elif isinstance(messages, (types.messages.MessagesSlice, types.messages.ChannelMessages)):
elif isinstance(
messages, (types.messages.MessagesSlice, types.messages.ChannelMessages)
):
self._total = messages.count
else:
raise RuntimeError("unexpected case")
@ -201,7 +225,9 @@ class MessageList(AsyncList[Message]):
chat_map = build_chat_map(client, messages.users, messages.chats)
self._buffer.extend(
Message._from_raw(client, m, chat_map)
for m in (reversed(messages.messages) if self._reversed else messages.messages)
for m in (
reversed(messages.messages) if self._reversed else messages.messages
)
)
return chat_map
@ -209,7 +235,11 @@ class MessageList(AsyncList[Message]):
self,
) -> types.Message | types.MessageService | types.MessageEmpty:
return next(
(m._raw for m in reversed(self._buffer) if not isinstance(m._raw, types.MessageEmpty)),
(
m._raw
for m in reversed(self._buffer)
if not isinstance(m._raw, types.MessageEmpty)
),
types.MessageEmpty(id=0, peer_id=None),
)
@ -305,10 +335,14 @@ class CherryPickedList(MessageList):
if isinstance(self._peer, ChannelRef):
result = await self._client(
functions.channels.get_messages(channel=self._peer._to_input_channel(), id=self._ids[:100])
functions.channels.get_messages(
channel=self._peer._to_input_channel(), id=self._ids[:100]
)
)
else:
result = await self._client(functions.messages.get_messages(id=self._ids[:100]))
result = await self._client(
functions.messages.get_messages(id=self._ids[:100])
)
self._extend_buffer(self._client, result)
self._ids = self._ids[100:]
@ -458,7 +492,9 @@ def search_all_messages(
)
async def pin_message(self: Client, chat: Peer | PeerRef, /, message_id: int) -> Message:
async def pin_message(
self: Client, chat: Peer | PeerRef, /, message_id: int
) -> Message:
return self._build_message_map(
await self(
functions.messages.update_pinned_message(
@ -473,7 +509,9 @@ async def pin_message(self: Client, chat: Peer | PeerRef, /, message_id: int) ->
).get_single()
async def unpin_message(self: Client, chat: Peer | PeerRef, /, message_id: int | Literal["all"]) -> None:
async def unpin_message(
self: Client, chat: Peer | PeerRef, /, message_id: int | Literal["all"]
) -> None:
if message_id == "all":
await self(
functions.messages.unpin_all_messages(
@ -493,15 +531,25 @@ async def unpin_message(self: Client, chat: Peer | PeerRef, /, message_id: int |
)
async def read_message(self: Client, chat: Peer | PeerRef, /, message_id: int | Literal["all"]) -> None:
async def read_message(
self: Client, chat: Peer | PeerRef, /, message_id: int | Literal["all"]
) -> None:
if message_id == "all":
message_id = 0
peer = chat._ref
if isinstance(peer, ChannelRef):
await self(functions.channels.read_history(channel=peer._to_input_channel(), max_id=message_id))
await self(
functions.channels.read_history(
channel=peer._to_input_channel(), max_id=message_id
)
)
else:
await self(functions.messages.read_history(peer=peer._ref._to_input_peer(), max_id=message_id))
await self(
functions.messages.read_history(
peer=peer._ref._to_input_peer(), max_id=message_id
)
)
class MessageMap:
@ -536,7 +584,9 @@ class MessageMap:
def _empty(self, id: int = 0) -> Message:
return Message._from_raw(
self._client,
types.MessageEmpty(id=id, peer_id=self._peer._to_peer() if self._peer else None),
types.MessageEmpty(
id=id, peer_id=self._peer._to_peer() if self._peer else None
),
{},
)

View File

@ -82,9 +82,16 @@ async def connect_sender(
# Only the ID of the input DC may be known.
# Find the corresponding address and authentication key if needed.
addr = dc.ipv4_addr or next(
d.ipv4_addr for d in itertools.chain(known_dcs, KNOWN_DCS) if d.id == dc.id and d.ipv4_addr
d.ipv4_addr
for d in itertools.chain(known_dcs, KNOWN_DCS)
if d.id == dc.id and d.ipv4_addr
)
auth = (
None
if force_auth_gen
else dc.auth
or (next((d.auth for d in known_dcs if d.id == dc.id and d.auth), None))
)
auth = None if force_auth_gen else dc.auth or (next((d.auth for d in known_dcs if d.id == dc.id and d.auth), None))
sender = await do_connect_sender(
Full(),
@ -115,7 +122,9 @@ async def connect_sender(
)
except BadStatus as e:
if e.status == 404 and auth:
dc = DataCenter(id=dc.id, ipv4_addr=dc.ipv4_addr, ipv6_addr=dc.ipv6_addr, auth=None)
dc = DataCenter(
id=dc.id, ipv4_addr=dc.ipv4_addr, ipv6_addr=dc.ipv6_addr, auth=None
)
config.base_logger.warning(
"datacenter could not find stored auth; will retry generating a new one: %s",
dc,
@ -165,8 +174,12 @@ async def connect(self: Client) -> None:
if session := await self._storage.load():
self._session = session
datacenter = self._config.datacenter or DataCenter(id=self._session.user.dc if self._session.user else DEFAULT_DC)
self._sender, self._session.dcs = await connect_sender(self._config, self._session.dcs, datacenter)
datacenter = self._config.datacenter or DataCenter(
id=self._session.user.dc if self._session.user else DEFAULT_DC
)
self._sender, self._session.dcs = await connect_sender(
self._config, self._session.dcs, datacenter
)
if self._message_box.is_empty() and self._session.user:
try:
@ -180,7 +193,9 @@ async def connect(self: Client) -> None:
me = await self.get_me()
assert me is not None
self._chat_hashes.set_self_user(me.id, me.bot)
self._session.user = SessionUser(id=me.id, dc=self._sender.dc_id, bot=me.bot, username=me.username)
self._session.user = SessionUser(
id=me.id, dc=self._sender.dc_id, bot=me.bot, username=me.username
)
self._dispatcher = asyncio.create_task(dispatcher(self))
@ -199,14 +214,18 @@ async def disconnect(self: Client) -> None:
except asyncio.CancelledError:
pass
except Exception:
self._config.base_logger.exception("unhandled exception when cancelling dispatcher; this is a bug")
self._config.base_logger.exception(
"unhandled exception when cancelling dispatcher; this is a bug"
)
finally:
self._dispatcher = None
try:
await sender.disconnect()
except Exception:
self._config.base_logger.exception("unhandled exception during disconnect; this is a bug")
self._config.base_logger.exception(
"unhandled exception during disconnect; this is a bug"
)
try:
if self._session.user:

View File

@ -40,7 +40,9 @@ def add_event_handler(
self._handlers.setdefault(event_cls, []).append((handler, filter))
def remove_event_handler(self: Client, handler: Callable[[Event], Awaitable[Any]], /) -> None:
def remove_event_handler(
self: Client, handler: Callable[[Event], Awaitable[Any]], /
) -> None:
for event_cls, handlers in tuple(self._handlers.items()):
for i in reversed(range(len(handlers))):
if handlers[i][0] == handler:
@ -49,7 +51,9 @@ def remove_event_handler(self: Client, handler: Callable[[Event], Awaitable[Any]
del self._handlers[event_cls]
def get_handler_filter(self: Client, handler: Callable[[Event], Awaitable[Any]], /) -> Optional[FilterType]:
def get_handler_filter(
self: Client, handler: Callable[[Event], Awaitable[Any]], /
) -> Optional[FilterType]:
for handlers in self._handlers.values():
for h, f in handlers:
if h == handler:
@ -80,7 +84,9 @@ def process_socket_updates(client: Client, all_updates: list[abcs.Updates]) -> N
return
try:
result, users, chats = client._message_box.process_updates(updates, client._chat_hashes)
result, users, chats = client._message_box.process_updates(
updates, client._chat_hashes
)
except Gap:
return
@ -101,7 +107,8 @@ def extend_update_queue(
except asyncio.QueueFull:
now = asyncio.get_running_loop().time()
if client._last_update_limit_warn is None or (
now - client._last_update_limit_warn > UPDATE_LIMIT_EXCEEDED_LOG_COOLDOWN
now - client._last_update_limit_warn
> UPDATE_LIMIT_EXCEEDED_LOG_COOLDOWN
):
client._config.base_logger.warning(
"updates are being dropped because limit=%d has been reached",

View File

@ -53,11 +53,15 @@ def resolved_peer_to_chat(client: Client, resolved: abcs.contacts.ResolvedPeer)
async def resolve_phone(self: Client, phone: str, /) -> Peer:
return resolved_peer_to_chat(self, await self(functions.contacts.resolve_phone(phone=phone)))
return resolved_peer_to_chat(
self, await self(functions.contacts.resolve_phone(phone=phone))
)
async def resolve_username(self: Client, username: str, /) -> Peer:
return resolved_peer_to_chat(self, await self(functions.contacts.resolve_username(username=username)))
return resolved_peer_to_chat(
self, await self(functions.contacts.resolve_username(username=username))
)
async def resolve_peers(self: Client, peers: Sequence[Peer | PeerRef], /) -> list[Peer]:
@ -95,4 +99,8 @@ async def resolve_peers(self: Client, peers: Sequence[Peer | PeerRef], /) -> lis
chats.extend(ret_chats.chats)
chat_map = build_chat_map(self, users, chats)
return [chat_map.get(ref.identifier) or expand_peer(self, ref._to_peer(), broadcast=None) for ref in refs]
return [
chat_map.get(ref.identifier)
or expand_peer(self, ref._to_peer(), broadcast=None)
for ref in refs
]

View File

@ -34,13 +34,17 @@ def from_name(name: str, *, _cache: dict[str, Type[RpcError]] = {}) -> Type[RpcE
return _cache[name]
def adapt_rpc(error: RpcError, *, _cache: dict[tuple[int, str], Type[RpcError]] = {}) -> RpcError:
def adapt_rpc(
error: RpcError, *, _cache: dict[tuple[int, str], Type[RpcError]] = {}
) -> RpcError:
code = canonicalize_code(error.code)
name = canonicalize_name(error.name)
tup = code, name
if tup not in _cache:
_cache[tup] = type(pretty_name(name), (from_code(code), from_name(name)), {})
return _cache[tup](code=error.code, name=error.name, value=error.value, caused_by=error._caused_by)
return _cache[tup](
code=error.code, name=error.name, value=error.value, caused_by=error._caused_by
)
class ErrorFactory:
@ -51,7 +55,9 @@ class ErrorFactory:
return from_code(int(m[1]))
else:
adapted = adapt_user_name(name)
if pretty_name(canonicalize_name(adapted)) != name or re.match(r"[A-Z]{2}", name):
if pretty_name(canonicalize_name(adapted)) != name or re.match(
r"[A-Z]{2}", name
):
raise AttributeError(f"error subclass names must be CamelCase: {name}")
return from_name(adapted)

View File

@ -24,7 +24,9 @@ class Event(abc.ABC, metaclass=NoPublicConstructor):
@classmethod
@abc.abstractmethod
def _try_from_update(cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]) -> Optional[Self]:
def _try_from_update(
cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]
) -> Optional[Self]:
pass
@ -50,7 +52,9 @@ class Raw(Event):
self._chat_map = chat_map
@classmethod
def _try_from_update(cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]) -> Optional[Self]:
def _try_from_update(
cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]
) -> Optional[Self]:
return cls._create(client, update, chat_map)

View File

@ -65,7 +65,9 @@ class Any(Combinable):
__slots__ = ("_filters",)
def __init__(self, filter1: FilterType, filter2: FilterType, *filters: FilterType) -> None:
def __init__(
self, filter1: FilterType, filter2: FilterType, *filters: FilterType
) -> None:
self._filters = (filter1, filter2, *filters)
@property
@ -109,7 +111,9 @@ class All(Combinable):
__slots__ = ("_filters",)
def __init__(self, filter1: FilterType, filter2: FilterType, *filters: FilterType) -> None:
def __init__(
self, filter1: FilterType, filter2: FilterType, *filters: FilterType
) -> None:
self._filters = (filter1, filter2, *filters)
@property

View File

@ -24,11 +24,15 @@ class NewMessage(Event, Message):
"""
@classmethod
def _try_from_update(cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]) -> Optional[Self]:
def _try_from_update(
cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]
) -> Optional[Self]:
if isinstance(update, (types.UpdateNewMessage, types.UpdateNewChannelMessage)):
if isinstance(update.message, types.Message):
return cls._from_raw(client, update.message, chat_map)
elif isinstance(update, (types.UpdateShortMessage, types.UpdateShortChatMessage)):
elif isinstance(
update, (types.UpdateShortMessage, types.UpdateShortChatMessage)
):
raise RuntimeError("should have been handled by adaptor")
return None
@ -42,8 +46,12 @@ class MessageEdited(Event, Message):
"""
@classmethod
def _try_from_update(cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]) -> Optional[Self]:
if isinstance(update, (types.UpdateEditMessage, types.UpdateEditChannelMessage)):
def _try_from_update(
cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]
) -> Optional[Self]:
if isinstance(
update, (types.UpdateEditMessage, types.UpdateEditChannelMessage)
):
return cls._from_raw(client, update.message, chat_map)
else:
return None
@ -66,7 +74,9 @@ class MessageDeleted(Event):
self._channel_id = channel_id
@classmethod
def _try_from_update(cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]) -> Optional[Self]:
def _try_from_update(
cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]
) -> Optional[Self]:
if isinstance(update, types.UpdateDeleteMessages):
return cls._create(update.messages, None)
elif isinstance(update, types.UpdateDeleteChannelMessages):
@ -112,7 +122,9 @@ class MessageRead(Event):
self._chat_map = chat_map
@classmethod
def _try_from_update(cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]) -> Optional[Self]:
def _try_from_update(
cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]
) -> Optional[Self]:
if isinstance(
update,
(
@ -127,7 +139,9 @@ class MessageRead(Event):
return None
def _peer(self) -> abcs.Peer:
if isinstance(self._raw, (types.UpdateReadHistoryInbox, types.UpdateReadHistoryOutbox)):
if isinstance(
self._raw, (types.UpdateReadHistoryInbox, types.UpdateReadHistoryOutbox)
):
return self._raw.peer
else:
return types.PeerChannel(channel_id=self._raw.channel_id)
@ -140,7 +154,9 @@ class MessageRead(Event):
peer = self._peer()
pid = peer_id(peer)
if pid not in self._chat_map:
self._chat_map[pid] = expand_peer(self._client, peer, broadcast=getattr(self._raw, "post", None))
self._chat_map[pid] = expand_peer(
self._client, peer, broadcast=getattr(self._raw, "post", None)
)
return self._chat_map[pid]
@property

View File

@ -31,7 +31,9 @@ class ButtonCallback(Event):
self._chat_map = chat_map
@classmethod
def _try_from_update(cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]) -> Optional[Self]:
def _try_from_update(
cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]
) -> Optional[Self]:
if isinstance(update, types.UpdateBotCallbackQuery) and update.data is not None:
return cls._create(client, update, chat_map)
else:
@ -81,7 +83,11 @@ class ButtonCallback(Event):
chat = self._chat_map.get(pid) or PeerRef._empty_from_peer(self._raw.peer)
lst = CherryPickedList(self._client, chat._ref, [])
lst._ids.append(types.InputMessageCallbackQuery(id=self._raw.msg_id, query_id=self._raw.query_id))
lst._ids.append(
types.InputMessageCallbackQuery(
id=self._raw.msg_id, query_id=self._raw.query_id
)
)
message = (await lst)[0]
@ -99,7 +105,9 @@ class InlineQuery(Event):
self._raw = update
@classmethod
def _try_from_update(cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]) -> Optional[Self]:
def _try_from_update(
cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]
) -> Optional[Self]:
if isinstance(update, types.UpdateBotInlineQuery):
return cls._create(update)
else:

View File

@ -133,7 +133,9 @@ def parse(html: str) -> tuple[str, list[MessageEntity]]:
return del_surrogate(parser.text), parser.entities
ENTITY_TO_FORMATTER: dict[Type[MessageEntity], tuple[str, str] | Callable[[Any, str], tuple[str, str]]] = {
ENTITY_TO_FORMATTER: dict[
Type[MessageEntity], tuple[str, str] | Callable[[Any, str], tuple[str, str]]
] = {
MessageEntityBold: ("<strong>", "</strong>"),
MessageEntityItalic: ("<em>", "</em>"),
MessageEntityCode: ("<code>", "</code>"),
@ -194,7 +196,12 @@ def unparse(text: str, entities: Iterable[MessageEntity]) -> str:
while within_surrogate(text, at):
at += 1
text = text[:at] + what + escape(text[at:next_escape_bound]) + text[next_escape_bound:]
text = (
text[:at]
+ what
+ escape(text[at:next_escape_bound])
+ text[next_escape_bound:]
)
next_escape_bound = at
text = escape(text[:next_escape_bound]) + text[next_escape_bound:]

View File

@ -82,7 +82,9 @@ def parse(message: str) -> tuple[str, list[MessageEntity]]:
else:
for entity in reversed(entities):
if isinstance(entity, ty):
setattr(entity, "length", len(message) - getattr(entity, "offset", 0))
setattr(
entity, "length", len(message) - getattr(entity, "offset", 0)
)
break
parsed = MARKDOWN.parse(add_surrogate(message.strip()))
@ -101,15 +103,25 @@ def parse(message: str) -> tuple[str, list[MessageEntity]]:
if token.type in ("blockquote_close", "blockquote_open"):
push(MessageEntityBlockquote)
elif token.type == "code_block":
entities.append(MessageEntityPre(offset=len(message), length=len(token.content), language=""))
entities.append(
MessageEntityPre(
offset=len(message), length=len(token.content), language=""
)
)
message += token.content
elif token.type == "code_inline":
entities.append(MessageEntityCode(offset=len(message), length=len(token.content)))
entities.append(
MessageEntityCode(offset=len(message), length=len(token.content))
)
message += token.content
elif token.type in ("em_close", "em_open"):
push(MessageEntityItalic)
elif token.type == "fence":
entities.append(MessageEntityPre(offset=len(message), length=len(token.content), language=token.info))
entities.append(
MessageEntityPre(
offset=len(message), length=len(token.content), language=token.info
)
)
message += token.content[:-1] # remove a single trailing newline
elif token.type == "hardbreak":
message += "\n"
@ -118,7 +130,9 @@ def parse(message: str) -> tuple[str, list[MessageEntity]]:
elif token.type == "hr":
message += "\u2015\n\n"
elif token.type in ("link_close", "link_open"):
if token.markup != "autolink": # telegram already picks up on these automatically
if (
token.markup != "autolink"
): # telegram already picks up on these automatically
push(MessageEntityTextUrl, url=token.attrs.get("href"))
elif token.type in ("s_close", "s_open"):
push(MessageEntityStrike)

View File

@ -6,7 +6,11 @@ def add_surrogate(text: str) -> str:
return "".join(
# SMP -> Surrogate Pairs (Telegram offsets are calculated with these).
# See https://en.wikipedia.org/wiki/Plane_(Unicode)#Overview for more.
("".join(chr(y) for y in struct.unpack("<HH", x.encode("utf-16le"))) if (0x10000 <= ord(x) <= 0x10FFFF) else x)
(
"".join(chr(y) for y in struct.unpack("<HH", x.encode("utf-16le")))
if (0x10000 <= ord(x) <= 0x10FFFF)
else x
)
for x in text
)

View File

@ -52,12 +52,20 @@ class AlbumBuilder(metaclass=NoPublicConstructor):
input_media: abcs.InputMedia
if try_get_url_path(file) is not None:
assert isinstance(file, str)
input_media = types.InputMediaPhotoExternal(spoiler=False, url=file, ttl_seconds=None)
input_media = types.InputMediaPhotoExternal(
spoiler=False, url=file, ttl_seconds=None
)
else:
input_file, _ = await self._client._upload(file, size, "a.jpg")
input_media = types.InputMediaUploadedPhoto(spoiler=False, file=input_file, stickers=None, ttl_seconds=None)
input_media = types.InputMediaUploadedPhoto(
spoiler=False, file=input_file, stickers=None, ttl_seconds=None
)
media = await self._client(functions.messages.upload_media(peer=types.InputPeerSelf(), media=input_media))
media = await self._client(
functions.messages.upload_media(
peer=types.InputPeerSelf(), media=input_media
)
)
assert isinstance(media, types.MessageMediaPhoto)
assert isinstance(media.photo, types.Photo)
input_media = types.InputMediaPhoto(
@ -69,7 +77,9 @@ class AlbumBuilder(metaclass=NoPublicConstructor):
),
ttl_seconds=media.ttl_seconds,
)
message, entities = parse_message(text=caption, markdown=caption_markdown, html=caption_html, allow_empty=True)
message, entities = parse_message(
text=caption, markdown=caption_markdown, html=caption_html, allow_empty=True
)
self._medias.append(
types.InputSingleMedia(
media=input_media,
@ -122,7 +132,9 @@ class AlbumBuilder(metaclass=NoPublicConstructor):
input_media: abcs.InputMedia
if try_get_url_path(file) is not None:
assert isinstance(file, str)
input_media = types.InputMediaDocumentExternal(spoiler=False, url=file, ttl_seconds=None)
input_media = types.InputMediaDocumentExternal(
spoiler=False, url=file, ttl_seconds=None
)
else:
input_file, name = await self._client._upload(file, size, name)
if mime_type is None:
@ -156,7 +168,11 @@ class AlbumBuilder(metaclass=NoPublicConstructor):
ttl_seconds=None,
)
media = await self._client(functions.messages.upload_media(peer=types.InputPeerEmpty(), media=input_media))
media = await self._client(
functions.messages.upload_media(
peer=types.InputPeerEmpty(), media=input_media
)
)
assert isinstance(media, types.MessageMediaDocument)
assert isinstance(media.document, types.Document)
input_media = types.InputMediaDocument(
@ -169,7 +185,9 @@ class AlbumBuilder(metaclass=NoPublicConstructor):
ttl_seconds=media.ttl_seconds,
query=None,
)
message, entities = parse_message(text=caption, markdown=caption_markdown, html=caption_html, allow_empty=True)
message, entities = parse_message(
text=caption, markdown=caption_markdown, html=caption_html, allow_empty=True
)
self._medias.append(
types.InputSingleMedia(
media=input_media,
@ -179,7 +197,9 @@ class AlbumBuilder(metaclass=NoPublicConstructor):
)
)
async def send(self, peer: Peer | PeerRef, *, reply_to: Optional[int] = None) -> list[Message]:
async def send(
self, peer: Peer | PeerRef, *, reply_to: Optional[int] = None
) -> list[Message]:
"""
Send the album.
@ -205,7 +225,11 @@ class AlbumBuilder(metaclass=NoPublicConstructor):
update_stickersets_order=False,
peer=peer._ref._to_input_peer(),
reply_to=(
types.InputReplyToMessage(reply_to_msg_id=reply_to, top_msg_id=None) if reply_to else None
types.InputReplyToMessage(
reply_to_msg_id=reply_to, top_msg_id=None
)
if reply_to
else None
),
multi_media=self._medias,
schedule_date=None,

View File

@ -50,7 +50,9 @@ class Button(abc.ABC):
def __init__(self, text: str) -> None:
if self.__class__ == Button:
raise TypeError(f"Can't instantiate abstract class {self.__class__.__name__}")
raise TypeError(
f"Can't instantiate abstract class {self.__class__.__name__}"
)
self._raw: RawButtonType = types.KeyboardButton(text=text)
self._msg: Optional[weakref.ReferenceType[Message]] = None

View File

@ -29,6 +29,8 @@ class InlineButton(Button, abc.ABC):
def __init__(self, text: str) -> None:
if self.__class__ == InlineButton:
raise TypeError(f"Can't instantiate abstract class {self.__class__.__name__}")
raise TypeError(
f"Can't instantiate abstract class {self.__class__.__name__}"
)
else:
super().__init__(text)

View File

@ -14,7 +14,9 @@ class SwitchInline(InlineButton):
def __init__(self, text: str, query: Optional[str] = None) -> None:
super().__init__(text)
self._raw = types.KeyboardButtonSwitchInline(same_peer=False, text=text, query=query or "", peer_types=None)
self._raw = types.KeyboardButtonSwitchInline(
same_peer=False, text=text, query=query or "", peer_types=None
)
@property
def query(self) -> str:

View File

@ -111,7 +111,9 @@ class ChatRestriction(Enum):
return set(filter(None, iter(restrictions)))
@classmethod
def _set_to_raw(cls, restrictions: set[ChatRestriction], until_date: int) -> types.ChatBannedRights:
def _set_to_raw(
cls, restrictions: set[ChatRestriction], until_date: int
) -> types.ChatBannedRights:
return types.ChatBannedRights(
view_messages=cls.VIEW_MESSAGES in restrictions,
send_messages=cls.SEND_MESSAGES in restrictions,

View File

@ -90,6 +90,9 @@ class Dialog(metaclass=NoPublicConstructor):
if isinstance(self._raw, types.Dialog):
return self._raw.unread_count
elif isinstance(self._raw, types.DialogPeerFolder):
return self._raw.unread_unmuted_messages_count + self._raw.unread_muted_messages_count
return (
self._raw.unread_unmuted_messages_count
+ self._raw.unread_muted_messages_count
)
else:
raise RuntimeError("unexpected case")

View File

@ -37,7 +37,9 @@ class Draft(metaclass=NoPublicConstructor):
self._chat_map = chat_map
@classmethod
def _from_raw_update(cls, client: Client, draft: types.UpdateDraftMessage, chat_map: dict[int, Peer]) -> Self:
def _from_raw_update(
cls, client: Client, draft: types.UpdateDraftMessage, chat_map: dict[int, Peer]
) -> Self:
return cls._create(client, draft.peer, draft.top_msg_id, draft.draft, chat_map)
@classmethod
@ -58,7 +60,9 @@ class Draft(metaclass=NoPublicConstructor):
This is also the chat where the message will be sent to by :meth:`send`.
"""
return self._chat_map.get(peer_id(self._peer)) or expand_peer(self._client, self._peer, broadcast=None)
return self._chat_map.get(peer_id(self._peer)) or expand_peer(
self._client, self._peer, broadcast=None
)
@property
def link_preview(self) -> bool:
@ -87,7 +91,9 @@ class Draft(metaclass=NoPublicConstructor):
The :attr:`~Message.text_html` of the message that will be sent.
"""
if text := getattr(self._raw, "message", None):
return generate_html_message(text, getattr(self._raw, "entities", None) or [])
return generate_html_message(
text, getattr(self._raw, "entities", None) or []
)
else:
return None
@ -97,7 +103,9 @@ class Draft(metaclass=NoPublicConstructor):
The :attr:`~Message.text_markdown` of the message that will be sent.
"""
if text := getattr(self._raw, "message", None):
return generate_markdown_message(text, getattr(self._raw, "entities", None) or [])
return generate_markdown_message(
text, getattr(self._raw, "entities", None) or []
)
else:
return None
@ -107,7 +115,11 @@ class Draft(metaclass=NoPublicConstructor):
The date when the draft was last updated.
"""
date = getattr(self._raw, "date", None)
return datetime.datetime.fromtimestamp(date, tz=datetime.timezone.utc) if date is not None else None
return (
datetime.datetime.fromtimestamp(date, tz=datetime.timezone.utc)
if date is not None
else None
)
async def edit(
self,
@ -180,7 +192,11 @@ class Draft(metaclass=NoPublicConstructor):
noforwards=False,
update_stickersets_order=False,
peer=self._peer_ref()._to_input_peer(),
reply_to=(types.InputReplyToMessage(reply_to_msg_id=reply_to, top_msg_id=None) if reply_to else None),
reply_to=(
types.InputReplyToMessage(reply_to_msg_id=reply_to, top_msg_id=None)
if reply_to
else None
),
message=message,
random_id=random_id,
reply_markup=None,
@ -195,7 +211,11 @@ class Draft(metaclass=NoPublicConstructor):
{},
out=result.out,
id=result.id,
from_id=(types.PeerUser(user_id=self._client._session.user.id) if self._client._session.user else None),
from_id=(
types.PeerUser(user_id=self._client._session.user.id)
if self._client._session.user
else None
),
peer_id=self._peer_ref()._to_peer(),
reply_to=(
types.MessageReplyHeader(
@ -215,7 +235,9 @@ class Draft(metaclass=NoPublicConstructor):
ttl_period=result.ttl_period,
)
else:
return self._client._build_message_map(result, self._peer_ref()).with_random_id(random_id)
return self._client._build_message_map(
result, self._peer_ref()
).with_random_id(random_id)
async def delete(self) -> None:
"""

View File

@ -29,7 +29,11 @@ def photo_size_byte_count(size: abcs.PhotoSize) -> int:
elif isinstance(size, types.PhotoSizeProgressive):
return max(size.sizes)
elif isinstance(size, types.PhotoStrippedSize):
return len(stripped_size_header) + (len(size.bytes) - 3) + len(stripped_size_footer)
return (
len(stripped_size_header)
+ (len(size.bytes) - 3)
+ len(stripped_size_footer)
)
else:
raise RuntimeError("unexpected case")
@ -176,7 +180,9 @@ class File(metaclass=NoPublicConstructor):
self._client = client
@classmethod
def _try_from_raw_message_media(cls, client: Client, raw: abcs.MessageMedia) -> Optional[Self]:
def _try_from_raw_message_media(
cls, client: Client, raw: abcs.MessageMedia
) -> Optional[Self]:
if isinstance(raw, types.MessageMediaDocument):
if raw.document:
return cls._try_from_raw_document(
@ -198,9 +204,13 @@ class File(metaclass=NoPublicConstructor):
elif isinstance(raw, types.MessageMediaWebPage):
if isinstance(raw.webpage, types.WebPage):
if raw.webpage.document:
return cls._try_from_raw_document(client, raw.webpage.document, orig_raw=raw)
return cls._try_from_raw_document(
client, raw.webpage.document, orig_raw=raw
)
if raw.webpage.photo:
return cls._try_from_raw_photo(client, raw.webpage.photo, orig_raw=raw)
return cls._try_from_raw_photo(
client, raw.webpage.photo, orig_raw=raw
)
return None
@ -219,13 +229,21 @@ class File(metaclass=NoPublicConstructor):
attributes=raw.attributes,
size=raw.size,
name=next(
(a.file_name for a in raw.attributes if isinstance(a, types.DocumentAttributeFilename)),
(
a.file_name
for a in raw.attributes
if isinstance(a, types.DocumentAttributeFilename)
),
"",
),
mime=raw.mime_type,
photo=False,
muted=next(
(a.nosound for a in raw.attributes if isinstance(a, types.DocumentAttributeVideo)),
(
a.nosound
for a in raw.attributes
if isinstance(a, types.DocumentAttributeVideo)
),
False,
),
input_media=types.InputMediaDocument(
@ -343,7 +361,9 @@ class File(metaclass=NoPublicConstructor):
return dim.w
for attr in self._attributes:
if isinstance(attr, (types.DocumentAttributeImageSize, types.DocumentAttributeVideo)):
if isinstance(
attr, (types.DocumentAttributeImageSize, types.DocumentAttributeVideo)
):
return attr.w
return None
@ -357,7 +377,9 @@ class File(metaclass=NoPublicConstructor):
return dim.h
for attr in self._attributes:
if isinstance(attr, (types.DocumentAttributeImageSize, types.DocumentAttributeVideo)):
if isinstance(
attr, (types.DocumentAttributeImageSize, types.DocumentAttributeVideo)
):
return attr.h
return None
@ -388,7 +410,9 @@ class File(metaclass=NoPublicConstructor):
id=self._input_media.id.id,
access_hash=self._input_media.id.access_hash,
file_reference=self._input_media.id.file_reference,
thumb_size=(self._thumb.type if isinstance(self._thumb, thumb_types) else ""),
thumb_size=(
self._thumb.type if isinstance(self._thumb, thumb_types) else ""
),
)
elif isinstance(self._input_media, types.InputMediaPhoto):
assert isinstance(self._input_media.id, types.InputPhoto)

View File

@ -12,11 +12,16 @@ def _build_keyboard_rows(
) -> list[abcs.KeyboardButtonRow]:
# list[button] -> list[list[button]]
# This does allow for "invalid" inputs (mixing lists and non-lists), but that's acceptable.
buttons_lists_iter = [button if isinstance(button, list) else [button] for button in (btns or [])]
buttons_lists_iter = [
button if isinstance(button, list) else [button] for button in (btns or [])
]
# Remove empty rows (also making it easy to check if all-empty).
buttons_lists = [bs for bs in buttons_lists_iter if bs]
return [types.KeyboardButtonRow(buttons=[btn._raw for btn in btns]) for btns in buttons_lists]
return [
types.KeyboardButtonRow(buttons=[btn._raw for btn in btns])
for btns in buttons_lists
]
class Keyboard:
@ -44,7 +49,9 @@ class Keyboard:
class InlineKeyboard:
__slots__ = ("_raw",)
def __init__(self, buttons: list[AnyInlineButton] | list[list[AnyInlineButton]]) -> None:
def __init__(
self, buttons: list[AnyInlineButton] | list[list[AnyInlineButton]]
) -> None:
self._raw = types.ReplyInlineMarkup(rows=_build_keyboard_rows(buttons))

View File

@ -34,7 +34,11 @@ def generate_random_id() -> int:
def adapt_date(date: Optional[int]) -> Optional[datetime.datetime]:
return datetime.datetime.fromtimestamp(date, tz=datetime.timezone.utc) if date is not None else None
return (
datetime.datetime.fromtimestamp(date, tz=datetime.timezone.utc)
if date is not None
else None
)
class Message(metaclass=NoPublicConstructor):
@ -55,14 +59,20 @@ class Message(metaclass=NoPublicConstructor):
print('Found empty message with ID', message.id)
"""
def __init__(self, client: Client, message: abcs.Message, chat_map: dict[int, Peer]) -> None:
assert isinstance(message, (types.Message, types.MessageService, types.MessageEmpty))
def __init__(
self, client: Client, message: abcs.Message, chat_map: dict[int, Peer]
) -> None:
assert isinstance(
message, (types.Message, types.MessageService, types.MessageEmpty)
)
self._client = client
self._raw = message
self._chat_map = chat_map
@classmethod
def _from_raw(cls, client: Client, message: abcs.Message, chat_map: dict[int, Peer]) -> Self:
def _from_raw(
cls, client: Client, message: abcs.Message, chat_map: dict[int, Peer]
) -> Self:
return cls._create(client, message, chat_map)
@classmethod
@ -148,7 +158,9 @@ class Message(metaclass=NoPublicConstructor):
See :ref:`formatting` to learn the HTML elements used.
"""
if text := getattr(self._raw, "message", None):
return generate_html_message(text, getattr(self._raw, "entities", None) or [])
return generate_html_message(
text, getattr(self._raw, "entities", None) or []
)
else:
return None
@ -160,7 +172,9 @@ class Message(metaclass=NoPublicConstructor):
See :ref:`formatting` to learn the formatting characters used.
"""
if text := getattr(self._raw, "message", None):
return generate_markdown_message(text, getattr(self._raw, "entities", None) or [])
return generate_markdown_message(
text, getattr(self._raw, "entities", None) or []
)
else:
return None
@ -179,7 +193,9 @@ class Message(metaclass=NoPublicConstructor):
peer = self._raw.peer_id or types.PeerUser(user_id=0)
pid = peer_id(peer)
if pid not in self._chat_map:
self._chat_map[pid] = expand_peer(self._client, peer, broadcast=getattr(self._raw, "post", None))
self._chat_map[pid] = expand_peer(
self._client, peer, broadcast=getattr(self._raw, "post", None)
)
return self._chat_map[pid]
@property
@ -223,7 +239,14 @@ class Message(metaclass=NoPublicConstructor):
This can also be used as a way to check that the message media is an audio.
"""
audio = self._file()
return audio if audio and any(isinstance(a, types.DocumentAttributeAudio) for a in audio._attributes) else None
return (
audio
if audio
and any(
isinstance(a, types.DocumentAttributeAudio) for a in audio._attributes
)
else None
)
@property
def video(self) -> Optional[File]:
@ -233,7 +256,14 @@ class Message(metaclass=NoPublicConstructor):
This can also be used as a way to check that the message media is a video.
"""
audio = self._file()
return audio if audio and any(isinstance(a, types.DocumentAttributeVideo) for a in audio._attributes) else None
return (
audio
if audio
and any(
isinstance(a, types.DocumentAttributeVideo) for a in audio._attributes
)
else None
)
@property
def file(self) -> Optional[File]:
@ -447,7 +477,10 @@ class Message(metaclass=NoPublicConstructor):
return None
return [
[create_button(self, button) for button in cast(types.KeyboardButtonRow, row).buttons]
[
create_button(self, button)
for button in cast(types.KeyboardButtonRow, row).buttons
]
for row in markup.rows
]
@ -473,8 +506,13 @@ class Message(metaclass=NoPublicConstructor):
return not isinstance(self._raw, types.MessageEmpty)
def build_msg_map(client: Client, messages: Sequence[abcs.Message], chat_map: dict[int, Peer]) -> dict[int, Message]:
return {msg.id: msg for msg in (Message._from_raw(client, m, chat_map) for m in messages)}
def build_msg_map(
client: Client, messages: Sequence[abcs.Message], chat_map: dict[int, Peer]
) -> dict[int, Message]:
return {
msg.id: msg
for msg in (Message._from_raw(client, m, chat_map) for m in messages)
}
def parse_message(

View File

@ -16,16 +16,23 @@ class Final(abc.ABCMeta):
cls_namespace: dict[str, object],
) -> "Final":
# Allow subclassing while within telethon._impl (or other package names).
allowed_base = Final.__module__[: Final.__module__.find(".", Final.__module__.find(".") + 1)]
allowed_base = Final.__module__[
: Final.__module__.find(".", Final.__module__.find(".") + 1)
]
for base in bases:
if isinstance(base, Final) and not base.__module__.startswith(allowed_base):
raise TypeError(f"{base.__module__}.{base.__qualname__} does not support" " subclassing")
raise TypeError(
f"{base.__module__}.{base.__qualname__} does not support"
" subclassing"
)
return super().__new__(cls, name, bases, cls_namespace)
class NoPublicConstructor(Final):
def __call__(cls, *args: Any, **kwds: Any) -> Any:
raise TypeError(f"{cls.__module__}.{cls.__qualname__} has no public constructor")
raise TypeError(
f"{cls.__module__}.{cls.__qualname__} has no public constructor"
)
@property
def _create(cls: Type[T]) -> Type[T]:

View File

@ -157,16 +157,22 @@ class Participant(metaclass=NoPublicConstructor):
"""
:data:`True` if the participant is the creator of the chat.
"""
return isinstance(self._raw, (types.ChannelParticipantCreator, types.ChatParticipantCreator))
return isinstance(
self._raw, (types.ChannelParticipantCreator, types.ChatParticipantCreator)
)
@property
def admin_rights(self) -> Optional[set[AdminRight]]:
"""
The set of administrator rights this participant has been granted, if they are an administrator.
"""
if isinstance(self._raw, (types.ChannelParticipantCreator, types.ChannelParticipantAdmin)):
if isinstance(
self._raw, (types.ChannelParticipantCreator, types.ChannelParticipantAdmin)
):
return AdminRight._from_raw(self._raw.admin_rights)
elif isinstance(self._raw, (types.ChatParticipantCreator, types.ChatParticipantAdmin)):
elif isinstance(
self._raw, (types.ChatParticipantCreator, types.ChatParticipantAdmin)
):
return AdminRight._chat_rights()
else:
return None
@ -188,9 +194,13 @@ class Participant(metaclass=NoPublicConstructor):
participant = self.user or self.banned or self.left
assert participant
if isinstance(participant, User):
await self._client.set_participant_admin_rights(self._chat, participant, rights)
await self._client.set_participant_admin_rights(
self._chat, participant, rights
)
else:
raise TypeError(f"participant of type {participant.__class__.__name__} cannot be made admin")
raise TypeError(
f"participant of type {participant.__class__.__name__} cannot be made admin"
)
async def set_restrictions(
self,
@ -203,4 +213,6 @@ class Participant(metaclass=NoPublicConstructor):
"""
participant = self.user or self.banned or self.left
assert participant
await self._client.set_participant_restrictions(self._chat, participant, restrictions, until=until)
await self._client.set_participant_restrictions(
self._chat, participant, restrictions, until=until
)

View File

@ -15,7 +15,9 @@ if TYPE_CHECKING:
from ...client.client import Client
def build_chat_map(client: Client, users: Sequence[abcs.User], chats: Sequence[abcs.Chat]) -> dict[int, Peer]:
def build_chat_map(
client: Client, users: Sequence[abcs.User], chats: Sequence[abcs.Chat]
) -> dict[int, Peer]:
users_iter = (User._from_raw(u) for u in users)
chats_iter = (
(
@ -43,7 +45,9 @@ def build_chat_map(client: Client, users: Sequence[abcs.User], chats: Sequence[a
for x in v:
print(x, file=sys.stderr)
raise RuntimeError(f"chat identifier collision: {k}; please report this")
raise RuntimeError(
f"chat identifier collision: {k}; please report this"
)
return result
@ -77,7 +81,11 @@ def expand_peer(client: Client, peer: abcs.Peer, *, broadcast: Optional[bool]) -
until_date=None,
)
return Channel._from_raw(channel) if broadcast else Group._from_raw(client, channel)
return (
Channel._from_raw(channel)
if broadcast
else Group._from_raw(client, channel)
)
else:
raise RuntimeError("unexpected case")

View File

@ -24,7 +24,13 @@ class Group(Peer, metaclass=NoPublicConstructor):
def __init__(
self,
client: Client,
chat: (types.ChatEmpty | types.Chat | types.ChatForbidden | types.Channel | types.ChannelForbidden),
chat: (
types.ChatEmpty
| types.Chat
| types.ChatForbidden
| types.Channel
| types.ChannelForbidden
),
) -> None:
self._client = client
self._raw = chat
@ -90,4 +96,6 @@ class Group(Peer, metaclass=NoPublicConstructor):
"""
Alias for :meth:`telethon.Client.set_chat_default_restrictions`.
"""
await self._client.set_chat_default_restrictions(self, restrictions, until=until)
await self._client.set_chat_default_restrictions(
self, restrictions, until=until
)

View File

@ -1,10 +1,16 @@
try:
import cryptg # type: ignore [import-untyped]
def ige_encrypt(plaintext: bytes | bytearray | memoryview, key: bytes, iv: bytes) -> bytes: # noqa: F811
return cryptg.encrypt_ige(bytes(plaintext) if not isinstance(plaintext, bytes) else plaintext, key, iv)
def ige_encrypt(
plaintext: bytes | bytearray | memoryview, key: bytes, iv: bytes
) -> bytes: # noqa: F811
return cryptg.encrypt_ige(
bytes(plaintext) if not isinstance(plaintext, bytes) else plaintext, key, iv
)
def ige_decrypt(ciphertext: bytes | bytearray | memoryview, key: bytes, iv: bytes) -> bytes: # noqa: F811
def ige_decrypt(
ciphertext: bytes | bytearray | memoryview, key: bytes, iv: bytes
) -> bytes: # noqa: F811
return cryptg.decrypt_ige(
bytes(ciphertext) if not isinstance(ciphertext, bytes) else ciphertext,
key,
@ -14,7 +20,9 @@ try:
except ImportError:
import pyaes # type: ignore [import-untyped]
def ige_encrypt(plaintext: bytes | bytearray | memoryview, key: bytes, iv: bytes) -> bytes:
def ige_encrypt(
plaintext: bytes | bytearray | memoryview, key: bytes, iv: bytes
) -> bytes:
assert len(plaintext) % 16 == 0
assert len(iv) == 32
@ -27,7 +35,10 @@ except ImportError:
for block_offset in range(0, len(plaintext), 16):
plaintext_block = plaintext[block_offset : block_offset + 16]
ciphertext_block = bytes(
a ^ b for a, b in zip(aes.encrypt([a ^ b for a, b in zip(plaintext_block, iv1)]), iv2)
a ^ b
for a, b in zip(
aes.encrypt([a ^ b for a, b in zip(plaintext_block, iv1)]), iv2
)
)
iv1 = ciphertext_block
iv2 = plaintext_block
@ -36,7 +47,9 @@ except ImportError:
return bytes(ciphertext)
def ige_decrypt(ciphertext: bytes | bytearray | memoryview, key: bytes, iv: bytes) -> bytes:
def ige_decrypt(
ciphertext: bytes | bytearray | memoryview, key: bytes, iv: bytes
) -> bytes:
assert len(ciphertext) % 16 == 0
assert len(iv) == 32
@ -49,7 +62,10 @@ except ImportError:
for block_offset in range(0, len(ciphertext), 16):
ciphertext_block = ciphertext[block_offset : block_offset + 16]
plaintext_block = bytes(
a ^ b for a, b in zip(aes.decrypt([a ^ b for a, b in zip(ciphertext_block, iv2)]), iv1)
a ^ b
for a, b in zip(
aes.decrypt([a ^ b for a, b in zip(ciphertext_block, iv2)]), iv1
)
)
iv1 = ciphertext_block
iv2 = plaintext_block

View File

@ -20,4 +20,8 @@ class AuthKey:
return self.data
def calc_new_nonce_hash(self, new_nonce: int, number: int) -> int:
return int.from_bytes(sha1(new_nonce.to_bytes(32) + number.to_bytes(1) + self.aux_hash).digest()[4:])
return int.from_bytes(
sha1(new_nonce.to_bytes(32) + number.to_bytes(1) + self.aux_hash).digest()[
4:
]
)

View File

@ -19,7 +19,9 @@ class CalcKey(NamedTuple):
# https://core.telegram.org/mtproto/description#defining-aes-key-and-initialization-vector
def calc_key(auth_key: AuthKey, msg_key: bytes | bytearray | memoryview, side: Side) -> CalcKey:
def calc_key(
auth_key: AuthKey, msg_key: bytes | bytearray | memoryview, side: Side
) -> CalcKey:
x = int(side)
# sha256_a = SHA256 (msg_key + substr (auth_key, x, 36))
@ -41,8 +43,12 @@ def determine_padding_v2_length(length: int) -> int:
return 16 + (16 - (length % 16))
def _do_encrypt_data_v2(plaintext: bytes, auth_key: AuthKey, random_padding: bytes) -> bytes:
padded_plaintext = plaintext + random_padding[: determine_padding_v2_length(len(plaintext))]
def _do_encrypt_data_v2(
plaintext: bytes, auth_key: AuthKey, random_padding: bytes
) -> bytes:
padded_plaintext = (
plaintext + random_padding[: determine_padding_v2_length(len(plaintext))]
)
side = Side.CLIENT
x = int(side)
@ -64,7 +70,9 @@ def encrypt_data_v2(plaintext: bytes, auth_key: AuthKey) -> bytes:
return _do_encrypt_data_v2(plaintext, auth_key, random_padding)
def decrypt_data_v2(ciphertext: bytes | bytearray | memoryview, auth_key: AuthKey) -> bytes:
def decrypt_data_v2(
ciphertext: bytes | bytearray | memoryview, auth_key: AuthKey
) -> bytes:
side = Side.SERVER
x = int(side)

View File

@ -34,13 +34,17 @@ def encrypt_hashed(data: bytes, key: PublicKey, random_bytes: bytes) -> bytes:
temp_key = random_bytes[192 + 32 * attempt : 192 + 32 * attempt + 32]
# data_with_hash := data_pad_reversed + SHA256(temp_key + data_with_padding); -- after this assignment, data_with_hash is exactly 224 bytes long.
data_with_hash = data_pad_reversed + sha256(temp_key + data_with_padding).digest()
data_with_hash = (
data_pad_reversed + sha256(temp_key + data_with_padding).digest()
)
# aes_encrypted := AES256_IGE(data_with_hash, temp_key, 0); -- AES256-IGE encryption with zero IV.
aes_encrypted = ige_encrypt(data_with_hash, temp_key, bytes(32))
# temp_key_xor := temp_key XOR SHA256(aes_encrypted); -- adjusted key, 32 bytes
temp_key_xor = bytes(a ^ b for a, b in zip(temp_key, sha256(aes_encrypted).digest()))
temp_key_xor = bytes(
a ^ b for a, b in zip(temp_key, sha256(aes_encrypted).digest())
)
# key_aes_encrypted := temp_key_xor + aes_encrypted; -- exactly 256 bytes (2048 bits) long
key_aes_encrypted = temp_key_xor + aes_encrypted
@ -83,4 +87,6 @@ j4WcDuXc2CTHgH8gFTNhp/Y8/SpDOhvn9QIDAQAB
)
RSA_KEYS = {compute_fingerprint(key): key for key in (PRODUCTION_RSA_KEY, TESTMODE_RSA_KEY)}
RSA_KEYS = {
compute_fingerprint(key): key for key in (PRODUCTION_RSA_KEY, TESTMODE_RSA_KEY)
}

View File

@ -20,7 +20,9 @@ def h(*data: bytes | bytearray | memoryview) -> bytes:
# SH(data, salt) := H(salt | data | salt)
def sh(data: bytes | bytearray | memoryview, salt: bytes | bytearray | memoryview) -> bytes:
def sh(
data: bytes | bytearray | memoryview, salt: bytes | bytearray | memoryview
) -> bytes:
return h(salt, data, salt)

View File

@ -108,9 +108,13 @@ def _do_step2(data: Step1, response: bytes, random_bytes: bytes) -> tuple[bytes,
)
try:
fingerprint = next(fp for fp in res_pq.server_public_key_fingerprints if fp in RSA_KEYS)
fingerprint = next(
fp for fp in res_pq.server_public_key_fingerprints if fp in RSA_KEYS
)
except StopIteration:
raise ValueError(f"unknown fingerprints: {res_pq.server_public_key_fingerprints}")
raise ValueError(
f"unknown fingerprints: {res_pq.server_public_key_fingerprints}"
)
key = RSA_KEYS[fingerprint]
ciphertext = encrypt_hashed(pq_inner_data, key, random_bytes)
@ -129,7 +133,9 @@ def step2(data: Step1, response: bytes) -> tuple[bytes, Step2]:
return _do_step2(data, response, os.urandom(288))
def _do_step3(data: Step2, response: bytes, random_bytes: bytes, now: int) -> tuple[bytes, Step3]:
def _do_step3(
data: Step2, response: bytes, random_bytes: bytes, now: int
) -> tuple[bytes, Step3]:
assert len(random_bytes) == 272
nonce = data.nonce
@ -152,7 +158,9 @@ def _do_step3(data: Step2, response: bytes, random_bytes: bytes, now: int) -> tu
check_server_nonce(server_dh_params.server_nonce, server_nonce)
if len(server_dh_params.encrypted_answer) % 16 != 0:
raise ValueError(f"encrypted response not padded with size: {len(server_dh_params.encrypted_answer)}")
raise ValueError(
f"encrypted response not padded with size: {len(server_dh_params.encrypted_answer)}"
)
key, iv = generate_key_data_from_nonce(server_nonce, new_nonce)
assert isinstance(server_dh_params.encrypted_answer, bytes)
@ -164,7 +172,9 @@ def _do_step3(data: Step2, response: bytes, random_bytes: bytes, now: int) -> tu
server_dh_inner = AbcServerDhInnerData._read_from(plain_text_reader)
assert isinstance(server_dh_inner, ServerDhInnerData)
expected_answer_hash = sha1(plain_text_answer[20 : 20 + plain_text_reader._pos]).digest()
expected_answer_hash = sha1(
plain_text_answer[20 : 20 + plain_text_reader._pos]
).digest()
if got_answer_hash != expected_answer_hash:
raise ValueError("invalid answer hash")
@ -203,11 +213,15 @@ def _do_step3(data: Step2, response: bytes, random_bytes: bytes, now: int) -> tu
)
client_dh_inner_hashed = sha1(client_dh_inner).digest() + client_dh_inner
client_dh_inner_hashed += random_bytes[: (16 - (len(client_dh_inner_hashed) % 16)) % 16]
client_dh_inner_hashed += random_bytes[
: (16 - (len(client_dh_inner_hashed) % 16)) % 16
]
client_dh_encrypted = encrypt_ige(client_dh_inner_hashed, key, iv)
return set_client_dh_params(nonce=nonce, server_nonce=server_nonce, encrypted_data=client_dh_encrypted), Step3(
return set_client_dh_params(
nonce=nonce, server_nonce=server_nonce, encrypted_data=client_dh_encrypted
), Step3(
nonce=nonce,
server_nonce=server_nonce,
new_nonce=new_nonce,
@ -263,7 +277,10 @@ def create_key(data: Step3, response: bytes) -> CreatedKey:
first_salt = struct.unpack(
"<q",
bytes(a ^ b for a, b in zip(new_nonce.to_bytes(32)[:8], server_nonce.to_bytes(16)[:8])),
bytes(
a ^ b
for a, b in zip(new_nonce.to_bytes(32)[:8], server_nonce.to_bytes(16)[:8])
),
)[0]
if dh_gen.nonce_number == 1:

View File

@ -107,7 +107,9 @@ class Encrypted(Mtp):
) -> None:
self._auth_key = auth_key
self._time_offset: int = time_offset or 0
self._salts: list[FutureSalt] = [FutureSalt(valid_since=0, valid_until=0x7FFFFFFF, salt=first_salt or 0)]
self._salts: list[FutureSalt] = [
FutureSalt(valid_since=0, valid_until=0x7FFFFFFF, salt=first_salt or 0)
]
self._start_salt_time: Optional[tuple[int, float]] = None
self._compression_threshold = compression_threshold
self._deserialization: list[Deserialization] = []
@ -201,7 +203,9 @@ class Encrypted(Mtp):
if self._msg_count == 1:
del self._buffer[:CONTAINER_HEADER_LEN]
self._buffer[:HEADER_LEN] = struct.pack("<qq", self._get_current_salt(), self._client_id)
self._buffer[:HEADER_LEN] = struct.pack(
"<qq", self._get_current_salt(), self._client_id
)
if self._msg_count != 1:
self._buffer[HEADER_LEN : HEADER_LEN + CONTAINER_HEADER_LEN] = struct.pack(
@ -279,7 +283,11 @@ class Encrypted(Mtp):
if isinstance(bad_msg, BadServerSalt):
self._salts.clear()
self._salts.append(FutureSalt(valid_since=0, valid_until=0x7FFFFFFF, salt=bad_msg.new_server_salt))
self._salts.append(
FutureSalt(
valid_since=0, valid_until=0x7FFFFFFF, salt=bad_msg.new_server_salt
)
)
self._salt_request_msg_id = None
elif bad_msg.error_code in (16, 17):
self._correct_time_offset(message.msg_id)
@ -316,7 +324,9 @@ class Encrypted(Mtp):
# Response to internal request, do not propagate.
self._salt_request_msg_id = None
else:
self._deserialization.append(RpcResult(MsgId(salts.req_msg_id), message.body))
self._deserialization.append(
RpcResult(MsgId(salts.req_msg_id), message.body)
)
self._start_salt_time = (salts.now, self._adjusted_now())
self._salts = list(salts.salts)
@ -336,7 +346,11 @@ class Encrypted(Mtp):
def _handle_new_session_created(self, message: Message) -> None:
new_session = NewSessionCreated.from_bytes(message.body)
self._salts.clear()
self._salts.append(FutureSalt(valid_since=0, valid_until=0x7FFFFFFF, salt=new_session.server_salt))
self._salts.append(
FutureSalt(
valid_since=0, valid_until=0x7FFFFFFF, salt=new_session.server_salt
)
)
def _handle_container(self, message: Message) -> None:
container = MsgContainer.from_bytes(message.body)
@ -362,7 +376,11 @@ class Encrypted(Mtp):
self._deserialization.append(Update(message.body))
def _try_request_salts(self) -> None:
if len(self._salts) == 1 and self._salt_request_msg_id is None and self._get_current_salt() != 0:
if (
len(self._salts) == 1
and self._salt_request_msg_id is None
and self._get_current_salt() != 0
):
# If salts are requested in a container leading to bad_msg,
# the bad_msg_id will refer to the container, not the salts request.
#
@ -370,7 +388,9 @@ class Encrypted(Mtp):
# This would break, because we couldn't identify the response.
#
# So salts are only requested once we have a valid salt to reduce the chances of this happening.
self._salt_request_msg_id = self._serialize_msg(bytes(get_future_salts(num=NUM_FUTURE_SALTS)), True)
self._salt_request_msg_id = self._serialize_msg(
bytes(get_future_salts(num=NUM_FUTURE_SALTS)), True
)
def push(self, request: bytes) -> Optional[MsgId]:
if self._start_salt_time and len(self._salts) >= 2:
@ -415,7 +435,9 @@ class Encrypted(Mtp):
return MsgId(self._last_msg_id), encrypt_data_v2(result, self._auth_key)
def deserialize(self, payload: bytes | bytearray | memoryview) -> list[Deserialization]:
def deserialize(
self, payload: bytes | bytearray | memoryview
) -> list[Deserialization]:
check_message_buffer(payload)
plaintext = decrypt_data_v2(payload, self._auth_key)

View File

@ -31,7 +31,9 @@ class Plain(Mtp):
self._buffer.clear()
return MsgId(0), result
def deserialize(self, payload: bytes | bytearray | memoryview) -> list[Deserialization]:
def deserialize(
self, payload: bytes | bytearray | memoryview
) -> list[Deserialization]:
check_message_buffer(payload)
auth_key_id, msg_id, length = struct.unpack_from("<qqi", payload)
@ -46,6 +48,8 @@ class Plain(Mtp):
raise ValueError(f"bad length: expected >= 0, got: {length}")
if 20 + length > len(payload):
raise ValueError(f"message too short, expected: {20 + length}, got {len(payload)}")
raise ValueError(
f"message too short, expected: {20 + length}, got {len(payload)}"
)
return [RpcResult(MsgId(0), bytes(payload[20 : 20 + length]))]

View File

@ -116,7 +116,11 @@ class RpcError(ValueError):
def __eq__(self, other: object) -> bool:
if not isinstance(other, self.__class__):
return NotImplemented
return self._code == other._code and self._name == other._name and self._value == other._value
return (
self._code == other._code
and self._name == other._name
and self._value == other._value
)
# https://core.telegram.org/mtproto/service_messages_about_messages
@ -152,7 +156,9 @@ class BadMessage(ValueError):
self.msg_id = msg_id
self._code = code
self._caused_by = caused_by
self.severity = logging.WARNING if self._code in NON_FATAL_MSG_IDS else logging.ERROR
self.severity = (
logging.WARNING if self._code in NON_FATAL_MSG_IDS else logging.ERROR
)
@property
def code(self) -> int:
@ -195,7 +201,9 @@ class Mtp(ABC):
"""
@abstractmethod
def deserialize(self, payload: bytes | bytearray | memoryview) -> list[Deserialization]:
def deserialize(
self, payload: bytes | bytearray | memoryview
) -> list[Deserialization]:
"""
Deserialize incoming buffer payload.
"""

View File

@ -42,7 +42,10 @@ class Intermediate(Transport):
raise MissingBytes(expected=length, got=len(input))
if length <= 4:
if length >= 4 and (status := struct.unpack("<i", input[4 : 4 + length])[0]) < 0:
if (
length >= 4
and (status := struct.unpack("<i", input[4 : 4 + length])[0]) < 0
):
raise BadStatus(status=-status)
raise ValueError(f"bad length, expected > 0, got: {length}")

View File

@ -12,7 +12,9 @@ MESSAGE_SIZE_OVERHEAD = 8 + 4 + 4 # msg_id, seq_no, bytes
def check_message_buffer(message: bytes | bytearray | memoryview) -> None:
if len(message) < 20:
raise ValueError(f"server payload is too small to be a valid message: {message.hex()}")
raise ValueError(
f"server payload is too small to be a valid message: {message.hex()}"
)
# https://core.telegram.org/mtproto/description#content-related-message

View File

@ -313,7 +313,13 @@ class Sender:
def _on_ping_timeout(self) -> None:
ping_id = generate_random_id()
self._enqueue_body(bytes(ping_delay_disconnect(ping_id=ping_id, disconnect_delay=NO_PING_DISCONNECT)))
self._enqueue_body(
bytes(
ping_delay_disconnect(
ping_id=ping_id, disconnect_delay=NO_PING_DISCONNECT
)
)
)
self._next_ping = asyncio.get_running_loop().time() + PING_DELAY
def _process_mtp_buffer(self, updates: list[Updates]) -> None:
@ -329,7 +335,9 @@ class Sender:
else:
self._process_bad_message(result)
def _process_update(self, updates: list[Updates], update: bytes | bytearray | memoryview) -> None:
def _process_update(
self, updates: list[Updates], update: bytes | bytearray | memoryview
) -> None:
try:
updates.append(Updates.from_bytes(update))
except ValueError:
@ -433,7 +441,9 @@ class Sender:
req.state.msg_id == msg_id or req.state.container_msg_id == msg_id
):
raise RuntimeError("got response for unsent request")
elif isinstance(req.state, Sent) and (req.state.msg_id == msg_id or req.state.container_msg_id == msg_id):
elif isinstance(req.state, Sent) and (
req.state.msg_id == msg_id or req.state.container_msg_id == msg_id
):
yield self._requests.pop(i)
@property

View File

@ -65,7 +65,9 @@ class ChatHashCache:
return self._has_peer(peer.peer)
elif isinstance(peer, types.NotifyForumTopic):
return self._has_peer(peer.peer)
elif isinstance(peer, (types.NotifyUsers, types.NotifyChats, types.NotifyBroadcasts)):
elif isinstance(
peer, (types.NotifyUsers, types.NotifyChats, types.NotifyBroadcasts)
):
return True
else:
raise RuntimeError("unexpected case")
@ -118,7 +120,9 @@ class ChatHashCache:
elif isinstance(participant, types.ChannelParticipantAdmin):
return (
self._has(participant.user_id)
and (participant.inviter_id is None or self._has(participant.inviter_id))
and (
participant.inviter_id is None or self._has(participant.inviter_id)
)
and self._has(participant.promoted_by)
)
elif isinstance(participant, types.ChannelParticipantBanned):

View File

@ -43,8 +43,12 @@ class PeerRef(abc.ABC):
__slots__ = ("identifier", "authorization")
def __init__(self, identifier: PeerIdentifier, authorization: PeerAuth = None) -> None:
assert identifier >= 0, "PeerRef identifiers must be positive; see the documentation for Peers"
def __init__(
self, identifier: PeerIdentifier, authorization: PeerAuth = None
) -> None:
assert (
identifier >= 0
), "PeerRef identifiers must be positive; see the documentation for Peers"
self.identifier = identifier
self.authorization = authorization
@ -79,7 +83,9 @@ class PeerRef(abc.ABC):
authorization: Optional[int] = None
else:
try:
(authorization,) = struct.unpack("!q", base64.urlsafe_b64decode(auth.encode("ascii") + b"="))
(authorization,) = struct.unpack(
"!q", base64.urlsafe_b64decode(auth.encode("ascii") + b"=")
)
except Exception:
raise ValueError(f"invalid PeerRef string: {string!r}")
@ -131,14 +137,21 @@ class PeerRef(abc.ABC):
if self.authorization is None:
auth = "0"
else:
auth = base64.urlsafe_b64encode(struct.pack("!q", self.authorization)).decode("ascii").rstrip("=")
auth = (
base64.urlsafe_b64encode(struct.pack("!q", self.authorization))
.decode("ascii")
.rstrip("=")
)
return f"{self.identifier}.{auth}"
def __eq__(self, other: object) -> bool:
if not isinstance(other, self.__class__):
return NotImplemented
return self.identifier == other.identifier and self.authorization == other.authorization
return (
self.identifier == other.identifier
and self.authorization == other.authorization
)
@property
def _ref(self) -> UserRef | GroupRef | ChannelRef:
@ -169,12 +182,16 @@ class UserRef(PeerRef):
def _to_input_peer(self) -> abcs.InputPeer:
if self.identifier == SELF_USER_SENTINEL_ID:
return types.InputPeerSelf()
return types.InputPeerUser(user_id=self.identifier, access_hash=self.authorization or 0)
return types.InputPeerUser(
user_id=self.identifier, access_hash=self.authorization or 0
)
def _to_input_user(self) -> abcs.InputUser:
if self.identifier == SELF_USER_SENTINEL_ID:
return types.InputUserSelf()
return types.InputUser(user_id=self.identifier, access_hash=self.authorization or 0)
return types.InputUser(
user_id=self.identifier, access_hash=self.authorization or 0
)
def __str__(self) -> str:
return f"{USER_PREFIX}{self._encode_str()}"
@ -235,10 +252,14 @@ class ChannelRef(PeerRef):
return types.PeerChannel(channel_id=self.identifier)
def _to_input_peer(self) -> abcs.InputPeer:
return types.InputPeerChannel(channel_id=self.identifier, access_hash=self.authorization or 0)
return types.InputPeerChannel(
channel_id=self.identifier, access_hash=self.authorization or 0
)
def _to_input_channel(self) -> types.InputChannel:
return types.InputChannel(channel_id=self.identifier, access_hash=self.authorization or 0)
return types.InputChannel(
channel_id=self.identifier, access_hash=self.authorization or 0
)
def __str__(self) -> str:
return f"{CHANNEL_PREFIX}{self._encode_str()}"

View File

@ -27,7 +27,9 @@ def update_short(short: types.UpdateShort) -> types.UpdatesCombined:
)
def update_short_message(short: types.UpdateShortMessage, self_id: int) -> types.UpdatesCombined:
def update_short_message(
short: types.UpdateShortMessage, self_id: int
) -> types.UpdatesCombined:
return update_short(
types.UpdateShort(
update=types.UpdateNewMessage(
@ -44,7 +46,9 @@ def update_short_message(short: types.UpdateShortMessage, self_id: int) -> types
noforwards=False,
reactions=None,
id=short.id,
from_id=types.PeerUser(user_id=self_id if short.out else short.user_id),
from_id=types.PeerUser(
user_id=self_id if short.out else short.user_id
),
peer_id=types.PeerChat(
chat_id=short.user_id,
),

View File

@ -38,7 +38,9 @@ class PtsInfo:
self.entry = entry
def __repr__(self) -> str:
return f"PtsInfo(pts={self.pts}, pts_count={self.pts_count}, entry={self.entry})"
return (
f"PtsInfo(pts={self.pts}, pts_count={self.pts_count}, entry={self.entry})"
)
class State:
@ -68,7 +70,9 @@ class PossibleGap:
self.updates = updates
def __repr__(self) -> str:
return f"PossibleGap(deadline={self.deadline}, update_count={len(self.updates)})"
return (
f"PossibleGap(deadline={self.deadline}, update_count={len(self.updates)})"
)
class PrematureEndReason(Enum):

View File

@ -91,9 +91,13 @@ class MessageBox:
self.map[ENTRY_ACCOUNT] = State(pts=state.pts, deadline=deadline)
if state.qts != NO_SEQ:
self.map[ENTRY_SECRET] = State(pts=state.qts, deadline=deadline)
self.map.update((s.id, State(pts=s.pts, deadline=deadline)) for s in state.channels)
self.map.update(
(s.id, State(pts=s.pts, deadline=deadline)) for s in state.channels
)
self.date = datetime.datetime.fromtimestamp(state.date, tz=datetime.timezone.utc)
self.date = datetime.datetime.fromtimestamp(
state.date, tz=datetime.timezone.utc
)
self.seq = state.seq
self.possible_gaps.clear()
self.getting_diff_for.clear()
@ -132,18 +136,28 @@ class MessageBox:
default_deadline = next_updates_deadline()
if self.possible_gaps:
deadline = min(default_deadline, *(gap.deadline for gap in self.possible_gaps.values()))
deadline = min(
default_deadline, *(gap.deadline for gap in self.possible_gaps.values())
)
elif self.next_deadline in self.map:
deadline = min(default_deadline, self.map[self.next_deadline].deadline)
else:
deadline = default_deadline
if now >= deadline:
self.getting_diff_for.update(entry for entry, gap in self.possible_gaps.items() if now >= gap.deadline)
self.getting_diff_for.update(entry for entry, state in self.map.items() if now >= state.deadline)
self.getting_diff_for.update(
entry
for entry, gap in self.possible_gaps.items()
if now >= gap.deadline
)
self.getting_diff_for.update(
entry for entry, state in self.map.items() if now >= state.deadline
)
if __debug__:
self._trace("deadlines met, now getting diff for: %r", self.getting_diff_for)
self._trace(
"deadlines met, now getting diff for: %r", self.getting_diff_for
)
for entry in self.getting_diff_for:
self.possible_gaps.pop(entry, None)
@ -157,12 +171,19 @@ class MessageBox:
entry: Entry = ENTRY_ACCOUNT # for pyright to know it's not unbound
for entry in entries:
if entry not in self.map:
raise RuntimeError("Called reset_deadline on an entry for which we do not have state")
raise RuntimeError(
"Called reset_deadline on an entry for which we do not have state"
)
self.map[entry].deadline = deadline
if self.next_deadline in entries:
self.next_deadline = min(self.map.items(), key=lambda entry_state: entry_state[1].deadline)[0]
elif self.next_deadline in self.map and deadline < self.map[self.next_deadline].deadline:
self.next_deadline = min(
self.map.items(), key=lambda entry_state: entry_state[1].deadline
)[0]
elif (
self.next_deadline in self.map
and deadline < self.map[self.next_deadline].deadline
):
self.next_deadline = entry
def reset_channel_deadline(self, channel_id: int, timeout: Optional[float]) -> None:
@ -179,7 +200,9 @@ class MessageBox:
assert isinstance(state, types.updates.State)
self.map[ENTRY_ACCOUNT] = State(state.pts, deadline)
self.map[ENTRY_SECRET] = State(state.qts, deadline)
self.date = datetime.datetime.fromtimestamp(state.date, tz=datetime.timezone.utc)
self.date = datetime.datetime.fromtimestamp(
state.date, tz=datetime.timezone.utc
)
self.seq = state.seq
def try_set_channel_state(self, id: int, pts: int) -> None:
@ -192,11 +215,15 @@ class MessageBox:
def try_begin_get_diff(self, entry: Entry, reason: str) -> None:
if entry not in self.map:
if entry in self.possible_gaps:
raise RuntimeError("Should not have a possible_gap for an entry not in the state map")
raise RuntimeError(
"Should not have a possible_gap for an entry not in the state map"
)
return
if __debug__:
self._trace("marking entry=%r as needing difference because: %s", entry, reason)
self._trace(
"marking entry=%r as needing difference because: %s", entry, reason
)
self.getting_diff_for.add(entry)
self.possible_gaps.pop(entry, None)
@ -204,10 +231,14 @@ class MessageBox:
try:
self.getting_diff_for.remove(entry)
except KeyError:
raise RuntimeError("Called end_get_diff on an entry which was not getting diff for")
raise RuntimeError(
"Called end_get_diff on an entry which was not getting diff for"
)
self.reset_deadlines({entry}, next_updates_deadline())
assert entry not in self.possible_gaps, "gaps shouldn't be created while getting difference"
assert (
entry not in self.possible_gaps
), "gaps shouldn't be created while getting difference"
def ensure_known_peer_hashes(
self,
@ -215,7 +246,10 @@ class MessageBox:
chat_hashes: ChatHashCache,
) -> None:
if not chat_hashes.extend_from_updates(updates):
can_recover = not isinstance(updates, types.UpdateShort) or pts_info_from_update(updates.update) is not None
can_recover = (
not isinstance(updates, types.UpdateShort)
or pts_info_from_update(updates.update) is not None
)
if can_recover:
self.try_begin_get_diff(ENTRY_ACCOUNT, "missing hash")
raise Gap
@ -241,7 +275,9 @@ class MessageBox:
if combined.seq_start != NO_SEQ:
if self.seq + 1 > combined.seq_start:
if __debug__:
self._trace("skipping updates as they should have already been handled")
self._trace(
"skipping updates as they should have already been handled"
)
return result, combined.users, combined.chats
elif self.seq + 1 < combined.seq_start:
self.try_begin_get_diff(ENTRY_ACCOUNT, "detected gap")
@ -269,13 +305,17 @@ class MessageBox:
if __debug__:
self._trace("updating seq as local pts was updated too")
if combined.date != NO_DATE:
self.date = datetime.datetime.fromtimestamp(combined.date, tz=datetime.timezone.utc)
self.date = datetime.datetime.fromtimestamp(
combined.date, tz=datetime.timezone.utc
)
if combined.seq != NO_SEQ:
self.seq = combined.seq
if self.possible_gaps:
if __debug__:
self._trace("trying to re-apply count=%r possible gaps", len(self.possible_gaps))
self._trace(
"trying to re-apply count=%r possible gaps", len(self.possible_gaps)
)
for key in list(self.possible_gaps.keys()):
self.possible_gaps[key].updates.sort(key=update_sort_key)
@ -292,7 +332,9 @@ class MessageBox:
applied,
)
self.possible_gaps = {entry: gap for entry, gap in self.possible_gaps.items() if gap.updates}
self.possible_gaps = {
entry: gap for entry, gap in self.possible_gaps.items() if gap.updates
}
return result, combined.users, combined.chats
@ -342,7 +384,8 @@ class MessageBox:
)
if pts.entry not in self.possible_gaps:
self.possible_gaps[pts.entry] = PossibleGap(
deadline=asyncio.get_running_loop().time() + POSSIBLE_GAP_TIMEOUT,
deadline=asyncio.get_running_loop().time()
+ POSSIBLE_GAP_TIMEOUT,
updates=[],
)
@ -370,14 +413,20 @@ class MessageBox:
for entry in (ENTRY_ACCOUNT, ENTRY_SECRET):
if entry in self.getting_diff_for:
if entry not in self.map:
raise RuntimeError("Should not try to get difference for an entry without known state")
raise RuntimeError(
"Should not try to get difference for an entry without known state"
)
gd = functions.updates.get_difference(
pts=self.map[ENTRY_ACCOUNT].pts,
pts_limit=None,
pts_total_limit=None,
date=int(self.date.timestamp()),
qts=(self.map[ENTRY_SECRET].pts if ENTRY_SECRET in self.map else NO_SEQ),
qts=(
self.map[ENTRY_SECRET].pts
if ENTRY_SECRET in self.map
else NO_SEQ
),
qts_limit=None,
)
if __debug__:
@ -398,7 +447,9 @@ class MessageBox:
result: tuple[list[abcs.Update], Sequence[abcs.User], Sequence[abcs.Chat]]
if isinstance(diff, types.updates.DifferenceEmpty):
finish = True
self.date = datetime.datetime.fromtimestamp(diff.date, tz=datetime.timezone.utc)
self.date = datetime.datetime.fromtimestamp(
diff.date, tz=datetime.timezone.utc
)
self.seq = diff.seq
result = [], [], []
elif isinstance(diff, types.updates.Difference):
@ -451,7 +502,9 @@ class MessageBox:
assert isinstance(state, types.updates.State)
self.map[ENTRY_ACCOUNT].pts = state.pts
self.map[ENTRY_SECRET].pts = state.qts
self.date = datetime.datetime.fromtimestamp(state.date, tz=datetime.timezone.utc)
self.date = datetime.datetime.fromtimestamp(
state.date, tz=datetime.timezone.utc
)
self.seq = state.seq
updates, users, chats = self.process_updates(
@ -507,13 +560,19 @@ class MessageBox:
channel=channel,
filter=types.ChannelMessagesFilterEmpty(),
pts=state.pts,
limit=(BOT_CHANNEL_DIFF_LIMIT if chat_hashes.is_self_bot else USER_CHANNEL_DIFF_LIMIT),
limit=(
BOT_CHANNEL_DIFF_LIMIT
if chat_hashes.is_self_bot
else USER_CHANNEL_DIFF_LIMIT
),
)
if __debug__:
self._trace("requesting channel difference: %s", gd)
return gd
else:
raise RuntimeError("should not try to get difference for an entry without known state")
raise RuntimeError(
"should not try to get difference for an entry without known state"
)
else:
self.end_get_diff(entry)
self.map.pop(entry, None)
@ -579,7 +638,9 @@ class MessageBox:
else:
raise RuntimeError("unexpected case")
def end_channel_difference(self, channel_id: int, reason: PrematureEndReason) -> None:
def end_channel_difference(
self, channel_id: int, reason: PrematureEndReason
) -> None:
entry: Entry = channel_id
if __debug__:
self._trace("ending channel=%r difference: %s", entry, reason)

View File

@ -38,7 +38,9 @@ class SqliteSession(Storage):
if version == 7:
session = self._load_v7(c)
else:
raise ValueError("only migration from sqlite session format 7 supported")
raise ValueError(
"only migration from sqlite session format 7 supported"
)
self._reset(c)
self._get_or_init_version(c)
@ -103,7 +105,11 @@ class SqliteSession(Storage):
DataCenter(id=id, ipv4_addr=ipv4_addr, ipv6_addr=ipv6_addr, auth=auth)
for (id, ipv4_addr, ipv6_addr, auth) in datacenter
],
user=(User(id=user[0], dc=user[1], bot=bool(user[2]), username=user[3]) if user else None),
user=(
User(id=user[0], dc=user[1], bot=bool(user[2]), username=user[3])
if user
else None
),
state=(
UpdateState(
pts=state[0],
@ -160,7 +166,9 @@ class SqliteSession(Storage):
@staticmethod
def _get_or_init_version(c: sqlite3.Cursor) -> int:
c.execute("select name from sqlite_master where type='table' and name='version'")
c.execute(
"select name from sqlite_master where type='table' and name='version'"
)
if c.fetchone():
c.execute("select version from version")
tup = c.fetchone()

View File

@ -23,7 +23,9 @@ def _bootstrap_get_ty(constructor_id: int) -> Optional[Type["Serializable"]]:
from ..mtproto.layer import TYPE_MAPPING as MTPROTO_TYPES
if API_TYPES.keys() & MTPROTO_TYPES.keys():
raise RuntimeError("generated api and mtproto schemas cannot have colliding constructor identifiers")
raise RuntimeError(
"generated api and mtproto schemas cannot have colliding constructor identifiers"
)
ALL_TYPES = API_TYPES | MTPROTO_TYPES
# Signatures don't fully match, but this is a private method
@ -37,7 +39,9 @@ class Reader:
__slots__ = ("_view", "_pos", "_len")
def __init__(self, buffer: "Buffer") -> None:
self._view = memoryview(buffer) if not isinstance(buffer, memoryview) else buffer
self._view = (
memoryview(buffer) if not isinstance(buffer, memoryview) else buffer
)
self._pos = 0
self._len = len(self._view)

View File

@ -14,7 +14,9 @@ def _bootstrap_get_deserializer(
from ..mtproto.layer import RESPONSE_MAPPING as MTPROTO_DESER
if API_DESER.keys() & MTPROTO_DESER.keys():
raise RuntimeError("generated api and mtproto schemas cannot have colliding constructor identifiers")
raise RuntimeError(
"generated api and mtproto schemas cannot have colliding constructor identifiers"
)
ALL_DESER = API_DESER | MTPROTO_DESER
Request._get_deserializer = ALL_DESER.get # type: ignore [assignment]

View File

@ -49,7 +49,9 @@ class Serializable(abc.ABC):
def __eq__(self, other: object) -> bool:
if not isinstance(other, self.__class__):
return NotImplemented
return all(getattr(self, attr) == getattr(other, attr) for attr in self.__slots__)
return all(
getattr(self, attr) == getattr(other, attr) for attr in self.__slots__
)
def serialize_bytes_to(buffer: bytearray, data: bytes | bytearray | memoryview) -> None:

View File

@ -26,16 +26,25 @@ def test_auth_key_id() -> None:
def test_calc_new_nonce_hash1() -> None:
auth_key = get_auth_key()
new_nonce = get_new_nonce()
assert auth_key.calc_new_nonce_hash(new_nonce, 1) == 258944117842285651226187582903746985063
assert (
auth_key.calc_new_nonce_hash(new_nonce, 1)
== 258944117842285651226187582903746985063
)
def test_calc_new_nonce_hash2() -> None:
auth_key = get_auth_key()
new_nonce = get_new_nonce()
assert auth_key.calc_new_nonce_hash(new_nonce, 2) == 324588944215647649895949797213421233055
assert (
auth_key.calc_new_nonce_hash(new_nonce, 2)
== 324588944215647649895949797213421233055
)
def test_calc_new_nonce_hash3() -> None:
auth_key = get_auth_key()
new_nonce = get_new_nonce()
assert auth_key.calc_new_nonce_hash(new_nonce, 3) == 100989356540453064705070297823778556733
assert (
auth_key.calc_new_nonce_hash(new_nonce, 3)
== 100989356540453064705070297823778556733
)

View File

@ -66,8 +66,14 @@ def test_key_from_nonce() -> None:
new_nonce = int.from_bytes(bytes(range(32)))
key, iv = generate_key_data_from_nonce(server_nonce, new_nonce)
assert key == b'\x07X\xf1S;a]$\xf6\xe8\xa9Jo\xcb\xee\nU\xea\xab"\x17\xd7)\\\xa9!=\x1a-}\x16\xa6'
assert iv == b"Z\x84\x10\x8e\x98\x05el\xe8d\x07\x0e\x16nb\x18\xf6x>\x85\x11G\x1aZ\xb7\x80,\xf2\x00\x01\x02\x03"
assert (
key
== b'\x07X\xf1S;a]$\xf6\xe8\xa9Jo\xcb\xee\nU\xea\xab"\x17\xd7)\\\xa9!=\x1a-}\x16\xa6'
)
assert (
iv
== b"Z\x84\x10\x8e\x98\x05el\xe8d\x07\x0e\x16nb\x18\xf6x>\x85\x11G\x1aZ\xb7\x80,\xf2\x00\x01\x02\x03"
)
def test_verify_ige_encryption() -> None:
@ -82,7 +88,5 @@ def test_verify_ige_decryption() -> None:
ciphertext = get_test_aes_key_or_iv()
key = get_test_aes_key_or_iv()
iv = get_test_aes_key_or_iv()
expected = (
b"\xe5wz\xfa\xcd{,\x16\xf7\xac@\xca\xe6\x1e\xf6\x03\xfe\xe6\t\x8f\xb8\xa8\x86\n\xb9\xeeg,\xd7\xe5\xba\xcc"
)
expected = b"\xe5wz\xfa\xcd{,\x16\xf7\xac@\xca\xe6\x1e\xf6\x03\xfe\xe6\t\x8f\xb8\xa8\x86\n\xb9\xeeg,\xd7\xe5\xba\xcc"
assert decrypt_ige(ciphertext, key, iv) == expected

View File

@ -32,7 +32,10 @@ def test_parse_all_entities_markdown() -> None:
markdown = "Some **bold** (__strong__), *italics* (_cursive_), inline `code`, a\n```rust\npre\n```\nblock, a [link](https://example.com), and [mentions](tg://user?id=12345678)"
text, entities = parse_markdown_message(markdown)
assert text == "Some bold (strong), italics (cursive), inline code, a\npre\nblock, a link, and mentions"
assert (
text
== "Some bold (strong), italics (cursive), inline code, a\npre\nblock, a link, and mentions"
)
assert entities == [
types.MessageEntityBold(offset=5, length=4),
types.MessageEntityBold(offset=11, length=6),
@ -89,7 +92,10 @@ def test_parse_emoji_html() -> None:
def test_parse_all_entities_html() -> None:
html = 'Some <b>bold</b> (<strong>strong</strong>), <i>italics</i> (<em>cursive</em>), inline <code>code</code>, a <pre>pre</pre> block, a <a href="https://example.com">link</a>, <details>spoilers</details> and <a href="tg://user?id=12345678">mentions</a>'
text, entities = parse_html_message(html)
assert text == "Some bold (strong), italics (cursive), inline code, a pre block, a link, spoilers and mentions"
assert (
text
== "Some bold (strong), italics (cursive), inline code, a pre block, a link, spoilers and mentions"
)
assert entities == [
types.MessageEntityBold(offset=5, length=4),
types.MessageEntityBold(offset=11, length=6),

View File

@ -37,9 +37,6 @@ build-backend = "setuptools.build_meta"
[tool.setuptools.dynamic]
version = {attr = "telethon_generator.version.__version__"}
[tool.ruff]
line-length = 120
[tool.ruff.lint]
select = ["F", "E", "W", "I"]
ignore = [

View File

@ -17,7 +17,9 @@ from .serde.deserialization import (
from .serde.serialization import generate_function, generate_write
def generate_init(writer: SourceWriter, namespaces: set[str], classes: set[str]) -> None:
def generate_init(
writer: SourceWriter, namespaces: set[str], classes: set[str]
) -> None:
sorted_cls = list(sorted(classes))
sorted_ns = list(sorted(namespaces))
@ -91,7 +93,9 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
writer = fs.open(type_path)
if type_path not in fs:
writer.write("# pyright: reportUnusedImport=false, reportConstantRedefinition=false")
writer.write(
"# pyright: reportUnusedImport=false, reportConstantRedefinition=false"
)
writer.write("import struct")
writer.write("from typing import Optional, Self, Sequence")
writer.write("from .. import abcs")
@ -102,7 +106,9 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
generated_type_names.add(f"{ns}{to_class_name(typedef.name)}")
# class Type(BaseType)
writer.write(f"class {to_class_name(typedef.name)}({inner_type_fmt(typedef.ty)}):")
writer.write(
f"class {to_class_name(typedef.name)}({inner_type_fmt(typedef.ty)}):"
)
# __slots__ = ('params', ...)
slots = " ".join(f"'{p.name}'," for p in property_params)
@ -115,7 +121,9 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
# def __init__()
if property_params:
params = "".join(f", {p.name}: {param_type_fmt(p.ty)}" for p in property_params)
params = "".join(
f", {p.name}: {param_type_fmt(p.ty)}" for p in property_params
)
writer.write(f" def __init__(_s, *{params}) -> None:")
for p in property_params:
writer.write(f" _s.{p.name} = {p.name}")
@ -143,7 +151,9 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
raise ValueError("nested function-namespaces are not supported")
elif len(functiondef.namespace) == 1:
function_namespaces.add(functiondef.namespace[0])
function_path = (Path("functions") / functiondef.namespace[0]).with_suffix(".py")
function_path = (Path("functions") / functiondef.namespace[0]).with_suffix(
".py"
)
else:
function_def_names.add(to_method_name(functiondef.name))
function_path = Path("functions/_nons.py")
@ -163,14 +173,18 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
params = "".join(f", {p.name}: {param_type_fmt(p.ty)}" for p in required_params)
star = "*" if params else ""
return_ty = param_type_fmt(NormalParameter(ty=functiondef.ty, flag=None))
writer.write(f"def {to_method_name(functiondef.name)}({star}{params}) -> Request[{return_ty}]:")
writer.write(
f"def {to_method_name(functiondef.name)}({star}{params}) -> Request[{return_ty}]:"
)
writer.indent(2)
generate_function(writer, functiondef)
writer.dedent(2)
generate_init(fs.open(Path("abcs/__init__.py")), abc_namespaces, abc_class_names)
generate_init(fs.open(Path("types/__init__.py")), type_namespaces, type_class_names)
generate_init(fs.open(Path("functions/__init__.py")), function_namespaces, function_def_names)
generate_init(
fs.open(Path("functions/__init__.py")), function_namespaces, function_def_names
)
writer = fs.open(Path("layer.py"))
writer.write("# pyright: reportUnusedImport=false")
@ -180,12 +194,16 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
)
writer.write("from typing import cast, Type")
writer.write(f"LAYER = {tl.layer!r}")
writer.write("TYPE_MAPPING = {t.constructor_id(): t for t in cast(tuple[Type[Serializable]], (")
writer.write(
"TYPE_MAPPING = {t.constructor_id(): t for t in cast(tuple[Type[Serializable]], ("
)
for name in sorted(generated_type_names):
writer.write(f" types.{name},")
writer.write("))}")
writer.write("RESPONSE_MAPPING = {")
for functiondef in tl.functiondefs:
writer.write(f" {hex(functiondef.id)}: {function_deserializer_fmt(functiondef)},")
writer.write(
f" {hex(functiondef.id)}: {function_deserializer_fmt(functiondef)},"
)
writer.write("}")
writer.write("__all__ = ['LAYER', 'TYPE_MAPPING', 'RESPONSE_MAPPING']")

View File

@ -50,7 +50,9 @@ _TRIVIAL_STRUCT_MAP = {"int": "i", "long": "q", "double": "d", "Bool": "I"}
def trivial_struct_fmt(ty: BaseParameter) -> str:
try:
return _TRIVIAL_STRUCT_MAP[ty.ty.name] if isinstance(ty, NormalParameter) else "I"
return (
_TRIVIAL_STRUCT_MAP[ty.ty.name] if isinstance(ty, NormalParameter) else "I"
)
except KeyError:
raise ValueError("input param was not trivial")

View File

@ -38,7 +38,9 @@ def reader_read_fmt(ty: Type, constructor_id: int) -> tuple[str, Optional[str]]:
return f"reader.read_serializable({inner_type_fmt(ty)})", "type-abstract"
def generate_normal_param_read(writer: SourceWriter, name: str, param: NormalParameter, constructor_id: int) -> None:
def generate_normal_param_read(
writer: SourceWriter, name: str, param: NormalParameter, constructor_id: int
) -> None:
flag_check = f"_{param.flag.name} & {1 << param.flag.index}" if param.flag else None
if param.ty.name == "true":
if not flag_check:
@ -53,7 +55,9 @@ def generate_normal_param_read(writer: SourceWriter, name: str, param: NormalPar
if param.ty.generic_arg:
if param.ty.name not in ("Vector", "vector"):
raise ValueError("generic_arg deserialization for non-vectors is not supported")
raise ValueError(
"generic_arg deserialization for non-vectors is not supported"
)
if param.ty.bare:
writer.write("__len = reader.read_fmt('<i', 4)[0]")
@ -66,12 +70,18 @@ def generate_normal_param_read(writer: SourceWriter, name: str, param: NormalPar
if is_trivial(generic):
fmt = trivial_struct_fmt(generic)
size = struct.calcsize(f"<{fmt}")
writer.write(f"_{name} = [*reader.read_fmt(f'<{{__len}}{fmt}', __len * {size})]")
writer.write(
f"_{name} = [*reader.read_fmt(f'<{{__len}}{fmt}', __len * {size})]"
)
if param.ty.generic_arg.name == "Bool":
writer.write(f"assert all(__x in (0xbc799737, 0x0x997275b5) for __x in _{name})")
writer.write(
f"assert all(__x in (0xbc799737, 0x0x997275b5) for __x in _{name})"
)
writer.write(f"_{name} = [_{name} == 0x997275b5]")
else:
fmt_read, type_ignore = reader_read_fmt(param.ty.generic_arg, constructor_id)
fmt_read, type_ignore = reader_read_fmt(
param.ty.generic_arg, constructor_id
)
comment = f" # type: ignore [{type_ignore}]" if type_ignore else ""
writer.write(f"_{name} = [{fmt_read} for _ in range(__len)]{comment}")
else:
@ -117,7 +127,9 @@ def param_value_fmt(param: Parameter) -> str:
def function_deserializer_fmt(defn: Definition) -> str:
if defn.ty.generic_arg:
if defn.ty.name != ("Vector"):
raise ValueError("generic_arg return for non-boxed-vectors is not supported")
raise ValueError(
"generic_arg return for non-boxed-vectors is not supported"
)
elif defn.ty.generic_ref:
raise ValueError("return for generic refs inside vector is not supported")
elif is_trivial(NormalParameter(ty=defn.ty.generic_arg, flag=None)):
@ -126,9 +138,13 @@ def function_deserializer_fmt(defn: Definition) -> str:
elif defn.ty.generic_arg.name == "long":
return "deserialize_i64_list"
else:
raise ValueError(f"return for trivial arg {defn.ty.generic_arg} is not supported")
raise ValueError(
f"return for trivial arg {defn.ty.generic_arg} is not supported"
)
elif defn.ty.generic_arg.bare:
raise ValueError("return for non-boxed serializables inside a vector is not supported")
raise ValueError(
"return for non-boxed serializables inside a vector is not supported"
)
else:
return f"list_deserializer({inner_type_fmt(defn.ty.generic_arg)})"
elif defn.ty.generic_ref:

View File

@ -15,11 +15,15 @@ def param_value_expr(param: Parameter) -> str:
return f"{pre}{mid}{suf}"
def generate_buffer_append(writer: SourceWriter, buffer: str, name: str, ty: Type) -> None:
def generate_buffer_append(
writer: SourceWriter, buffer: str, name: str, ty: Type
) -> None:
if is_trivial(NormalParameter(ty=ty, flag=None)):
fmt = trivial_struct_fmt(NormalParameter(ty=ty, flag=None))
if ty.name == "Bool":
writer.write(f"{buffer} += struct.pack(f'<{fmt}', (0x997275b5 if {name} else 0xbc799737))")
writer.write(
f"{buffer} += struct.pack(f'<{fmt}', (0x997275b5 if {name} else 0xbc799737))"
)
else:
writer.write(f"{buffer} += struct.pack(f'<{fmt}', {name})")
elif ty.generic_ref or ty.name == "Object":
@ -54,7 +58,9 @@ def generate_normal_param_write(
if param.ty.generic_arg:
if param.ty.name not in ("Vector", "vector"):
raise ValueError("generic_arg deserialization for non-vectors is not supported")
raise ValueError(
"generic_arg deserialization for non-vectors is not supported"
)
if param.ty.bare:
writer.write(f"{buffer} += struct.pack('<i', len({name}))")
@ -70,7 +76,9 @@ def generate_normal_param_write(
f"{buffer} += struct.pack(f'<{{len({name})}}{fmt}', *(0x997275b5 if {tmp} else 0xbc799737 for {tmp} in {name}))"
)
else:
writer.write(f"{buffer} += struct.pack(f'<{{len({name})}}{fmt}', *{name})")
writer.write(
f"{buffer} += struct.pack(f'<{{len({name})}}{fmt}', *{name})"
)
else:
tmp = next(tmp_names)
writer.write(f"for {tmp} in {name}:")
@ -102,7 +110,9 @@ def generate_write(writer: SourceWriter, defn: Definition) -> None:
else f"(0 if self.{p.name} is None else {1 << p.ty.flag.index})"
)
for p in defn.params
if isinstance(p.ty, NormalParameter) and p.ty.flag and p.ty.flag.name == param.name
if isinstance(p.ty, NormalParameter)
and p.ty.flag
and p.ty.flag.name == param.name
)
writer.write(f"_{param.name} = {flags or 0}")
@ -113,7 +123,9 @@ def generate_write(writer: SourceWriter, defn: Definition) -> None:
for param in iter:
if not isinstance(param.ty, NormalParameter):
raise RuntimeError("FlagsParameter should be considered trivial")
generate_normal_param_write(writer, tmp_names, "buffer", f"self.{param.name}", param.ty)
generate_normal_param_write(
writer, tmp_names, "buffer", f"self.{param.name}", param.ty
)
def generate_function(writer: SourceWriter, defn: Definition) -> None:
@ -136,7 +148,9 @@ def generate_function(writer: SourceWriter, defn: Definition) -> None:
else f"(0 if {p.name} is None else {1 << p.ty.flag.index})"
)
for p in defn.params
if isinstance(p.ty, NormalParameter) and p.ty.flag and p.ty.flag.name == param.name
if isinstance(p.ty, NormalParameter)
and p.ty.flag
and p.ty.flag.name == param.name
)
writer.write(f"{param.name} = {flags or 0}")
@ -147,5 +161,7 @@ def generate_function(writer: SourceWriter, defn: Definition) -> None:
for param in iter:
if not isinstance(param.ty, NormalParameter):
raise RuntimeError("FlagsParameter should be considered trivial")
generate_normal_param_write(writer, tmp_names, "_buffer", param.name, param.ty)
generate_normal_param_write(
writer, tmp_names, "_buffer", param.name, param.ty
)
writer.write("return Request(b'' + _buffer)")

View File

@ -35,4 +35,6 @@ def load_tl_file(path: str | Path) -> ParsedTl:
else:
functiondefs.append(definition)
return ParsedTl(layer=layer, typedefs=list(typedefs), functiondefs=list(functiondefs))
return ParsedTl(
layer=layer, typedefs=list(typedefs), functiondefs=list(functiondefs)
)

View File

@ -14,7 +14,9 @@ def gen_py_code(
functiondefs: Optional[list[Definition]] = None,
) -> str:
fs = FakeFs()
generate(fs, ParsedTl(layer=0, typedefs=typedefs or [], functiondefs=functiondefs or []))
generate(
fs, ParsedTl(layer=0, typedefs=typedefs or [], functiondefs=functiondefs or [])
)
generated = bytearray()
for path, data in fs._files.items():
if path.stem not in ("__init__", "layer"):
@ -25,7 +27,9 @@ def gen_py_code(
def test_generic_functions_use_bytes_parameters() -> None:
definitions = get_definitions("invokeWithLayer#da9b0d0d {X:Type} layer:int query:!X = X;")
definitions = get_definitions(
"invokeWithLayer#da9b0d0d {X:Type} layer:int query:!X = X;"
)
result = gen_py_code(functiondefs=definitions)
assert "invoke_with_layer" in result
assert "query: _bytes" in result

View File

@ -55,14 +55,18 @@ def test_valid_param() -> None:
assert Parameter.from_str("foo:!bar") == Parameter(
name="foo",
ty=NormalParameter(
ty=Type(namespace=[], name="bar", bare=True, generic_ref=True, generic_arg=None),
ty=Type(
namespace=[], name="bar", bare=True, generic_ref=True, generic_arg=None
),
flag=None,
),
)
assert Parameter.from_str("foo:bar.1?baz") == Parameter(
name="foo",
ty=NormalParameter(
ty=Type(namespace=[], name="baz", bare=True, generic_ref=False, generic_arg=None),
ty=Type(
namespace=[], name="baz", bare=True, generic_ref=False, generic_arg=None
),
flag=Flag(
name="bar",
index=1,

View File

@ -12,7 +12,9 @@ def test_empty_simple() -> None:
def test_simple() -> None:
assert Type.from_str("foo") == Type(namespace=[], name="foo", bare=True, generic_ref=False, generic_arg=None)
assert Type.from_str("foo") == Type(
namespace=[], name="foo", bare=True, generic_ref=False, generic_arg=None
)
@mark.parametrize("ty", [".", "..", ".foo", "foo.", "foo..foo", ".foo."])