mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-08-05 20:50:22 +03:00
Upgrade ruff and mypy version, format files
This commit is contained in:
parent
918f719ab2
commit
b25ec41a1f
|
@ -15,9 +15,7 @@ def make_link_node(rawtext, app, name, options):
|
|||
base += "/"
|
||||
|
||||
set_classes(options)
|
||||
node = nodes.reference(
|
||||
rawtext, utils.unescape(name), refuri="{}?q={}".format(base, name), **options
|
||||
)
|
||||
node = nodes.reference(rawtext, utils.unescape(name), refuri="{}?q={}".format(base, name), **options)
|
||||
return node
|
||||
|
||||
|
||||
|
|
|
@ -28,10 +28,8 @@ dynamic = ["version"]
|
|||
[project.optional-dependencies]
|
||||
cryptg = ["cryptg~=0.4"]
|
||||
dev = [
|
||||
"isort~=5.12",
|
||||
"black~=23.3.0",
|
||||
"mypy~=1.3",
|
||||
"ruff~=0.0.292",
|
||||
"mypy~=1.11.2",
|
||||
"ruff~=0.6.8",
|
||||
"pytest~=7.3",
|
||||
"pytest-asyncio~=0.21",
|
||||
]
|
||||
|
@ -55,6 +53,10 @@ backend-path = ["build_backend"]
|
|||
version = {attr = "telethon.version.__version__"}
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 120
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["F", "E", "W", "I"]
|
||||
ignore = [
|
||||
"E501", # formatter takes care of lines that are too long besides documentation
|
||||
]
|
||||
|
|
|
@ -31,9 +31,7 @@ async def complete_login(client: Client, auth: abcs.auth.Authorization) -> User:
|
|||
assert isinstance(auth, types.auth.Authorization)
|
||||
assert isinstance(auth.user, types.User)
|
||||
user = User._from_raw(auth.user)
|
||||
client._session.user = SessionUser(
|
||||
id=user.id, dc=client._sender.dc_id, bot=user.bot, username=user.username
|
||||
)
|
||||
client._session.user = SessionUser(id=user.id, dc=client._sender.dc_id, bot=user.bot, username=user.username)
|
||||
|
||||
client._chat_hashes.set_self_user(user.id, user.bot)
|
||||
|
||||
|
@ -56,9 +54,7 @@ async def complete_login(client: Client, auth: abcs.auth.Authorization) -> User:
|
|||
|
||||
async def handle_migrate(client: Client, dc_id: Optional[int]) -> None:
|
||||
assert dc_id is not None
|
||||
sender, client._session.dcs = await connect_sender(
|
||||
client._config, client._session.dcs, DataCenter(id=dc_id)
|
||||
)
|
||||
sender, client._session.dcs = await connect_sender(client._config, client._session.dcs, DataCenter(id=dc_id))
|
||||
async with client._sender_lock:
|
||||
client._sender = sender
|
||||
|
||||
|
@ -177,9 +173,7 @@ async def interactive_login(
|
|||
user = await self.check_password(user_or_token, password)
|
||||
else:
|
||||
while True:
|
||||
print(
|
||||
"Please enter your password (prompt is hidden; type and press enter)"
|
||||
)
|
||||
print("Please enter your password (prompt is hidden; type and press enter)")
|
||||
password = getpass.getpass(": ")
|
||||
try:
|
||||
user = await self.check_password(user_or_token, password)
|
||||
|
@ -208,13 +202,9 @@ async def get_password_information(client: Client) -> PasswordToken:
|
|||
return PasswordToken._new(result)
|
||||
|
||||
|
||||
async def check_password(
|
||||
self: Client, token: PasswordToken, password: str | bytes
|
||||
) -> User:
|
||||
async def check_password(self: Client, token: PasswordToken, password: str | bytes) -> User:
|
||||
algo = token._password.current_algo
|
||||
if not isinstance(
|
||||
algo, types.PasswordKdfAlgoSha256Sha256Pbkdf2HmacshA512Iter100000Sha256ModPow
|
||||
):
|
||||
if not isinstance(algo, types.PasswordKdfAlgoSha256Sha256Pbkdf2HmacshA512Iter100000Sha256ModPow):
|
||||
raise RuntimeError("unrecognised 2FA algorithm")
|
||||
|
||||
if not two_factor_auth.check_p_and_g(algo.p, algo.g):
|
||||
|
|
|
@ -38,11 +38,7 @@ class InlineResults(metaclass=NoPublicConstructor):
|
|||
result = await self._client(
|
||||
functions.messages.get_inline_bot_results(
|
||||
bot=self._bot,
|
||||
peer=(
|
||||
self._peer._to_input_peer()
|
||||
if self._peer
|
||||
else types.InputPeerEmpty()
|
||||
),
|
||||
peer=(self._peer._to_input_peer() if self._peer else types.InputPeerEmpty()),
|
||||
geo_point=None,
|
||||
query=self._query,
|
||||
offset=self._offset,
|
||||
|
@ -51,12 +47,8 @@ class InlineResults(metaclass=NoPublicConstructor):
|
|||
assert isinstance(result, types.messages.BotResults)
|
||||
self._offset = result.next_offset
|
||||
for r in reversed(result.results):
|
||||
assert isinstance(
|
||||
r, (types.BotInlineMediaResult, types.BotInlineResult)
|
||||
)
|
||||
self._buffer.append(
|
||||
InlineResult._create(self._client, result, r, self._peer)
|
||||
)
|
||||
assert isinstance(r, (types.BotInlineMediaResult, types.BotInlineResult))
|
||||
self._buffer.append(InlineResult._create(self._client, result, r, self._peer))
|
||||
|
||||
if not self._buffer:
|
||||
self._offset = None
|
||||
|
|
|
@ -53,9 +53,7 @@ class ParticipantList(AsyncList[Participant]):
|
|||
|
||||
seen_count = len(self._seen)
|
||||
for p in chanp.participants:
|
||||
part = Participant._from_raw_channel(
|
||||
self._client, self._peer, p, chat_map
|
||||
)
|
||||
part = Participant._from_raw_channel(self._client, self._peer, p, chat_map)
|
||||
pid = part._peer_id()
|
||||
if pid not in self._seen:
|
||||
self._seen.add(pid)
|
||||
|
@ -66,9 +64,7 @@ class ParticipantList(AsyncList[Participant]):
|
|||
self._done = len(self._seen) == seen_count
|
||||
|
||||
else:
|
||||
chatp = await self._client(
|
||||
functions.messages.get_full_chat(chat_id=self._peer._to_input_chat())
|
||||
)
|
||||
chatp = await self._client(functions.messages.get_full_chat(chat_id=self._peer._to_input_chat()))
|
||||
assert isinstance(chatp, types.messages.ChatFull)
|
||||
assert isinstance(chatp.full_chat, types.ChatFull)
|
||||
|
||||
|
@ -87,17 +83,14 @@ class ParticipantList(AsyncList[Participant]):
|
|||
)
|
||||
elif isinstance(participants, types.ChatParticipants):
|
||||
self._buffer.extend(
|
||||
Participant._from_raw_chat(self._client, self._peer, p, chat_map)
|
||||
for p in participants.participants
|
||||
Participant._from_raw_chat(self._client, self._peer, p, chat_map) for p in participants.participants
|
||||
)
|
||||
|
||||
self._total = len(self._buffer)
|
||||
self._done = True
|
||||
|
||||
|
||||
def get_participants(
|
||||
self: Client, chat: Group | Channel | GroupRef | ChannelRef, /
|
||||
) -> AsyncList[Participant]:
|
||||
def get_participants(self: Client, chat: Group | Channel | GroupRef | ChannelRef, /) -> AsyncList[Participant]:
|
||||
return ParticipantList(self, chat._ref)
|
||||
|
||||
|
||||
|
@ -137,9 +130,7 @@ class RecentActionList(AsyncList[RecentAction]):
|
|||
self._offset = min(e.id for e in self._buffer)
|
||||
|
||||
|
||||
def get_admin_log(
|
||||
self: Client, chat: Group | Channel | GroupRef | ChannelRef, /
|
||||
) -> AsyncList[RecentAction]:
|
||||
def get_admin_log(self: Client, chat: Group | Channel | GroupRef | ChannelRef, /) -> AsyncList[RecentAction]:
|
||||
return RecentActionList(self, chat._ref)
|
||||
|
||||
|
||||
|
@ -174,11 +165,7 @@ class ProfilePhotoList(AsyncList[File]):
|
|||
else:
|
||||
raise RuntimeError("unexpected case")
|
||||
|
||||
self._buffer.extend(
|
||||
filter(
|
||||
None, (File._try_from_raw_photo(self._client, p) for p in photos)
|
||||
)
|
||||
)
|
||||
self._buffer.extend(filter(None, (File._try_from_raw_photo(self._client, p) for p in photos)))
|
||||
|
||||
|
||||
def get_profile_photos(self: Client, peer: Peer | PeerRef, /) -> AsyncList[File]:
|
||||
|
@ -256,11 +243,7 @@ async def set_chat_default_restrictions(
|
|||
*,
|
||||
until: Optional[datetime.datetime] = None,
|
||||
) -> None:
|
||||
banned_rights = ChatRestriction._set_to_raw(
|
||||
set(restrictions), int(until.timestamp()) if until else 0x7FFFFFFF
|
||||
)
|
||||
banned_rights = ChatRestriction._set_to_raw(set(restrictions), int(until.timestamp()) if until else 0x7FFFFFFF)
|
||||
await self(
|
||||
functions.messages.edit_chat_default_banned_rights(
|
||||
peer=chat._ref._to_input_peer(), banned_rights=banned_rights
|
||||
)
|
||||
functions.messages.edit_chat_default_banned_rights(peer=chat._ref._to_input_peer(), banned_rights=banned_rights)
|
||||
)
|
||||
|
|
|
@ -237,9 +237,7 @@ class Client:
|
|||
lang_code=lang_code or "en",
|
||||
catch_up=catch_up or False,
|
||||
datacenter=datacenter,
|
||||
flood_sleep_threshold=(
|
||||
60 if flood_sleep_threshold is None else flood_sleep_threshold
|
||||
),
|
||||
flood_sleep_threshold=(60 if flood_sleep_threshold is None else flood_sleep_threshold),
|
||||
update_queue_limit=update_queue_limit,
|
||||
base_logger=base_logger,
|
||||
connector=connector or (lambda ip, port: asyncio.open_connection(ip, port)),
|
||||
|
@ -250,8 +248,8 @@ class Client:
|
|||
self._message_box = MessageBox(base_logger=base_logger)
|
||||
self._chat_hashes = ChatHashCache(None)
|
||||
self._last_update_limit_warn: Optional[float] = None
|
||||
self._updates: asyncio.Queue[tuple[abcs.Update, dict[int, Peer]]] = (
|
||||
asyncio.Queue(maxsize=self._config.update_queue_limit or 0)
|
||||
self._updates: asyncio.Queue[tuple[abcs.Update, dict[int, Peer]]] = asyncio.Queue(
|
||||
maxsize=self._config.update_queue_limit or 0
|
||||
)
|
||||
self._dispatcher: Optional[asyncio.Task[None]] = None
|
||||
self._handlers: dict[
|
||||
|
@ -411,9 +409,7 @@ class Client:
|
|||
"""
|
||||
await delete_dialog(self, dialog)
|
||||
|
||||
async def delete_messages(
|
||||
self, chat: Peer | PeerRef, /, message_ids: list[int], *, revoke: bool = True
|
||||
) -> int:
|
||||
async def delete_messages(self, chat: Peer | PeerRef, /, message_ids: list[int], *, revoke: bool = True) -> int:
|
||||
"""
|
||||
Delete messages.
|
||||
|
||||
|
@ -653,9 +649,7 @@ class Client:
|
|||
"""
|
||||
return await forward_messages(self, target, message_ids, source)
|
||||
|
||||
def get_admin_log(
|
||||
self, chat: Group | Channel | GroupRef | ChannelRef, /
|
||||
) -> AsyncList[RecentAction]:
|
||||
def get_admin_log(self, chat: Group | Channel | GroupRef | ChannelRef, /) -> AsyncList[RecentAction]:
|
||||
"""
|
||||
Get the recent actions from the administrator's log.
|
||||
|
||||
|
@ -757,9 +751,7 @@ class Client:
|
|||
"""
|
||||
return get_file_bytes(self, media)
|
||||
|
||||
def get_handler_filter(
|
||||
self, handler: Callable[[Event], Awaitable[Any]], /
|
||||
) -> Optional[FilterType]:
|
||||
def get_handler_filter(self, handler: Callable[[Event], Awaitable[Any]], /) -> Optional[FilterType]:
|
||||
"""
|
||||
Get the filter associated to the given event handler.
|
||||
|
||||
|
@ -861,13 +853,9 @@ class Client:
|
|||
async for message in reversed(client.get_messages(chat)):
|
||||
print(message.sender.name, ':', message.markdown_text)
|
||||
"""
|
||||
return get_messages(
|
||||
self, chat, limit, offset_id=offset_id, offset_date=offset_date
|
||||
)
|
||||
return get_messages(self, chat, limit, offset_id=offset_id, offset_date=offset_date)
|
||||
|
||||
def get_messages_with_ids(
|
||||
self, chat: Peer | PeerRef, /, message_ids: list[int]
|
||||
) -> AsyncList[Message]:
|
||||
def get_messages_with_ids(self, chat: Peer | PeerRef, /, message_ids: list[int]) -> AsyncList[Message]:
|
||||
"""
|
||||
Get the full message objects from the corresponding message identifiers.
|
||||
|
||||
|
@ -891,9 +879,7 @@ class Client:
|
|||
"""
|
||||
return get_messages_with_ids(self, chat, message_ids)
|
||||
|
||||
def get_participants(
|
||||
self, chat: Group | Channel | GroupRef | ChannelRef, /
|
||||
) -> AsyncList[Participant]:
|
||||
def get_participants(self, chat: Group | Channel | GroupRef | ChannelRef, /) -> AsyncList[Participant]:
|
||||
"""
|
||||
Get the participants in a group or channel, along with their permissions.
|
||||
|
||||
|
@ -981,9 +967,7 @@ class Client:
|
|||
"""
|
||||
return await inline_query(self, bot, query, peer=peer)
|
||||
|
||||
async def interactive_login(
|
||||
self, phone_or_token: Optional[str] = None, *, password: Optional[str] = None
|
||||
) -> User:
|
||||
async def interactive_login(self, phone_or_token: Optional[str] = None, *, password: Optional[str] = None) -> User:
|
||||
"""
|
||||
Begin an interactive login if needed.
|
||||
If the account was already logged-in, this method simply returns :term:`yourself`.
|
||||
|
@ -1035,9 +1019,7 @@ class Client:
|
|||
|
||||
def on(
|
||||
self, event_cls: Type[Event], /, filter: Optional[FilterType] = None
|
||||
) -> Callable[
|
||||
[Callable[[Event], Awaitable[Any]]], Callable[[Event], Awaitable[Any]]
|
||||
]:
|
||||
) -> Callable[[Callable[[Event], Awaitable[Any]]], Callable[[Event], Awaitable[Any]]]:
|
||||
"""
|
||||
Register the decorated function to be invoked when the provided event type occurs.
|
||||
|
||||
|
@ -1125,9 +1107,7 @@ class Client:
|
|||
"""
|
||||
return prepare_album(self)
|
||||
|
||||
async def read_message(
|
||||
self, chat: Peer | PeerRef, /, message_id: int | Literal["all"]
|
||||
) -> None:
|
||||
async def read_message(self, chat: Peer | PeerRef, /, message_id: int | Literal["all"]) -> None:
|
||||
"""
|
||||
Mark messages as read.
|
||||
|
||||
|
@ -1158,9 +1138,7 @@ class Client:
|
|||
"""
|
||||
await read_message(self, chat, message_id)
|
||||
|
||||
def remove_event_handler(
|
||||
self, handler: Callable[[Event], Awaitable[Any]], /
|
||||
) -> None:
|
||||
def remove_event_handler(self, handler: Callable[[Event], Awaitable[Any]], /) -> None:
|
||||
"""
|
||||
Remove the handler as a function to be called when events occur.
|
||||
This is simply the opposite of :meth:`add_event_handler`.
|
||||
|
@ -1324,9 +1302,7 @@ class Client:
|
|||
async for message in client.search_all_messages(query='hello'):
|
||||
print(message.text)
|
||||
"""
|
||||
return search_all_messages(
|
||||
self, limit, query=query, offset_id=offset_id, offset_date=offset_date
|
||||
)
|
||||
return search_all_messages(self, limit, query=query, offset_id=offset_id, offset_date=offset_date)
|
||||
|
||||
def search_messages(
|
||||
self,
|
||||
|
@ -1372,9 +1348,7 @@ class Client:
|
|||
async for message in client.search_messages(chat, query='hello'):
|
||||
print(message.text)
|
||||
"""
|
||||
return search_messages(
|
||||
self, chat, limit, query=query, offset_id=offset_id, offset_date=offset_date
|
||||
)
|
||||
return search_messages(self, chat, limit, query=query, offset_id=offset_id, offset_date=offset_date)
|
||||
|
||||
async def send_audio(
|
||||
self,
|
||||
|
@ -1975,9 +1949,7 @@ class Client:
|
|||
|
||||
:meth:`telethon.types.Participant.set_restrictions`
|
||||
"""
|
||||
await set_participant_restrictions(
|
||||
self, chat, participant, restrictions, until=until
|
||||
)
|
||||
await set_participant_restrictions(self, chat, participant, restrictions, until=until)
|
||||
|
||||
async def sign_in(self, token: LoginToken, code: str) -> User | PasswordToken:
|
||||
"""
|
||||
|
@ -2024,9 +1996,7 @@ class Client:
|
|||
"""
|
||||
await sign_out(self)
|
||||
|
||||
async def unpin_message(
|
||||
self, chat: Peer | PeerRef, /, message_id: int | Literal["all"]
|
||||
) -> None:
|
||||
async def unpin_message(self, chat: Peer | PeerRef, /, message_id: int | Literal["all"]) -> None:
|
||||
"""
|
||||
Unpin one or all messages from the top.
|
||||
|
||||
|
|
|
@ -51,9 +51,7 @@ class DialogList(AsyncList[Dialog]):
|
|||
chat_map = build_chat_map(self._client, result.users, result.chats)
|
||||
msg_map = build_msg_map(self._client, result.messages, chat_map)
|
||||
|
||||
self._buffer.extend(
|
||||
Dialog._from_raw(self._client, d, chat_map, msg_map) for d in result.dialogs
|
||||
)
|
||||
self._buffer.extend(Dialog._from_raw(self._client, d, chat_map, msg_map) for d in result.dialogs)
|
||||
|
||||
|
||||
def get_dialogs(self: Client) -> AsyncList[Dialog]:
|
||||
|
@ -130,9 +128,7 @@ async def edit_draft(
|
|||
reply_to: Optional[int] = None,
|
||||
) -> Draft:
|
||||
peer = peer._ref
|
||||
message, entities = parse_message(
|
||||
text=text, markdown=markdown, html=html, allow_empty=False
|
||||
)
|
||||
message, entities = parse_message(text=text, markdown=markdown, html=html, allow_empty=False)
|
||||
|
||||
result = await self(
|
||||
functions.messages.save_draft(
|
||||
|
|
|
@ -189,15 +189,11 @@ async def send_file(
|
|||
reply_to: Optional[int] = None,
|
||||
keyboard: Optional[KeyboardType] = None,
|
||||
) -> Message:
|
||||
message, entities = parse_message(
|
||||
text=caption, markdown=caption_markdown, html=caption_html, allow_empty=True
|
||||
)
|
||||
message, entities = parse_message(text=caption, markdown=caption_markdown, html=caption_html, allow_empty=True)
|
||||
|
||||
# Re-send existing file.
|
||||
if isinstance(file, File):
|
||||
return await do_send_file(
|
||||
self, chat, file._input_media, message, entities, reply_to, keyboard
|
||||
)
|
||||
return await do_send_file(self, chat, file._input_media, message, entities, reply_to, keyboard)
|
||||
|
||||
# URLs are handled early as they can't use any other attributes either.
|
||||
input_media: abcs.InputMedia
|
||||
|
@ -212,16 +208,10 @@ async def send_file(
|
|||
else:
|
||||
as_photo = False
|
||||
if as_photo:
|
||||
input_media = types.InputMediaPhotoExternal(
|
||||
spoiler=False, url=file, ttl_seconds=None
|
||||
)
|
||||
input_media = types.InputMediaPhotoExternal(spoiler=False, url=file, ttl_seconds=None)
|
||||
else:
|
||||
input_media = types.InputMediaDocumentExternal(
|
||||
spoiler=False, url=file, ttl_seconds=None
|
||||
)
|
||||
return await do_send_file(
|
||||
self, chat, input_media, message, entities, reply_to, keyboard
|
||||
)
|
||||
input_media = types.InputMediaDocumentExternal(spoiler=False, url=file, ttl_seconds=None)
|
||||
return await do_send_file(self, chat, input_media, message, entities, reply_to, keyboard)
|
||||
|
||||
input_file, name = await upload(self, file, size, name)
|
||||
|
||||
|
@ -285,9 +275,7 @@ async def send_file(
|
|||
ttl_seconds=None,
|
||||
)
|
||||
|
||||
return await do_send_file(
|
||||
self, chat, input_media, message, entities, reply_to, keyboard
|
||||
)
|
||||
return await do_send_file(self, chat, input_media, message, entities, reply_to, keyboard)
|
||||
|
||||
|
||||
async def do_send_file(
|
||||
|
@ -309,11 +297,7 @@ async def do_send_file(
|
|||
noforwards=False,
|
||||
update_stickersets_order=False,
|
||||
peer=chat._ref._to_input_peer(),
|
||||
reply_to=(
|
||||
types.InputReplyToMessage(reply_to_msg_id=reply_to, top_msg_id=None)
|
||||
if reply_to
|
||||
else None
|
||||
),
|
||||
reply_to=(types.InputReplyToMessage(reply_to_msg_id=reply_to, top_msg_id=None) if reply_to else None),
|
||||
media=input_media,
|
||||
message=message,
|
||||
random_id=random_id,
|
||||
|
@ -395,11 +379,7 @@ async def do_upload(
|
|||
)
|
||||
)
|
||||
else:
|
||||
await client(
|
||||
functions.upload.save_file_part(
|
||||
file_id=file_id, file_part=part, bytes=to_store
|
||||
)
|
||||
)
|
||||
await client(functions.upload.save_file_part(file_id=file_id, file_part=part, bytes=to_store))
|
||||
hash_md5.update(to_store)
|
||||
|
||||
buffer.clear()
|
||||
|
@ -458,9 +438,7 @@ def get_file_bytes(self: Client, media: File, /) -> AsyncList[bytes]:
|
|||
return FileBytesList(self, media)
|
||||
|
||||
|
||||
async def download(
|
||||
self: Client, media: File, /, file: str | Path | OutFileLike
|
||||
) -> None:
|
||||
async def download(self: Client, media: File, /, file: str | Path | OutFileLike) -> None:
|
||||
fd = OutWrapper(file)
|
||||
try:
|
||||
async for chunk in get_file_bytes(self, media):
|
||||
|
|
|
@ -46,9 +46,7 @@ async def send_message(
|
|||
update_stickersets_order=False,
|
||||
peer=chat._ref._to_input_peer(),
|
||||
reply_to=(
|
||||
types.InputReplyToMessage(
|
||||
reply_to_msg_id=text.replied_message_id, top_msg_id=None
|
||||
)
|
||||
types.InputReplyToMessage(reply_to_msg_id=text.replied_message_id, top_msg_id=None)
|
||||
if text.replied_message_id
|
||||
else None
|
||||
),
|
||||
|
@ -60,9 +58,7 @@ async def send_message(
|
|||
send_as=None,
|
||||
)
|
||||
else:
|
||||
message, entities = parse_message(
|
||||
text=text, markdown=markdown, html=html, allow_empty=False
|
||||
)
|
||||
message, entities = parse_message(text=text, markdown=markdown, html=html, allow_empty=False)
|
||||
request = functions.messages.send_message(
|
||||
no_webpage=not link_preview,
|
||||
silent=False,
|
||||
|
@ -71,11 +67,7 @@ async def send_message(
|
|||
noforwards=False,
|
||||
update_stickersets_order=False,
|
||||
peer=chat._ref._to_input_peer(),
|
||||
reply_to=(
|
||||
types.InputReplyToMessage(reply_to_msg_id=reply_to, top_msg_id=None)
|
||||
if reply_to
|
||||
else None
|
||||
),
|
||||
reply_to=(types.InputReplyToMessage(reply_to_msg_id=reply_to, top_msg_id=None) if reply_to else None),
|
||||
message=message,
|
||||
random_id=random_id,
|
||||
reply_markup=keyboard._raw if keyboard else None,
|
||||
|
@ -91,11 +83,7 @@ async def send_message(
|
|||
{},
|
||||
out=result.out,
|
||||
id=result.id,
|
||||
from_id=(
|
||||
types.PeerUser(user_id=self._session.user.id)
|
||||
if self._session.user
|
||||
else None
|
||||
),
|
||||
from_id=(types.PeerUser(user_id=self._session.user.id) if self._session.user else None),
|
||||
peer_id=chat._ref._to_peer(),
|
||||
reply_to=(
|
||||
types.MessageReplyHeader(
|
||||
|
@ -130,9 +118,7 @@ async def edit_message(
|
|||
link_preview: bool = False,
|
||||
keyboard: Optional[KeyboardType] = None,
|
||||
) -> Message:
|
||||
message, entities = parse_message(
|
||||
text=text, markdown=markdown, html=html, allow_empty=False
|
||||
)
|
||||
message, entities = parse_message(text=text, markdown=markdown, html=html, allow_empty=False)
|
||||
return self._build_message_map(
|
||||
await self(
|
||||
functions.messages.edit_message(
|
||||
|
@ -160,15 +146,9 @@ async def delete_messages(
|
|||
) -> int:
|
||||
peer = chat._ref
|
||||
if isinstance(peer, ChannelRef):
|
||||
affected = await self(
|
||||
functions.channels.delete_messages(
|
||||
channel=peer._to_input_channel(), id=message_ids
|
||||
)
|
||||
)
|
||||
affected = await self(functions.channels.delete_messages(channel=peer._to_input_channel(), id=message_ids))
|
||||
else:
|
||||
affected = await self(
|
||||
functions.messages.delete_messages(revoke=revoke, id=message_ids)
|
||||
)
|
||||
affected = await self(functions.messages.delete_messages(revoke=revoke, id=message_ids))
|
||||
assert isinstance(affected, types.messages.AffectedMessages)
|
||||
return affected.pts_count
|
||||
|
||||
|
@ -205,9 +185,7 @@ class MessageList(AsyncList[Message]):
|
|||
super().__init__()
|
||||
self._reversed = False
|
||||
|
||||
def _extend_buffer(
|
||||
self, client: Client, messages: abcs.messages.Messages
|
||||
) -> dict[int, Peer]:
|
||||
def _extend_buffer(self, client: Client, messages: abcs.messages.Messages) -> dict[int, Peer]:
|
||||
if isinstance(messages, types.messages.MessagesNotModified):
|
||||
self._total = messages.count
|
||||
return {}
|
||||
|
@ -215,9 +193,7 @@ class MessageList(AsyncList[Message]):
|
|||
if isinstance(messages, types.messages.Messages):
|
||||
self._total = len(messages.messages)
|
||||
self._done = True
|
||||
elif isinstance(
|
||||
messages, (types.messages.MessagesSlice, types.messages.ChannelMessages)
|
||||
):
|
||||
elif isinstance(messages, (types.messages.MessagesSlice, types.messages.ChannelMessages)):
|
||||
self._total = messages.count
|
||||
else:
|
||||
raise RuntimeError("unexpected case")
|
||||
|
@ -225,9 +201,7 @@ class MessageList(AsyncList[Message]):
|
|||
chat_map = build_chat_map(client, messages.users, messages.chats)
|
||||
self._buffer.extend(
|
||||
Message._from_raw(client, m, chat_map)
|
||||
for m in (
|
||||
reversed(messages.messages) if self._reversed else messages.messages
|
||||
)
|
||||
for m in (reversed(messages.messages) if self._reversed else messages.messages)
|
||||
)
|
||||
return chat_map
|
||||
|
||||
|
@ -235,11 +209,7 @@ class MessageList(AsyncList[Message]):
|
|||
self,
|
||||
) -> types.Message | types.MessageService | types.MessageEmpty:
|
||||
return next(
|
||||
(
|
||||
m._raw
|
||||
for m in reversed(self._buffer)
|
||||
if not isinstance(m._raw, types.MessageEmpty)
|
||||
),
|
||||
(m._raw for m in reversed(self._buffer) if not isinstance(m._raw, types.MessageEmpty)),
|
||||
types.MessageEmpty(id=0, peer_id=None),
|
||||
)
|
||||
|
||||
|
@ -335,14 +305,10 @@ class CherryPickedList(MessageList):
|
|||
|
||||
if isinstance(self._peer, ChannelRef):
|
||||
result = await self._client(
|
||||
functions.channels.get_messages(
|
||||
channel=self._peer._to_input_channel(), id=self._ids[:100]
|
||||
)
|
||||
functions.channels.get_messages(channel=self._peer._to_input_channel(), id=self._ids[:100])
|
||||
)
|
||||
else:
|
||||
result = await self._client(
|
||||
functions.messages.get_messages(id=self._ids[:100])
|
||||
)
|
||||
result = await self._client(functions.messages.get_messages(id=self._ids[:100]))
|
||||
|
||||
self._extend_buffer(self._client, result)
|
||||
self._ids = self._ids[100:]
|
||||
|
@ -492,9 +458,7 @@ def search_all_messages(
|
|||
)
|
||||
|
||||
|
||||
async def pin_message(
|
||||
self: Client, chat: Peer | PeerRef, /, message_id: int
|
||||
) -> Message:
|
||||
async def pin_message(self: Client, chat: Peer | PeerRef, /, message_id: int) -> Message:
|
||||
return self._build_message_map(
|
||||
await self(
|
||||
functions.messages.update_pinned_message(
|
||||
|
@ -509,9 +473,7 @@ async def pin_message(
|
|||
).get_single()
|
||||
|
||||
|
||||
async def unpin_message(
|
||||
self: Client, chat: Peer | PeerRef, /, message_id: int | Literal["all"]
|
||||
) -> None:
|
||||
async def unpin_message(self: Client, chat: Peer | PeerRef, /, message_id: int | Literal["all"]) -> None:
|
||||
if message_id == "all":
|
||||
await self(
|
||||
functions.messages.unpin_all_messages(
|
||||
|
@ -531,25 +493,15 @@ async def unpin_message(
|
|||
)
|
||||
|
||||
|
||||
async def read_message(
|
||||
self: Client, chat: Peer | PeerRef, /, message_id: int | Literal["all"]
|
||||
) -> None:
|
||||
async def read_message(self: Client, chat: Peer | PeerRef, /, message_id: int | Literal["all"]) -> None:
|
||||
if message_id == "all":
|
||||
message_id = 0
|
||||
|
||||
peer = chat._ref
|
||||
if isinstance(peer, ChannelRef):
|
||||
await self(
|
||||
functions.channels.read_history(
|
||||
channel=peer._to_input_channel(), max_id=message_id
|
||||
)
|
||||
)
|
||||
await self(functions.channels.read_history(channel=peer._to_input_channel(), max_id=message_id))
|
||||
else:
|
||||
await self(
|
||||
functions.messages.read_history(
|
||||
peer=peer._ref._to_input_peer(), max_id=message_id
|
||||
)
|
||||
)
|
||||
await self(functions.messages.read_history(peer=peer._ref._to_input_peer(), max_id=message_id))
|
||||
|
||||
|
||||
class MessageMap:
|
||||
|
@ -584,9 +536,7 @@ class MessageMap:
|
|||
def _empty(self, id: int = 0) -> Message:
|
||||
return Message._from_raw(
|
||||
self._client,
|
||||
types.MessageEmpty(
|
||||
id=id, peer_id=self._peer._to_peer() if self._peer else None
|
||||
),
|
||||
types.MessageEmpty(id=id, peer_id=self._peer._to_peer() if self._peer else None),
|
||||
{},
|
||||
)
|
||||
|
||||
|
|
|
@ -82,16 +82,9 @@ async def connect_sender(
|
|||
# Only the ID of the input DC may be known.
|
||||
# Find the corresponding address and authentication key if needed.
|
||||
addr = dc.ipv4_addr or next(
|
||||
d.ipv4_addr
|
||||
for d in itertools.chain(known_dcs, KNOWN_DCS)
|
||||
if d.id == dc.id and d.ipv4_addr
|
||||
)
|
||||
auth = (
|
||||
None
|
||||
if force_auth_gen
|
||||
else dc.auth
|
||||
or (next((d.auth for d in known_dcs if d.id == dc.id and d.auth), None))
|
||||
d.ipv4_addr for d in itertools.chain(known_dcs, KNOWN_DCS) if d.id == dc.id and d.ipv4_addr
|
||||
)
|
||||
auth = None if force_auth_gen else dc.auth or (next((d.auth for d in known_dcs if d.id == dc.id and d.auth), None))
|
||||
|
||||
sender = await do_connect_sender(
|
||||
Full(),
|
||||
|
@ -122,9 +115,7 @@ async def connect_sender(
|
|||
)
|
||||
except BadStatus as e:
|
||||
if e.status == 404 and auth:
|
||||
dc = DataCenter(
|
||||
id=dc.id, ipv4_addr=dc.ipv4_addr, ipv6_addr=dc.ipv6_addr, auth=None
|
||||
)
|
||||
dc = DataCenter(id=dc.id, ipv4_addr=dc.ipv4_addr, ipv6_addr=dc.ipv6_addr, auth=None)
|
||||
config.base_logger.warning(
|
||||
"datacenter could not find stored auth; will retry generating a new one: %s",
|
||||
dc,
|
||||
|
@ -174,12 +165,8 @@ async def connect(self: Client) -> None:
|
|||
if session := await self._storage.load():
|
||||
self._session = session
|
||||
|
||||
datacenter = self._config.datacenter or DataCenter(
|
||||
id=self._session.user.dc if self._session.user else DEFAULT_DC
|
||||
)
|
||||
self._sender, self._session.dcs = await connect_sender(
|
||||
self._config, self._session.dcs, datacenter
|
||||
)
|
||||
datacenter = self._config.datacenter or DataCenter(id=self._session.user.dc if self._session.user else DEFAULT_DC)
|
||||
self._sender, self._session.dcs = await connect_sender(self._config, self._session.dcs, datacenter)
|
||||
|
||||
if self._message_box.is_empty() and self._session.user:
|
||||
try:
|
||||
|
@ -193,9 +180,7 @@ async def connect(self: Client) -> None:
|
|||
me = await self.get_me()
|
||||
assert me is not None
|
||||
self._chat_hashes.set_self_user(me.id, me.bot)
|
||||
self._session.user = SessionUser(
|
||||
id=me.id, dc=self._sender.dc_id, bot=me.bot, username=me.username
|
||||
)
|
||||
self._session.user = SessionUser(id=me.id, dc=self._sender.dc_id, bot=me.bot, username=me.username)
|
||||
|
||||
self._dispatcher = asyncio.create_task(dispatcher(self))
|
||||
|
||||
|
@ -214,18 +199,14 @@ async def disconnect(self: Client) -> None:
|
|||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
self._config.base_logger.exception(
|
||||
"unhandled exception when cancelling dispatcher; this is a bug"
|
||||
)
|
||||
self._config.base_logger.exception("unhandled exception when cancelling dispatcher; this is a bug")
|
||||
finally:
|
||||
self._dispatcher = None
|
||||
|
||||
try:
|
||||
await sender.disconnect()
|
||||
except Exception:
|
||||
self._config.base_logger.exception(
|
||||
"unhandled exception during disconnect; this is a bug"
|
||||
)
|
||||
self._config.base_logger.exception("unhandled exception during disconnect; this is a bug")
|
||||
|
||||
try:
|
||||
if self._session.user:
|
||||
|
|
|
@ -40,9 +40,7 @@ def add_event_handler(
|
|||
self._handlers.setdefault(event_cls, []).append((handler, filter))
|
||||
|
||||
|
||||
def remove_event_handler(
|
||||
self: Client, handler: Callable[[Event], Awaitable[Any]], /
|
||||
) -> None:
|
||||
def remove_event_handler(self: Client, handler: Callable[[Event], Awaitable[Any]], /) -> None:
|
||||
for event_cls, handlers in tuple(self._handlers.items()):
|
||||
for i in reversed(range(len(handlers))):
|
||||
if handlers[i][0] == handler:
|
||||
|
@ -51,9 +49,7 @@ def remove_event_handler(
|
|||
del self._handlers[event_cls]
|
||||
|
||||
|
||||
def get_handler_filter(
|
||||
self: Client, handler: Callable[[Event], Awaitable[Any]], /
|
||||
) -> Optional[FilterType]:
|
||||
def get_handler_filter(self: Client, handler: Callable[[Event], Awaitable[Any]], /) -> Optional[FilterType]:
|
||||
for handlers in self._handlers.values():
|
||||
for h, f in handlers:
|
||||
if h == handler:
|
||||
|
@ -84,9 +80,7 @@ def process_socket_updates(client: Client, all_updates: list[abcs.Updates]) -> N
|
|||
return
|
||||
|
||||
try:
|
||||
result, users, chats = client._message_box.process_updates(
|
||||
updates, client._chat_hashes
|
||||
)
|
||||
result, users, chats = client._message_box.process_updates(updates, client._chat_hashes)
|
||||
except Gap:
|
||||
return
|
||||
|
||||
|
@ -107,8 +101,7 @@ def extend_update_queue(
|
|||
except asyncio.QueueFull:
|
||||
now = asyncio.get_running_loop().time()
|
||||
if client._last_update_limit_warn is None or (
|
||||
now - client._last_update_limit_warn
|
||||
> UPDATE_LIMIT_EXCEEDED_LOG_COOLDOWN
|
||||
now - client._last_update_limit_warn > UPDATE_LIMIT_EXCEEDED_LOG_COOLDOWN
|
||||
):
|
||||
client._config.base_logger.warning(
|
||||
"updates are being dropped because limit=%d has been reached",
|
||||
|
|
|
@ -53,15 +53,11 @@ def resolved_peer_to_chat(client: Client, resolved: abcs.contacts.ResolvedPeer)
|
|||
|
||||
|
||||
async def resolve_phone(self: Client, phone: str, /) -> Peer:
|
||||
return resolved_peer_to_chat(
|
||||
self, await self(functions.contacts.resolve_phone(phone=phone))
|
||||
)
|
||||
return resolved_peer_to_chat(self, await self(functions.contacts.resolve_phone(phone=phone)))
|
||||
|
||||
|
||||
async def resolve_username(self: Client, username: str, /) -> Peer:
|
||||
return resolved_peer_to_chat(
|
||||
self, await self(functions.contacts.resolve_username(username=username))
|
||||
)
|
||||
return resolved_peer_to_chat(self, await self(functions.contacts.resolve_username(username=username)))
|
||||
|
||||
|
||||
async def resolve_peers(self: Client, peers: Sequence[Peer | PeerRef], /) -> list[Peer]:
|
||||
|
@ -99,8 +95,4 @@ async def resolve_peers(self: Client, peers: Sequence[Peer | PeerRef], /) -> lis
|
|||
chats.extend(ret_chats.chats)
|
||||
|
||||
chat_map = build_chat_map(self, users, chats)
|
||||
return [
|
||||
chat_map.get(ref.identifier)
|
||||
or expand_peer(self, ref._to_peer(), broadcast=None)
|
||||
for ref in refs
|
||||
]
|
||||
return [chat_map.get(ref.identifier) or expand_peer(self, ref._to_peer(), broadcast=None) for ref in refs]
|
||||
|
|
|
@ -34,17 +34,13 @@ def from_name(name: str, *, _cache: dict[str, Type[RpcError]] = {}) -> Type[RpcE
|
|||
return _cache[name]
|
||||
|
||||
|
||||
def adapt_rpc(
|
||||
error: RpcError, *, _cache: dict[tuple[int, str], Type[RpcError]] = {}
|
||||
) -> RpcError:
|
||||
def adapt_rpc(error: RpcError, *, _cache: dict[tuple[int, str], Type[RpcError]] = {}) -> RpcError:
|
||||
code = canonicalize_code(error.code)
|
||||
name = canonicalize_name(error.name)
|
||||
tup = code, name
|
||||
if tup not in _cache:
|
||||
_cache[tup] = type(pretty_name(name), (from_code(code), from_name(name)), {})
|
||||
return _cache[tup](
|
||||
code=error.code, name=error.name, value=error.value, caused_by=error._caused_by
|
||||
)
|
||||
return _cache[tup](code=error.code, name=error.name, value=error.value, caused_by=error._caused_by)
|
||||
|
||||
|
||||
class ErrorFactory:
|
||||
|
@ -55,9 +51,7 @@ class ErrorFactory:
|
|||
return from_code(int(m[1]))
|
||||
else:
|
||||
adapted = adapt_user_name(name)
|
||||
if pretty_name(canonicalize_name(adapted)) != name or re.match(
|
||||
r"[A-Z]{2}", name
|
||||
):
|
||||
if pretty_name(canonicalize_name(adapted)) != name or re.match(r"[A-Z]{2}", name):
|
||||
raise AttributeError(f"error subclass names must be CamelCase: {name}")
|
||||
return from_name(adapted)
|
||||
|
||||
|
|
|
@ -24,9 +24,7 @@ class Event(abc.ABC, metaclass=NoPublicConstructor):
|
|||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def _try_from_update(
|
||||
cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]
|
||||
) -> Optional[Self]:
|
||||
def _try_from_update(cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]) -> Optional[Self]:
|
||||
pass
|
||||
|
||||
|
||||
|
@ -52,9 +50,7 @@ class Raw(Event):
|
|||
self._chat_map = chat_map
|
||||
|
||||
@classmethod
|
||||
def _try_from_update(
|
||||
cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]
|
||||
) -> Optional[Self]:
|
||||
def _try_from_update(cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]) -> Optional[Self]:
|
||||
return cls._create(client, update, chat_map)
|
||||
|
||||
|
||||
|
|
|
@ -65,9 +65,7 @@ class Any(Combinable):
|
|||
|
||||
__slots__ = ("_filters",)
|
||||
|
||||
def __init__(
|
||||
self, filter1: FilterType, filter2: FilterType, *filters: FilterType
|
||||
) -> None:
|
||||
def __init__(self, filter1: FilterType, filter2: FilterType, *filters: FilterType) -> None:
|
||||
self._filters = (filter1, filter2, *filters)
|
||||
|
||||
@property
|
||||
|
@ -111,9 +109,7 @@ class All(Combinable):
|
|||
|
||||
__slots__ = ("_filters",)
|
||||
|
||||
def __init__(
|
||||
self, filter1: FilterType, filter2: FilterType, *filters: FilterType
|
||||
) -> None:
|
||||
def __init__(self, filter1: FilterType, filter2: FilterType, *filters: FilterType) -> None:
|
||||
self._filters = (filter1, filter2, *filters)
|
||||
|
||||
@property
|
||||
|
|
|
@ -24,15 +24,11 @@ class NewMessage(Event, Message):
|
|||
"""
|
||||
|
||||
@classmethod
|
||||
def _try_from_update(
|
||||
cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]
|
||||
) -> Optional[Self]:
|
||||
def _try_from_update(cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]) -> Optional[Self]:
|
||||
if isinstance(update, (types.UpdateNewMessage, types.UpdateNewChannelMessage)):
|
||||
if isinstance(update.message, types.Message):
|
||||
return cls._from_raw(client, update.message, chat_map)
|
||||
elif isinstance(
|
||||
update, (types.UpdateShortMessage, types.UpdateShortChatMessage)
|
||||
):
|
||||
elif isinstance(update, (types.UpdateShortMessage, types.UpdateShortChatMessage)):
|
||||
raise RuntimeError("should have been handled by adaptor")
|
||||
|
||||
return None
|
||||
|
@ -46,12 +42,8 @@ class MessageEdited(Event, Message):
|
|||
"""
|
||||
|
||||
@classmethod
|
||||
def _try_from_update(
|
||||
cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]
|
||||
) -> Optional[Self]:
|
||||
if isinstance(
|
||||
update, (types.UpdateEditMessage, types.UpdateEditChannelMessage)
|
||||
):
|
||||
def _try_from_update(cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]) -> Optional[Self]:
|
||||
if isinstance(update, (types.UpdateEditMessage, types.UpdateEditChannelMessage)):
|
||||
return cls._from_raw(client, update.message, chat_map)
|
||||
else:
|
||||
return None
|
||||
|
@ -74,9 +66,7 @@ class MessageDeleted(Event):
|
|||
self._channel_id = channel_id
|
||||
|
||||
@classmethod
|
||||
def _try_from_update(
|
||||
cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]
|
||||
) -> Optional[Self]:
|
||||
def _try_from_update(cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]) -> Optional[Self]:
|
||||
if isinstance(update, types.UpdateDeleteMessages):
|
||||
return cls._create(update.messages, None)
|
||||
elif isinstance(update, types.UpdateDeleteChannelMessages):
|
||||
|
@ -122,9 +112,7 @@ class MessageRead(Event):
|
|||
self._chat_map = chat_map
|
||||
|
||||
@classmethod
|
||||
def _try_from_update(
|
||||
cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]
|
||||
) -> Optional[Self]:
|
||||
def _try_from_update(cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]) -> Optional[Self]:
|
||||
if isinstance(
|
||||
update,
|
||||
(
|
||||
|
@ -139,9 +127,7 @@ class MessageRead(Event):
|
|||
return None
|
||||
|
||||
def _peer(self) -> abcs.Peer:
|
||||
if isinstance(
|
||||
self._raw, (types.UpdateReadHistoryInbox, types.UpdateReadHistoryOutbox)
|
||||
):
|
||||
if isinstance(self._raw, (types.UpdateReadHistoryInbox, types.UpdateReadHistoryOutbox)):
|
||||
return self._raw.peer
|
||||
else:
|
||||
return types.PeerChannel(channel_id=self._raw.channel_id)
|
||||
|
@ -154,9 +140,7 @@ class MessageRead(Event):
|
|||
peer = self._peer()
|
||||
pid = peer_id(peer)
|
||||
if pid not in self._chat_map:
|
||||
self._chat_map[pid] = expand_peer(
|
||||
self._client, peer, broadcast=getattr(self._raw, "post", None)
|
||||
)
|
||||
self._chat_map[pid] = expand_peer(self._client, peer, broadcast=getattr(self._raw, "post", None))
|
||||
return self._chat_map[pid]
|
||||
|
||||
@property
|
||||
|
|
|
@ -31,9 +31,7 @@ class ButtonCallback(Event):
|
|||
self._chat_map = chat_map
|
||||
|
||||
@classmethod
|
||||
def _try_from_update(
|
||||
cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]
|
||||
) -> Optional[Self]:
|
||||
def _try_from_update(cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]) -> Optional[Self]:
|
||||
if isinstance(update, types.UpdateBotCallbackQuery) and update.data is not None:
|
||||
return cls._create(client, update, chat_map)
|
||||
else:
|
||||
|
@ -83,11 +81,7 @@ class ButtonCallback(Event):
|
|||
chat = self._chat_map.get(pid) or PeerRef._empty_from_peer(self._raw.peer)
|
||||
|
||||
lst = CherryPickedList(self._client, chat._ref, [])
|
||||
lst._ids.append(
|
||||
types.InputMessageCallbackQuery(
|
||||
id=self._raw.msg_id, query_id=self._raw.query_id
|
||||
)
|
||||
)
|
||||
lst._ids.append(types.InputMessageCallbackQuery(id=self._raw.msg_id, query_id=self._raw.query_id))
|
||||
|
||||
message = (await lst)[0]
|
||||
|
||||
|
@ -105,9 +99,7 @@ class InlineQuery(Event):
|
|||
self._raw = update
|
||||
|
||||
@classmethod
|
||||
def _try_from_update(
|
||||
cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]
|
||||
) -> Optional[Self]:
|
||||
def _try_from_update(cls, client: Client, update: abcs.Update, chat_map: dict[int, Peer]) -> Optional[Self]:
|
||||
if isinstance(update, types.UpdateBotInlineQuery):
|
||||
return cls._create(update)
|
||||
else:
|
||||
|
|
|
@ -133,9 +133,7 @@ def parse(html: str) -> tuple[str, list[MessageEntity]]:
|
|||
return del_surrogate(parser.text), parser.entities
|
||||
|
||||
|
||||
ENTITY_TO_FORMATTER: dict[
|
||||
Type[MessageEntity], tuple[str, str] | Callable[[Any, str], tuple[str, str]]
|
||||
] = {
|
||||
ENTITY_TO_FORMATTER: dict[Type[MessageEntity], tuple[str, str] | Callable[[Any, str], tuple[str, str]]] = {
|
||||
MessageEntityBold: ("<strong>", "</strong>"),
|
||||
MessageEntityItalic: ("<em>", "</em>"),
|
||||
MessageEntityCode: ("<code>", "</code>"),
|
||||
|
@ -196,12 +194,7 @@ def unparse(text: str, entities: Iterable[MessageEntity]) -> str:
|
|||
while within_surrogate(text, at):
|
||||
at += 1
|
||||
|
||||
text = (
|
||||
text[:at]
|
||||
+ what
|
||||
+ escape(text[at:next_escape_bound])
|
||||
+ text[next_escape_bound:]
|
||||
)
|
||||
text = text[:at] + what + escape(text[at:next_escape_bound]) + text[next_escape_bound:]
|
||||
next_escape_bound = at
|
||||
|
||||
text = escape(text[:next_escape_bound]) + text[next_escape_bound:]
|
||||
|
|
|
@ -82,9 +82,7 @@ def parse(message: str) -> tuple[str, list[MessageEntity]]:
|
|||
else:
|
||||
for entity in reversed(entities):
|
||||
if isinstance(entity, ty):
|
||||
setattr(
|
||||
entity, "length", len(message) - getattr(entity, "offset", 0)
|
||||
)
|
||||
setattr(entity, "length", len(message) - getattr(entity, "offset", 0))
|
||||
break
|
||||
|
||||
parsed = MARKDOWN.parse(add_surrogate(message.strip()))
|
||||
|
@ -103,25 +101,15 @@ def parse(message: str) -> tuple[str, list[MessageEntity]]:
|
|||
if token.type in ("blockquote_close", "blockquote_open"):
|
||||
push(MessageEntityBlockquote)
|
||||
elif token.type == "code_block":
|
||||
entities.append(
|
||||
MessageEntityPre(
|
||||
offset=len(message), length=len(token.content), language=""
|
||||
)
|
||||
)
|
||||
entities.append(MessageEntityPre(offset=len(message), length=len(token.content), language=""))
|
||||
message += token.content
|
||||
elif token.type == "code_inline":
|
||||
entities.append(
|
||||
MessageEntityCode(offset=len(message), length=len(token.content))
|
||||
)
|
||||
entities.append(MessageEntityCode(offset=len(message), length=len(token.content)))
|
||||
message += token.content
|
||||
elif token.type in ("em_close", "em_open"):
|
||||
push(MessageEntityItalic)
|
||||
elif token.type == "fence":
|
||||
entities.append(
|
||||
MessageEntityPre(
|
||||
offset=len(message), length=len(token.content), language=token.info
|
||||
)
|
||||
)
|
||||
entities.append(MessageEntityPre(offset=len(message), length=len(token.content), language=token.info))
|
||||
message += token.content[:-1] # remove a single trailing newline
|
||||
elif token.type == "hardbreak":
|
||||
message += "\n"
|
||||
|
@ -130,9 +118,7 @@ def parse(message: str) -> tuple[str, list[MessageEntity]]:
|
|||
elif token.type == "hr":
|
||||
message += "\u2015\n\n"
|
||||
elif token.type in ("link_close", "link_open"):
|
||||
if (
|
||||
token.markup != "autolink"
|
||||
): # telegram already picks up on these automatically
|
||||
if token.markup != "autolink": # telegram already picks up on these automatically
|
||||
push(MessageEntityTextUrl, url=token.attrs.get("href"))
|
||||
elif token.type in ("s_close", "s_open"):
|
||||
push(MessageEntityStrike)
|
||||
|
|
|
@ -6,11 +6,7 @@ def add_surrogate(text: str) -> str:
|
|||
return "".join(
|
||||
# SMP -> Surrogate Pairs (Telegram offsets are calculated with these).
|
||||
# See https://en.wikipedia.org/wiki/Plane_(Unicode)#Overview for more.
|
||||
(
|
||||
"".join(chr(y) for y in struct.unpack("<HH", x.encode("utf-16le")))
|
||||
if (0x10000 <= ord(x) <= 0x10FFFF)
|
||||
else x
|
||||
)
|
||||
("".join(chr(y) for y in struct.unpack("<HH", x.encode("utf-16le"))) if (0x10000 <= ord(x) <= 0x10FFFF) else x)
|
||||
for x in text
|
||||
)
|
||||
|
||||
|
|
|
@ -52,20 +52,12 @@ class AlbumBuilder(metaclass=NoPublicConstructor):
|
|||
input_media: abcs.InputMedia
|
||||
if try_get_url_path(file) is not None:
|
||||
assert isinstance(file, str)
|
||||
input_media = types.InputMediaPhotoExternal(
|
||||
spoiler=False, url=file, ttl_seconds=None
|
||||
)
|
||||
input_media = types.InputMediaPhotoExternal(spoiler=False, url=file, ttl_seconds=None)
|
||||
else:
|
||||
input_file, _ = await self._client._upload(file, size, "a.jpg")
|
||||
input_media = types.InputMediaUploadedPhoto(
|
||||
spoiler=False, file=input_file, stickers=None, ttl_seconds=None
|
||||
)
|
||||
input_media = types.InputMediaUploadedPhoto(spoiler=False, file=input_file, stickers=None, ttl_seconds=None)
|
||||
|
||||
media = await self._client(
|
||||
functions.messages.upload_media(
|
||||
peer=types.InputPeerSelf(), media=input_media
|
||||
)
|
||||
)
|
||||
media = await self._client(functions.messages.upload_media(peer=types.InputPeerSelf(), media=input_media))
|
||||
assert isinstance(media, types.MessageMediaPhoto)
|
||||
assert isinstance(media.photo, types.Photo)
|
||||
input_media = types.InputMediaPhoto(
|
||||
|
@ -77,9 +69,7 @@ class AlbumBuilder(metaclass=NoPublicConstructor):
|
|||
),
|
||||
ttl_seconds=media.ttl_seconds,
|
||||
)
|
||||
message, entities = parse_message(
|
||||
text=caption, markdown=caption_markdown, html=caption_html, allow_empty=True
|
||||
)
|
||||
message, entities = parse_message(text=caption, markdown=caption_markdown, html=caption_html, allow_empty=True)
|
||||
self._medias.append(
|
||||
types.InputSingleMedia(
|
||||
media=input_media,
|
||||
|
@ -132,9 +122,7 @@ class AlbumBuilder(metaclass=NoPublicConstructor):
|
|||
input_media: abcs.InputMedia
|
||||
if try_get_url_path(file) is not None:
|
||||
assert isinstance(file, str)
|
||||
input_media = types.InputMediaDocumentExternal(
|
||||
spoiler=False, url=file, ttl_seconds=None
|
||||
)
|
||||
input_media = types.InputMediaDocumentExternal(spoiler=False, url=file, ttl_seconds=None)
|
||||
else:
|
||||
input_file, name = await self._client._upload(file, size, name)
|
||||
if mime_type is None:
|
||||
|
@ -168,11 +156,7 @@ class AlbumBuilder(metaclass=NoPublicConstructor):
|
|||
ttl_seconds=None,
|
||||
)
|
||||
|
||||
media = await self._client(
|
||||
functions.messages.upload_media(
|
||||
peer=types.InputPeerEmpty(), media=input_media
|
||||
)
|
||||
)
|
||||
media = await self._client(functions.messages.upload_media(peer=types.InputPeerEmpty(), media=input_media))
|
||||
assert isinstance(media, types.MessageMediaDocument)
|
||||
assert isinstance(media.document, types.Document)
|
||||
input_media = types.InputMediaDocument(
|
||||
|
@ -185,9 +169,7 @@ class AlbumBuilder(metaclass=NoPublicConstructor):
|
|||
ttl_seconds=media.ttl_seconds,
|
||||
query=None,
|
||||
)
|
||||
message, entities = parse_message(
|
||||
text=caption, markdown=caption_markdown, html=caption_html, allow_empty=True
|
||||
)
|
||||
message, entities = parse_message(text=caption, markdown=caption_markdown, html=caption_html, allow_empty=True)
|
||||
self._medias.append(
|
||||
types.InputSingleMedia(
|
||||
media=input_media,
|
||||
|
@ -197,9 +179,7 @@ class AlbumBuilder(metaclass=NoPublicConstructor):
|
|||
)
|
||||
)
|
||||
|
||||
async def send(
|
||||
self, peer: Peer | PeerRef, *, reply_to: Optional[int] = None
|
||||
) -> list[Message]:
|
||||
async def send(self, peer: Peer | PeerRef, *, reply_to: Optional[int] = None) -> list[Message]:
|
||||
"""
|
||||
Send the album.
|
||||
|
||||
|
@ -225,11 +205,7 @@ class AlbumBuilder(metaclass=NoPublicConstructor):
|
|||
update_stickersets_order=False,
|
||||
peer=peer._ref._to_input_peer(),
|
||||
reply_to=(
|
||||
types.InputReplyToMessage(
|
||||
reply_to_msg_id=reply_to, top_msg_id=None
|
||||
)
|
||||
if reply_to
|
||||
else None
|
||||
types.InputReplyToMessage(reply_to_msg_id=reply_to, top_msg_id=None) if reply_to else None
|
||||
),
|
||||
multi_media=self._medias,
|
||||
schedule_date=None,
|
||||
|
|
|
@ -50,9 +50,7 @@ class Button(abc.ABC):
|
|||
|
||||
def __init__(self, text: str) -> None:
|
||||
if self.__class__ == Button:
|
||||
raise TypeError(
|
||||
f"Can't instantiate abstract class {self.__class__.__name__}"
|
||||
)
|
||||
raise TypeError(f"Can't instantiate abstract class {self.__class__.__name__}")
|
||||
|
||||
self._raw: RawButtonType = types.KeyboardButton(text=text)
|
||||
self._msg: Optional[weakref.ReferenceType[Message]] = None
|
||||
|
|
|
@ -29,8 +29,6 @@ class InlineButton(Button, abc.ABC):
|
|||
|
||||
def __init__(self, text: str) -> None:
|
||||
if self.__class__ == InlineButton:
|
||||
raise TypeError(
|
||||
f"Can't instantiate abstract class {self.__class__.__name__}"
|
||||
)
|
||||
raise TypeError(f"Can't instantiate abstract class {self.__class__.__name__}")
|
||||
else:
|
||||
super().__init__(text)
|
||||
|
|
|
@ -14,9 +14,7 @@ class SwitchInline(InlineButton):
|
|||
|
||||
def __init__(self, text: str, query: Optional[str] = None) -> None:
|
||||
super().__init__(text)
|
||||
self._raw = types.KeyboardButtonSwitchInline(
|
||||
same_peer=False, text=text, query=query or "", peer_types=None
|
||||
)
|
||||
self._raw = types.KeyboardButtonSwitchInline(same_peer=False, text=text, query=query or "", peer_types=None)
|
||||
|
||||
@property
|
||||
def query(self) -> str:
|
||||
|
|
|
@ -111,9 +111,7 @@ class ChatRestriction(Enum):
|
|||
return set(filter(None, iter(restrictions)))
|
||||
|
||||
@classmethod
|
||||
def _set_to_raw(
|
||||
cls, restrictions: set[ChatRestriction], until_date: int
|
||||
) -> types.ChatBannedRights:
|
||||
def _set_to_raw(cls, restrictions: set[ChatRestriction], until_date: int) -> types.ChatBannedRights:
|
||||
return types.ChatBannedRights(
|
||||
view_messages=cls.VIEW_MESSAGES in restrictions,
|
||||
send_messages=cls.SEND_MESSAGES in restrictions,
|
||||
|
|
|
@ -90,9 +90,6 @@ class Dialog(metaclass=NoPublicConstructor):
|
|||
if isinstance(self._raw, types.Dialog):
|
||||
return self._raw.unread_count
|
||||
elif isinstance(self._raw, types.DialogPeerFolder):
|
||||
return (
|
||||
self._raw.unread_unmuted_messages_count
|
||||
+ self._raw.unread_muted_messages_count
|
||||
)
|
||||
return self._raw.unread_unmuted_messages_count + self._raw.unread_muted_messages_count
|
||||
else:
|
||||
raise RuntimeError("unexpected case")
|
||||
|
|
|
@ -37,9 +37,7 @@ class Draft(metaclass=NoPublicConstructor):
|
|||
self._chat_map = chat_map
|
||||
|
||||
@classmethod
|
||||
def _from_raw_update(
|
||||
cls, client: Client, draft: types.UpdateDraftMessage, chat_map: dict[int, Peer]
|
||||
) -> Self:
|
||||
def _from_raw_update(cls, client: Client, draft: types.UpdateDraftMessage, chat_map: dict[int, Peer]) -> Self:
|
||||
return cls._create(client, draft.peer, draft.top_msg_id, draft.draft, chat_map)
|
||||
|
||||
@classmethod
|
||||
|
@ -60,9 +58,7 @@ class Draft(metaclass=NoPublicConstructor):
|
|||
|
||||
This is also the chat where the message will be sent to by :meth:`send`.
|
||||
"""
|
||||
return self._chat_map.get(peer_id(self._peer)) or expand_peer(
|
||||
self._client, self._peer, broadcast=None
|
||||
)
|
||||
return self._chat_map.get(peer_id(self._peer)) or expand_peer(self._client, self._peer, broadcast=None)
|
||||
|
||||
@property
|
||||
def link_preview(self) -> bool:
|
||||
|
@ -91,9 +87,7 @@ class Draft(metaclass=NoPublicConstructor):
|
|||
The :attr:`~Message.text_html` of the message that will be sent.
|
||||
"""
|
||||
if text := getattr(self._raw, "message", None):
|
||||
return generate_html_message(
|
||||
text, getattr(self._raw, "entities", None) or []
|
||||
)
|
||||
return generate_html_message(text, getattr(self._raw, "entities", None) or [])
|
||||
else:
|
||||
return None
|
||||
|
||||
|
@ -103,9 +97,7 @@ class Draft(metaclass=NoPublicConstructor):
|
|||
The :attr:`~Message.text_markdown` of the message that will be sent.
|
||||
"""
|
||||
if text := getattr(self._raw, "message", None):
|
||||
return generate_markdown_message(
|
||||
text, getattr(self._raw, "entities", None) or []
|
||||
)
|
||||
return generate_markdown_message(text, getattr(self._raw, "entities", None) or [])
|
||||
else:
|
||||
return None
|
||||
|
||||
|
@ -115,11 +107,7 @@ class Draft(metaclass=NoPublicConstructor):
|
|||
The date when the draft was last updated.
|
||||
"""
|
||||
date = getattr(self._raw, "date", None)
|
||||
return (
|
||||
datetime.datetime.fromtimestamp(date, tz=datetime.timezone.utc)
|
||||
if date is not None
|
||||
else None
|
||||
)
|
||||
return datetime.datetime.fromtimestamp(date, tz=datetime.timezone.utc) if date is not None else None
|
||||
|
||||
async def edit(
|
||||
self,
|
||||
|
@ -192,11 +180,7 @@ class Draft(metaclass=NoPublicConstructor):
|
|||
noforwards=False,
|
||||
update_stickersets_order=False,
|
||||
peer=self._peer_ref()._to_input_peer(),
|
||||
reply_to=(
|
||||
types.InputReplyToMessage(reply_to_msg_id=reply_to, top_msg_id=None)
|
||||
if reply_to
|
||||
else None
|
||||
),
|
||||
reply_to=(types.InputReplyToMessage(reply_to_msg_id=reply_to, top_msg_id=None) if reply_to else None),
|
||||
message=message,
|
||||
random_id=random_id,
|
||||
reply_markup=None,
|
||||
|
@ -211,11 +195,7 @@ class Draft(metaclass=NoPublicConstructor):
|
|||
{},
|
||||
out=result.out,
|
||||
id=result.id,
|
||||
from_id=(
|
||||
types.PeerUser(user_id=self._client._session.user.id)
|
||||
if self._client._session.user
|
||||
else None
|
||||
),
|
||||
from_id=(types.PeerUser(user_id=self._client._session.user.id) if self._client._session.user else None),
|
||||
peer_id=self._peer_ref()._to_peer(),
|
||||
reply_to=(
|
||||
types.MessageReplyHeader(
|
||||
|
@ -235,9 +215,7 @@ class Draft(metaclass=NoPublicConstructor):
|
|||
ttl_period=result.ttl_period,
|
||||
)
|
||||
else:
|
||||
return self._client._build_message_map(
|
||||
result, self._peer_ref()
|
||||
).with_random_id(random_id)
|
||||
return self._client._build_message_map(result, self._peer_ref()).with_random_id(random_id)
|
||||
|
||||
async def delete(self) -> None:
|
||||
"""
|
||||
|
|
|
@ -29,11 +29,7 @@ def photo_size_byte_count(size: abcs.PhotoSize) -> int:
|
|||
elif isinstance(size, types.PhotoSizeProgressive):
|
||||
return max(size.sizes)
|
||||
elif isinstance(size, types.PhotoStrippedSize):
|
||||
return (
|
||||
len(stripped_size_header)
|
||||
+ (len(size.bytes) - 3)
|
||||
+ len(stripped_size_footer)
|
||||
)
|
||||
return len(stripped_size_header) + (len(size.bytes) - 3) + len(stripped_size_footer)
|
||||
else:
|
||||
raise RuntimeError("unexpected case")
|
||||
|
||||
|
@ -180,9 +176,7 @@ class File(metaclass=NoPublicConstructor):
|
|||
self._client = client
|
||||
|
||||
@classmethod
|
||||
def _try_from_raw_message_media(
|
||||
cls, client: Client, raw: abcs.MessageMedia
|
||||
) -> Optional[Self]:
|
||||
def _try_from_raw_message_media(cls, client: Client, raw: abcs.MessageMedia) -> Optional[Self]:
|
||||
if isinstance(raw, types.MessageMediaDocument):
|
||||
if raw.document:
|
||||
return cls._try_from_raw_document(
|
||||
|
@ -204,13 +198,9 @@ class File(metaclass=NoPublicConstructor):
|
|||
elif isinstance(raw, types.MessageMediaWebPage):
|
||||
if isinstance(raw.webpage, types.WebPage):
|
||||
if raw.webpage.document:
|
||||
return cls._try_from_raw_document(
|
||||
client, raw.webpage.document, orig_raw=raw
|
||||
)
|
||||
return cls._try_from_raw_document(client, raw.webpage.document, orig_raw=raw)
|
||||
if raw.webpage.photo:
|
||||
return cls._try_from_raw_photo(
|
||||
client, raw.webpage.photo, orig_raw=raw
|
||||
)
|
||||
return cls._try_from_raw_photo(client, raw.webpage.photo, orig_raw=raw)
|
||||
|
||||
return None
|
||||
|
||||
|
@ -229,21 +219,13 @@ class File(metaclass=NoPublicConstructor):
|
|||
attributes=raw.attributes,
|
||||
size=raw.size,
|
||||
name=next(
|
||||
(
|
||||
a.file_name
|
||||
for a in raw.attributes
|
||||
if isinstance(a, types.DocumentAttributeFilename)
|
||||
),
|
||||
(a.file_name for a in raw.attributes if isinstance(a, types.DocumentAttributeFilename)),
|
||||
"",
|
||||
),
|
||||
mime=raw.mime_type,
|
||||
photo=False,
|
||||
muted=next(
|
||||
(
|
||||
a.nosound
|
||||
for a in raw.attributes
|
||||
if isinstance(a, types.DocumentAttributeVideo)
|
||||
),
|
||||
(a.nosound for a in raw.attributes if isinstance(a, types.DocumentAttributeVideo)),
|
||||
False,
|
||||
),
|
||||
input_media=types.InputMediaDocument(
|
||||
|
@ -361,9 +343,7 @@ class File(metaclass=NoPublicConstructor):
|
|||
return dim.w
|
||||
|
||||
for attr in self._attributes:
|
||||
if isinstance(
|
||||
attr, (types.DocumentAttributeImageSize, types.DocumentAttributeVideo)
|
||||
):
|
||||
if isinstance(attr, (types.DocumentAttributeImageSize, types.DocumentAttributeVideo)):
|
||||
return attr.w
|
||||
|
||||
return None
|
||||
|
@ -377,9 +357,7 @@ class File(metaclass=NoPublicConstructor):
|
|||
return dim.h
|
||||
|
||||
for attr in self._attributes:
|
||||
if isinstance(
|
||||
attr, (types.DocumentAttributeImageSize, types.DocumentAttributeVideo)
|
||||
):
|
||||
if isinstance(attr, (types.DocumentAttributeImageSize, types.DocumentAttributeVideo)):
|
||||
return attr.h
|
||||
|
||||
return None
|
||||
|
@ -410,9 +388,7 @@ class File(metaclass=NoPublicConstructor):
|
|||
id=self._input_media.id.id,
|
||||
access_hash=self._input_media.id.access_hash,
|
||||
file_reference=self._input_media.id.file_reference,
|
||||
thumb_size=(
|
||||
self._thumb.type if isinstance(self._thumb, thumb_types) else ""
|
||||
),
|
||||
thumb_size=(self._thumb.type if isinstance(self._thumb, thumb_types) else ""),
|
||||
)
|
||||
elif isinstance(self._input_media, types.InputMediaPhoto):
|
||||
assert isinstance(self._input_media.id, types.InputPhoto)
|
||||
|
|
|
@ -12,16 +12,11 @@ def _build_keyboard_rows(
|
|||
) -> list[abcs.KeyboardButtonRow]:
|
||||
# list[button] -> list[list[button]]
|
||||
# This does allow for "invalid" inputs (mixing lists and non-lists), but that's acceptable.
|
||||
buttons_lists_iter = [
|
||||
button if isinstance(button, list) else [button] for button in (btns or [])
|
||||
]
|
||||
buttons_lists_iter = [button if isinstance(button, list) else [button] for button in (btns or [])]
|
||||
# Remove empty rows (also making it easy to check if all-empty).
|
||||
buttons_lists = [bs for bs in buttons_lists_iter if bs]
|
||||
|
||||
return [
|
||||
types.KeyboardButtonRow(buttons=[btn._raw for btn in btns])
|
||||
for btns in buttons_lists
|
||||
]
|
||||
return [types.KeyboardButtonRow(buttons=[btn._raw for btn in btns]) for btns in buttons_lists]
|
||||
|
||||
|
||||
class Keyboard:
|
||||
|
@ -49,9 +44,7 @@ class Keyboard:
|
|||
class InlineKeyboard:
|
||||
__slots__ = ("_raw",)
|
||||
|
||||
def __init__(
|
||||
self, buttons: list[AnyInlineButton] | list[list[AnyInlineButton]]
|
||||
) -> None:
|
||||
def __init__(self, buttons: list[AnyInlineButton] | list[list[AnyInlineButton]]) -> None:
|
||||
self._raw = types.ReplyInlineMarkup(rows=_build_keyboard_rows(buttons))
|
||||
|
||||
|
||||
|
|
|
@ -34,11 +34,7 @@ def generate_random_id() -> int:
|
|||
|
||||
|
||||
def adapt_date(date: Optional[int]) -> Optional[datetime.datetime]:
|
||||
return (
|
||||
datetime.datetime.fromtimestamp(date, tz=datetime.timezone.utc)
|
||||
if date is not None
|
||||
else None
|
||||
)
|
||||
return datetime.datetime.fromtimestamp(date, tz=datetime.timezone.utc) if date is not None else None
|
||||
|
||||
|
||||
class Message(metaclass=NoPublicConstructor):
|
||||
|
@ -59,20 +55,14 @@ class Message(metaclass=NoPublicConstructor):
|
|||
print('Found empty message with ID', message.id)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, client: Client, message: abcs.Message, chat_map: dict[int, Peer]
|
||||
) -> None:
|
||||
assert isinstance(
|
||||
message, (types.Message, types.MessageService, types.MessageEmpty)
|
||||
)
|
||||
def __init__(self, client: Client, message: abcs.Message, chat_map: dict[int, Peer]) -> None:
|
||||
assert isinstance(message, (types.Message, types.MessageService, types.MessageEmpty))
|
||||
self._client = client
|
||||
self._raw = message
|
||||
self._chat_map = chat_map
|
||||
|
||||
@classmethod
|
||||
def _from_raw(
|
||||
cls, client: Client, message: abcs.Message, chat_map: dict[int, Peer]
|
||||
) -> Self:
|
||||
def _from_raw(cls, client: Client, message: abcs.Message, chat_map: dict[int, Peer]) -> Self:
|
||||
return cls._create(client, message, chat_map)
|
||||
|
||||
@classmethod
|
||||
|
@ -158,9 +148,7 @@ class Message(metaclass=NoPublicConstructor):
|
|||
See :ref:`formatting` to learn the HTML elements used.
|
||||
"""
|
||||
if text := getattr(self._raw, "message", None):
|
||||
return generate_html_message(
|
||||
text, getattr(self._raw, "entities", None) or []
|
||||
)
|
||||
return generate_html_message(text, getattr(self._raw, "entities", None) or [])
|
||||
else:
|
||||
return None
|
||||
|
||||
|
@ -172,9 +160,7 @@ class Message(metaclass=NoPublicConstructor):
|
|||
See :ref:`formatting` to learn the formatting characters used.
|
||||
"""
|
||||
if text := getattr(self._raw, "message", None):
|
||||
return generate_markdown_message(
|
||||
text, getattr(self._raw, "entities", None) or []
|
||||
)
|
||||
return generate_markdown_message(text, getattr(self._raw, "entities", None) or [])
|
||||
else:
|
||||
return None
|
||||
|
||||
|
@ -193,9 +179,7 @@ class Message(metaclass=NoPublicConstructor):
|
|||
peer = self._raw.peer_id or types.PeerUser(user_id=0)
|
||||
pid = peer_id(peer)
|
||||
if pid not in self._chat_map:
|
||||
self._chat_map[pid] = expand_peer(
|
||||
self._client, peer, broadcast=getattr(self._raw, "post", None)
|
||||
)
|
||||
self._chat_map[pid] = expand_peer(self._client, peer, broadcast=getattr(self._raw, "post", None))
|
||||
return self._chat_map[pid]
|
||||
|
||||
@property
|
||||
|
@ -239,14 +223,7 @@ class Message(metaclass=NoPublicConstructor):
|
|||
This can also be used as a way to check that the message media is an audio.
|
||||
"""
|
||||
audio = self._file()
|
||||
return (
|
||||
audio
|
||||
if audio
|
||||
and any(
|
||||
isinstance(a, types.DocumentAttributeAudio) for a in audio._attributes
|
||||
)
|
||||
else None
|
||||
)
|
||||
return audio if audio and any(isinstance(a, types.DocumentAttributeAudio) for a in audio._attributes) else None
|
||||
|
||||
@property
|
||||
def video(self) -> Optional[File]:
|
||||
|
@ -256,14 +233,7 @@ class Message(metaclass=NoPublicConstructor):
|
|||
This can also be used as a way to check that the message media is a video.
|
||||
"""
|
||||
audio = self._file()
|
||||
return (
|
||||
audio
|
||||
if audio
|
||||
and any(
|
||||
isinstance(a, types.DocumentAttributeVideo) for a in audio._attributes
|
||||
)
|
||||
else None
|
||||
)
|
||||
return audio if audio and any(isinstance(a, types.DocumentAttributeVideo) for a in audio._attributes) else None
|
||||
|
||||
@property
|
||||
def file(self) -> Optional[File]:
|
||||
|
@ -477,10 +447,7 @@ class Message(metaclass=NoPublicConstructor):
|
|||
return None
|
||||
|
||||
return [
|
||||
[
|
||||
create_button(self, button)
|
||||
for button in cast(types.KeyboardButtonRow, row).buttons
|
||||
]
|
||||
[create_button(self, button) for button in cast(types.KeyboardButtonRow, row).buttons]
|
||||
for row in markup.rows
|
||||
]
|
||||
|
||||
|
@ -506,13 +473,8 @@ class Message(metaclass=NoPublicConstructor):
|
|||
return not isinstance(self._raw, types.MessageEmpty)
|
||||
|
||||
|
||||
def build_msg_map(
|
||||
client: Client, messages: Sequence[abcs.Message], chat_map: dict[int, Peer]
|
||||
) -> dict[int, Message]:
|
||||
return {
|
||||
msg.id: msg
|
||||
for msg in (Message._from_raw(client, m, chat_map) for m in messages)
|
||||
}
|
||||
def build_msg_map(client: Client, messages: Sequence[abcs.Message], chat_map: dict[int, Peer]) -> dict[int, Message]:
|
||||
return {msg.id: msg for msg in (Message._from_raw(client, m, chat_map) for m in messages)}
|
||||
|
||||
|
||||
def parse_message(
|
||||
|
|
|
@ -16,23 +16,16 @@ class Final(abc.ABCMeta):
|
|||
cls_namespace: dict[str, object],
|
||||
) -> "Final":
|
||||
# Allow subclassing while within telethon._impl (or other package names).
|
||||
allowed_base = Final.__module__[
|
||||
: Final.__module__.find(".", Final.__module__.find(".") + 1)
|
||||
]
|
||||
allowed_base = Final.__module__[: Final.__module__.find(".", Final.__module__.find(".") + 1)]
|
||||
for base in bases:
|
||||
if isinstance(base, Final) and not base.__module__.startswith(allowed_base):
|
||||
raise TypeError(
|
||||
f"{base.__module__}.{base.__qualname__} does not support"
|
||||
" subclassing"
|
||||
)
|
||||
raise TypeError(f"{base.__module__}.{base.__qualname__} does not support" " subclassing")
|
||||
return super().__new__(cls, name, bases, cls_namespace)
|
||||
|
||||
|
||||
class NoPublicConstructor(Final):
|
||||
def __call__(cls, *args: Any, **kwds: Any) -> Any:
|
||||
raise TypeError(
|
||||
f"{cls.__module__}.{cls.__qualname__} has no public constructor"
|
||||
)
|
||||
raise TypeError(f"{cls.__module__}.{cls.__qualname__} has no public constructor")
|
||||
|
||||
@property
|
||||
def _create(cls: Type[T]) -> Type[T]:
|
||||
|
|
|
@ -157,22 +157,16 @@ class Participant(metaclass=NoPublicConstructor):
|
|||
"""
|
||||
:data:`True` if the participant is the creator of the chat.
|
||||
"""
|
||||
return isinstance(
|
||||
self._raw, (types.ChannelParticipantCreator, types.ChatParticipantCreator)
|
||||
)
|
||||
return isinstance(self._raw, (types.ChannelParticipantCreator, types.ChatParticipantCreator))
|
||||
|
||||
@property
|
||||
def admin_rights(self) -> Optional[set[AdminRight]]:
|
||||
"""
|
||||
The set of administrator rights this participant has been granted, if they are an administrator.
|
||||
"""
|
||||
if isinstance(
|
||||
self._raw, (types.ChannelParticipantCreator, types.ChannelParticipantAdmin)
|
||||
):
|
||||
if isinstance(self._raw, (types.ChannelParticipantCreator, types.ChannelParticipantAdmin)):
|
||||
return AdminRight._from_raw(self._raw.admin_rights)
|
||||
elif isinstance(
|
||||
self._raw, (types.ChatParticipantCreator, types.ChatParticipantAdmin)
|
||||
):
|
||||
elif isinstance(self._raw, (types.ChatParticipantCreator, types.ChatParticipantAdmin)):
|
||||
return AdminRight._chat_rights()
|
||||
else:
|
||||
return None
|
||||
|
@ -194,13 +188,9 @@ class Participant(metaclass=NoPublicConstructor):
|
|||
participant = self.user or self.banned or self.left
|
||||
assert participant
|
||||
if isinstance(participant, User):
|
||||
await self._client.set_participant_admin_rights(
|
||||
self._chat, participant, rights
|
||||
)
|
||||
await self._client.set_participant_admin_rights(self._chat, participant, rights)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"participant of type {participant.__class__.__name__} cannot be made admin"
|
||||
)
|
||||
raise TypeError(f"participant of type {participant.__class__.__name__} cannot be made admin")
|
||||
|
||||
async def set_restrictions(
|
||||
self,
|
||||
|
@ -213,6 +203,4 @@ class Participant(metaclass=NoPublicConstructor):
|
|||
"""
|
||||
participant = self.user or self.banned or self.left
|
||||
assert participant
|
||||
await self._client.set_participant_restrictions(
|
||||
self._chat, participant, restrictions, until=until
|
||||
)
|
||||
await self._client.set_participant_restrictions(self._chat, participant, restrictions, until=until)
|
||||
|
|
|
@ -15,9 +15,7 @@ if TYPE_CHECKING:
|
|||
from ...client.client import Client
|
||||
|
||||
|
||||
def build_chat_map(
|
||||
client: Client, users: Sequence[abcs.User], chats: Sequence[abcs.Chat]
|
||||
) -> dict[int, Peer]:
|
||||
def build_chat_map(client: Client, users: Sequence[abcs.User], chats: Sequence[abcs.Chat]) -> dict[int, Peer]:
|
||||
users_iter = (User._from_raw(u) for u in users)
|
||||
chats_iter = (
|
||||
(
|
||||
|
@ -45,9 +43,7 @@ def build_chat_map(
|
|||
for x in v:
|
||||
print(x, file=sys.stderr)
|
||||
|
||||
raise RuntimeError(
|
||||
f"chat identifier collision: {k}; please report this"
|
||||
)
|
||||
raise RuntimeError(f"chat identifier collision: {k}; please report this")
|
||||
|
||||
return result
|
||||
|
||||
|
@ -81,11 +77,7 @@ def expand_peer(client: Client, peer: abcs.Peer, *, broadcast: Optional[bool]) -
|
|||
until_date=None,
|
||||
)
|
||||
|
||||
return (
|
||||
Channel._from_raw(channel)
|
||||
if broadcast
|
||||
else Group._from_raw(client, channel)
|
||||
)
|
||||
return Channel._from_raw(channel) if broadcast else Group._from_raw(client, channel)
|
||||
else:
|
||||
raise RuntimeError("unexpected case")
|
||||
|
||||
|
|
|
@ -24,13 +24,7 @@ class Group(Peer, metaclass=NoPublicConstructor):
|
|||
def __init__(
|
||||
self,
|
||||
client: Client,
|
||||
chat: (
|
||||
types.ChatEmpty
|
||||
| types.Chat
|
||||
| types.ChatForbidden
|
||||
| types.Channel
|
||||
| types.ChannelForbidden
|
||||
),
|
||||
chat: (types.ChatEmpty | types.Chat | types.ChatForbidden | types.Channel | types.ChannelForbidden),
|
||||
) -> None:
|
||||
self._client = client
|
||||
self._raw = chat
|
||||
|
@ -96,6 +90,4 @@ class Group(Peer, metaclass=NoPublicConstructor):
|
|||
"""
|
||||
Alias for :meth:`telethon.Client.set_chat_default_restrictions`.
|
||||
"""
|
||||
await self._client.set_chat_default_restrictions(
|
||||
self, restrictions, until=until
|
||||
)
|
||||
await self._client.set_chat_default_restrictions(self, restrictions, until=until)
|
||||
|
|
|
@ -1,16 +1,10 @@
|
|||
try:
|
||||
import cryptg
|
||||
import cryptg # type: ignore [import-untyped]
|
||||
|
||||
def ige_encrypt(
|
||||
plaintext: bytes | bytearray | memoryview, key: bytes, iv: bytes
|
||||
) -> bytes: # noqa: F811
|
||||
return cryptg.encrypt_ige(
|
||||
bytes(plaintext) if not isinstance(plaintext, bytes) else plaintext, key, iv
|
||||
)
|
||||
def ige_encrypt(plaintext: bytes | bytearray | memoryview, key: bytes, iv: bytes) -> bytes: # noqa: F811
|
||||
return cryptg.encrypt_ige(bytes(plaintext) if not isinstance(plaintext, bytes) else plaintext, key, iv)
|
||||
|
||||
def ige_decrypt(
|
||||
ciphertext: bytes | bytearray | memoryview, key: bytes, iv: bytes
|
||||
) -> bytes: # noqa: F811
|
||||
def ige_decrypt(ciphertext: bytes | bytearray | memoryview, key: bytes, iv: bytes) -> bytes: # noqa: F811
|
||||
return cryptg.decrypt_ige(
|
||||
bytes(ciphertext) if not isinstance(ciphertext, bytes) else ciphertext,
|
||||
key,
|
||||
|
@ -18,11 +12,9 @@ try:
|
|||
)
|
||||
|
||||
except ImportError:
|
||||
import pyaes
|
||||
import pyaes # type: ignore [import-untyped]
|
||||
|
||||
def ige_encrypt(
|
||||
plaintext: bytes | bytearray | memoryview, key: bytes, iv: bytes
|
||||
) -> bytes:
|
||||
def ige_encrypt(plaintext: bytes | bytearray | memoryview, key: bytes, iv: bytes) -> bytes:
|
||||
assert len(plaintext) % 16 == 0
|
||||
assert len(iv) == 32
|
||||
|
||||
|
@ -35,10 +27,7 @@ except ImportError:
|
|||
for block_offset in range(0, len(plaintext), 16):
|
||||
plaintext_block = plaintext[block_offset : block_offset + 16]
|
||||
ciphertext_block = bytes(
|
||||
a ^ b
|
||||
for a, b in zip(
|
||||
aes.encrypt([a ^ b for a, b in zip(plaintext_block, iv1)]), iv2
|
||||
)
|
||||
a ^ b for a, b in zip(aes.encrypt([a ^ b for a, b in zip(plaintext_block, iv1)]), iv2)
|
||||
)
|
||||
iv1 = ciphertext_block
|
||||
iv2 = plaintext_block
|
||||
|
@ -47,9 +36,7 @@ except ImportError:
|
|||
|
||||
return bytes(ciphertext)
|
||||
|
||||
def ige_decrypt(
|
||||
ciphertext: bytes | bytearray | memoryview, key: bytes, iv: bytes
|
||||
) -> bytes:
|
||||
def ige_decrypt(ciphertext: bytes | bytearray | memoryview, key: bytes, iv: bytes) -> bytes:
|
||||
assert len(ciphertext) % 16 == 0
|
||||
assert len(iv) == 32
|
||||
|
||||
|
@ -62,10 +49,7 @@ except ImportError:
|
|||
for block_offset in range(0, len(ciphertext), 16):
|
||||
ciphertext_block = ciphertext[block_offset : block_offset + 16]
|
||||
plaintext_block = bytes(
|
||||
a ^ b
|
||||
for a, b in zip(
|
||||
aes.decrypt([a ^ b for a, b in zip(ciphertext_block, iv2)]), iv1
|
||||
)
|
||||
a ^ b for a, b in zip(aes.decrypt([a ^ b for a, b in zip(ciphertext_block, iv2)]), iv1)
|
||||
)
|
||||
iv1 = ciphertext_block
|
||||
iv2 = plaintext_block
|
||||
|
|
|
@ -20,8 +20,4 @@ class AuthKey:
|
|||
return self.data
|
||||
|
||||
def calc_new_nonce_hash(self, new_nonce: int, number: int) -> int:
|
||||
return int.from_bytes(
|
||||
sha1(new_nonce.to_bytes(32) + number.to_bytes(1) + self.aux_hash).digest()[
|
||||
4:
|
||||
]
|
||||
)
|
||||
return int.from_bytes(sha1(new_nonce.to_bytes(32) + number.to_bytes(1) + self.aux_hash).digest()[4:])
|
||||
|
|
|
@ -19,9 +19,7 @@ class CalcKey(NamedTuple):
|
|||
|
||||
|
||||
# https://core.telegram.org/mtproto/description#defining-aes-key-and-initialization-vector
|
||||
def calc_key(
|
||||
auth_key: AuthKey, msg_key: bytes | bytearray | memoryview, side: Side
|
||||
) -> CalcKey:
|
||||
def calc_key(auth_key: AuthKey, msg_key: bytes | bytearray | memoryview, side: Side) -> CalcKey:
|
||||
x = int(side)
|
||||
|
||||
# sha256_a = SHA256 (msg_key + substr (auth_key, x, 36))
|
||||
|
@ -43,12 +41,8 @@ def determine_padding_v2_length(length: int) -> int:
|
|||
return 16 + (16 - (length % 16))
|
||||
|
||||
|
||||
def _do_encrypt_data_v2(
|
||||
plaintext: bytes, auth_key: AuthKey, random_padding: bytes
|
||||
) -> bytes:
|
||||
padded_plaintext = (
|
||||
plaintext + random_padding[: determine_padding_v2_length(len(plaintext))]
|
||||
)
|
||||
def _do_encrypt_data_v2(plaintext: bytes, auth_key: AuthKey, random_padding: bytes) -> bytes:
|
||||
padded_plaintext = plaintext + random_padding[: determine_padding_v2_length(len(plaintext))]
|
||||
|
||||
side = Side.CLIENT
|
||||
x = int(side)
|
||||
|
@ -70,9 +64,7 @@ def encrypt_data_v2(plaintext: bytes, auth_key: AuthKey) -> bytes:
|
|||
return _do_encrypt_data_v2(plaintext, auth_key, random_padding)
|
||||
|
||||
|
||||
def decrypt_data_v2(
|
||||
ciphertext: bytes | bytearray | memoryview, auth_key: AuthKey
|
||||
) -> bytes:
|
||||
def decrypt_data_v2(ciphertext: bytes | bytearray | memoryview, auth_key: AuthKey) -> bytes:
|
||||
side = Side.SERVER
|
||||
x = int(side)
|
||||
|
||||
|
|
|
@ -34,17 +34,13 @@ def encrypt_hashed(data: bytes, key: PublicKey, random_bytes: bytes) -> bytes:
|
|||
temp_key = random_bytes[192 + 32 * attempt : 192 + 32 * attempt + 32]
|
||||
|
||||
# data_with_hash := data_pad_reversed + SHA256(temp_key + data_with_padding); -- after this assignment, data_with_hash is exactly 224 bytes long.
|
||||
data_with_hash = (
|
||||
data_pad_reversed + sha256(temp_key + data_with_padding).digest()
|
||||
)
|
||||
data_with_hash = data_pad_reversed + sha256(temp_key + data_with_padding).digest()
|
||||
|
||||
# aes_encrypted := AES256_IGE(data_with_hash, temp_key, 0); -- AES256-IGE encryption with zero IV.
|
||||
aes_encrypted = ige_encrypt(data_with_hash, temp_key, bytes(32))
|
||||
|
||||
# temp_key_xor := temp_key XOR SHA256(aes_encrypted); -- adjusted key, 32 bytes
|
||||
temp_key_xor = bytes(
|
||||
a ^ b for a, b in zip(temp_key, sha256(aes_encrypted).digest())
|
||||
)
|
||||
temp_key_xor = bytes(a ^ b for a, b in zip(temp_key, sha256(aes_encrypted).digest()))
|
||||
|
||||
# key_aes_encrypted := temp_key_xor + aes_encrypted; -- exactly 256 bytes (2048 bits) long
|
||||
key_aes_encrypted = temp_key_xor + aes_encrypted
|
||||
|
@ -87,6 +83,4 @@ j4WcDuXc2CTHgH8gFTNhp/Y8/SpDOhvn9QIDAQAB
|
|||
)
|
||||
|
||||
|
||||
RSA_KEYS = {
|
||||
compute_fingerprint(key): key for key in (PRODUCTION_RSA_KEY, TESTMODE_RSA_KEY)
|
||||
}
|
||||
RSA_KEYS = {compute_fingerprint(key): key for key in (PRODUCTION_RSA_KEY, TESTMODE_RSA_KEY)}
|
||||
|
|
|
@ -20,9 +20,7 @@ def h(*data: bytes | bytearray | memoryview) -> bytes:
|
|||
|
||||
|
||||
# SH(data, salt) := H(salt | data | salt)
|
||||
def sh(
|
||||
data: bytes | bytearray | memoryview, salt: bytes | bytearray | memoryview
|
||||
) -> bytes:
|
||||
def sh(data: bytes | bytearray | memoryview, salt: bytes | bytearray | memoryview) -> bytes:
|
||||
return h(salt, data, salt)
|
||||
|
||||
|
||||
|
|
|
@ -108,13 +108,9 @@ def _do_step2(data: Step1, response: bytes, random_bytes: bytes) -> tuple[bytes,
|
|||
)
|
||||
|
||||
try:
|
||||
fingerprint = next(
|
||||
fp for fp in res_pq.server_public_key_fingerprints if fp in RSA_KEYS
|
||||
)
|
||||
fingerprint = next(fp for fp in res_pq.server_public_key_fingerprints if fp in RSA_KEYS)
|
||||
except StopIteration:
|
||||
raise ValueError(
|
||||
f"unknown fingerprints: {res_pq.server_public_key_fingerprints}"
|
||||
)
|
||||
raise ValueError(f"unknown fingerprints: {res_pq.server_public_key_fingerprints}")
|
||||
|
||||
key = RSA_KEYS[fingerprint]
|
||||
ciphertext = encrypt_hashed(pq_inner_data, key, random_bytes)
|
||||
|
@ -133,9 +129,7 @@ def step2(data: Step1, response: bytes) -> tuple[bytes, Step2]:
|
|||
return _do_step2(data, response, os.urandom(288))
|
||||
|
||||
|
||||
def _do_step3(
|
||||
data: Step2, response: bytes, random_bytes: bytes, now: int
|
||||
) -> tuple[bytes, Step3]:
|
||||
def _do_step3(data: Step2, response: bytes, random_bytes: bytes, now: int) -> tuple[bytes, Step3]:
|
||||
assert len(random_bytes) == 272
|
||||
|
||||
nonce = data.nonce
|
||||
|
@ -158,9 +152,7 @@ def _do_step3(
|
|||
check_server_nonce(server_dh_params.server_nonce, server_nonce)
|
||||
|
||||
if len(server_dh_params.encrypted_answer) % 16 != 0:
|
||||
raise ValueError(
|
||||
f"encrypted response not padded with size: {len(server_dh_params.encrypted_answer)}"
|
||||
)
|
||||
raise ValueError(f"encrypted response not padded with size: {len(server_dh_params.encrypted_answer)}")
|
||||
|
||||
key, iv = generate_key_data_from_nonce(server_nonce, new_nonce)
|
||||
assert isinstance(server_dh_params.encrypted_answer, bytes)
|
||||
|
@ -172,9 +164,7 @@ def _do_step3(
|
|||
server_dh_inner = AbcServerDhInnerData._read_from(plain_text_reader)
|
||||
assert isinstance(server_dh_inner, ServerDhInnerData)
|
||||
|
||||
expected_answer_hash = sha1(
|
||||
plain_text_answer[20 : 20 + plain_text_reader._pos]
|
||||
).digest()
|
||||
expected_answer_hash = sha1(plain_text_answer[20 : 20 + plain_text_reader._pos]).digest()
|
||||
|
||||
if got_answer_hash != expected_answer_hash:
|
||||
raise ValueError("invalid answer hash")
|
||||
|
@ -213,15 +203,11 @@ def _do_step3(
|
|||
)
|
||||
|
||||
client_dh_inner_hashed = sha1(client_dh_inner).digest() + client_dh_inner
|
||||
client_dh_inner_hashed += random_bytes[
|
||||
: (16 - (len(client_dh_inner_hashed) % 16)) % 16
|
||||
]
|
||||
client_dh_inner_hashed += random_bytes[: (16 - (len(client_dh_inner_hashed) % 16)) % 16]
|
||||
|
||||
client_dh_encrypted = encrypt_ige(client_dh_inner_hashed, key, iv)
|
||||
|
||||
return set_client_dh_params(
|
||||
nonce=nonce, server_nonce=server_nonce, encrypted_data=client_dh_encrypted
|
||||
), Step3(
|
||||
return set_client_dh_params(nonce=nonce, server_nonce=server_nonce, encrypted_data=client_dh_encrypted), Step3(
|
||||
nonce=nonce,
|
||||
server_nonce=server_nonce,
|
||||
new_nonce=new_nonce,
|
||||
|
@ -277,10 +263,7 @@ def create_key(data: Step3, response: bytes) -> CreatedKey:
|
|||
|
||||
first_salt = struct.unpack(
|
||||
"<q",
|
||||
bytes(
|
||||
a ^ b
|
||||
for a, b in zip(new_nonce.to_bytes(32)[:8], server_nonce.to_bytes(16)[:8])
|
||||
),
|
||||
bytes(a ^ b for a, b in zip(new_nonce.to_bytes(32)[:8], server_nonce.to_bytes(16)[:8])),
|
||||
)[0]
|
||||
|
||||
if dh_gen.nonce_number == 1:
|
||||
|
@ -310,4 +293,4 @@ def check_new_nonce_hash(got: int, expected: int) -> None:
|
|||
|
||||
def check_g_in_range(value: int, low: int, high: int) -> None:
|
||||
if not (low < value < high):
|
||||
raise ValueError(f"g parameter {value} not in range({low+1}, {high})")
|
||||
raise ValueError(f"g parameter {value} not in range({low + 1}, {high})")
|
||||
|
|
|
@ -107,9 +107,7 @@ class Encrypted(Mtp):
|
|||
) -> None:
|
||||
self._auth_key = auth_key
|
||||
self._time_offset: int = time_offset or 0
|
||||
self._salts: list[FutureSalt] = [
|
||||
FutureSalt(valid_since=0, valid_until=0x7FFFFFFF, salt=first_salt or 0)
|
||||
]
|
||||
self._salts: list[FutureSalt] = [FutureSalt(valid_since=0, valid_until=0x7FFFFFFF, salt=first_salt or 0)]
|
||||
self._start_salt_time: Optional[tuple[int, float]] = None
|
||||
self._compression_threshold = compression_threshold
|
||||
self._deserialization: list[Deserialization] = []
|
||||
|
@ -203,9 +201,7 @@ class Encrypted(Mtp):
|
|||
if self._msg_count == 1:
|
||||
del self._buffer[:CONTAINER_HEADER_LEN]
|
||||
|
||||
self._buffer[:HEADER_LEN] = struct.pack(
|
||||
"<qq", self._get_current_salt(), self._client_id
|
||||
)
|
||||
self._buffer[:HEADER_LEN] = struct.pack("<qq", self._get_current_salt(), self._client_id)
|
||||
|
||||
if self._msg_count != 1:
|
||||
self._buffer[HEADER_LEN : HEADER_LEN + CONTAINER_HEADER_LEN] = struct.pack(
|
||||
|
@ -283,11 +279,7 @@ class Encrypted(Mtp):
|
|||
|
||||
if isinstance(bad_msg, BadServerSalt):
|
||||
self._salts.clear()
|
||||
self._salts.append(
|
||||
FutureSalt(
|
||||
valid_since=0, valid_until=0x7FFFFFFF, salt=bad_msg.new_server_salt
|
||||
)
|
||||
)
|
||||
self._salts.append(FutureSalt(valid_since=0, valid_until=0x7FFFFFFF, salt=bad_msg.new_server_salt))
|
||||
self._salt_request_msg_id = None
|
||||
elif bad_msg.error_code in (16, 17):
|
||||
self._correct_time_offset(message.msg_id)
|
||||
|
@ -324,9 +316,7 @@ class Encrypted(Mtp):
|
|||
# Response to internal request, do not propagate.
|
||||
self._salt_request_msg_id = None
|
||||
else:
|
||||
self._deserialization.append(
|
||||
RpcResult(MsgId(salts.req_msg_id), message.body)
|
||||
)
|
||||
self._deserialization.append(RpcResult(MsgId(salts.req_msg_id), message.body))
|
||||
|
||||
self._start_salt_time = (salts.now, self._adjusted_now())
|
||||
self._salts = list(salts.salts)
|
||||
|
@ -346,11 +336,7 @@ class Encrypted(Mtp):
|
|||
def _handle_new_session_created(self, message: Message) -> None:
|
||||
new_session = NewSessionCreated.from_bytes(message.body)
|
||||
self._salts.clear()
|
||||
self._salts.append(
|
||||
FutureSalt(
|
||||
valid_since=0, valid_until=0x7FFFFFFF, salt=new_session.server_salt
|
||||
)
|
||||
)
|
||||
self._salts.append(FutureSalt(valid_since=0, valid_until=0x7FFFFFFF, salt=new_session.server_salt))
|
||||
|
||||
def _handle_container(self, message: Message) -> None:
|
||||
container = MsgContainer.from_bytes(message.body)
|
||||
|
@ -376,11 +362,7 @@ class Encrypted(Mtp):
|
|||
self._deserialization.append(Update(message.body))
|
||||
|
||||
def _try_request_salts(self) -> None:
|
||||
if (
|
||||
len(self._salts) == 1
|
||||
and self._salt_request_msg_id is None
|
||||
and self._get_current_salt() != 0
|
||||
):
|
||||
if len(self._salts) == 1 and self._salt_request_msg_id is None and self._get_current_salt() != 0:
|
||||
# If salts are requested in a container leading to bad_msg,
|
||||
# the bad_msg_id will refer to the container, not the salts request.
|
||||
#
|
||||
|
@ -388,9 +370,7 @@ class Encrypted(Mtp):
|
|||
# This would break, because we couldn't identify the response.
|
||||
#
|
||||
# So salts are only requested once we have a valid salt to reduce the chances of this happening.
|
||||
self._salt_request_msg_id = self._serialize_msg(
|
||||
bytes(get_future_salts(num=NUM_FUTURE_SALTS)), True
|
||||
)
|
||||
self._salt_request_msg_id = self._serialize_msg(bytes(get_future_salts(num=NUM_FUTURE_SALTS)), True)
|
||||
|
||||
def push(self, request: bytes) -> Optional[MsgId]:
|
||||
if self._start_salt_time and len(self._salts) >= 2:
|
||||
|
@ -435,9 +415,7 @@ class Encrypted(Mtp):
|
|||
|
||||
return MsgId(self._last_msg_id), encrypt_data_v2(result, self._auth_key)
|
||||
|
||||
def deserialize(
|
||||
self, payload: bytes | bytearray | memoryview
|
||||
) -> list[Deserialization]:
|
||||
def deserialize(self, payload: bytes | bytearray | memoryview) -> list[Deserialization]:
|
||||
check_message_buffer(payload)
|
||||
|
||||
plaintext = decrypt_data_v2(payload, self._auth_key)
|
||||
|
|
|
@ -31,9 +31,7 @@ class Plain(Mtp):
|
|||
self._buffer.clear()
|
||||
return MsgId(0), result
|
||||
|
||||
def deserialize(
|
||||
self, payload: bytes | bytearray | memoryview
|
||||
) -> list[Deserialization]:
|
||||
def deserialize(self, payload: bytes | bytearray | memoryview) -> list[Deserialization]:
|
||||
check_message_buffer(payload)
|
||||
|
||||
auth_key_id, msg_id, length = struct.unpack_from("<qqi", payload)
|
||||
|
@ -48,8 +46,6 @@ class Plain(Mtp):
|
|||
raise ValueError(f"bad length: expected >= 0, got: {length}")
|
||||
|
||||
if 20 + length > len(payload):
|
||||
raise ValueError(
|
||||
f"message too short, expected: {20 + length}, got {len(payload)}"
|
||||
)
|
||||
raise ValueError(f"message too short, expected: {20 + length}, got {len(payload)}")
|
||||
|
||||
return [RpcResult(MsgId(0), bytes(payload[20 : 20 + length]))]
|
||||
|
|
|
@ -116,11 +116,7 @@ class RpcError(ValueError):
|
|||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, self.__class__):
|
||||
return NotImplemented
|
||||
return (
|
||||
self._code == other._code
|
||||
and self._name == other._name
|
||||
and self._value == other._value
|
||||
)
|
||||
return self._code == other._code and self._name == other._name and self._value == other._value
|
||||
|
||||
|
||||
# https://core.telegram.org/mtproto/service_messages_about_messages
|
||||
|
@ -156,9 +152,7 @@ class BadMessage(ValueError):
|
|||
self.msg_id = msg_id
|
||||
self._code = code
|
||||
self._caused_by = caused_by
|
||||
self.severity = (
|
||||
logging.WARNING if self._code in NON_FATAL_MSG_IDS else logging.ERROR
|
||||
)
|
||||
self.severity = logging.WARNING if self._code in NON_FATAL_MSG_IDS else logging.ERROR
|
||||
|
||||
@property
|
||||
def code(self) -> int:
|
||||
|
@ -201,9 +195,7 @@ class Mtp(ABC):
|
|||
"""
|
||||
|
||||
@abstractmethod
|
||||
def deserialize(
|
||||
self, payload: bytes | bytearray | memoryview
|
||||
) -> list[Deserialization]:
|
||||
def deserialize(self, payload: bytes | bytearray | memoryview) -> list[Deserialization]:
|
||||
"""
|
||||
Deserialize incoming buffer payload.
|
||||
"""
|
||||
|
|
|
@ -42,10 +42,7 @@ class Intermediate(Transport):
|
|||
raise MissingBytes(expected=length, got=len(input))
|
||||
|
||||
if length <= 4:
|
||||
if (
|
||||
length >= 4
|
||||
and (status := struct.unpack("<i", input[4 : 4 + length])[0]) < 0
|
||||
):
|
||||
if length >= 4 and (status := struct.unpack("<i", input[4 : 4 + length])[0]) < 0:
|
||||
raise BadStatus(status=-status)
|
||||
|
||||
raise ValueError(f"bad length, expected > 0, got: {length}")
|
||||
|
|
|
@ -12,9 +12,7 @@ MESSAGE_SIZE_OVERHEAD = 8 + 4 + 4 # msg_id, seq_no, bytes
|
|||
|
||||
def check_message_buffer(message: bytes | bytearray | memoryview) -> None:
|
||||
if len(message) < 20:
|
||||
raise ValueError(
|
||||
f"server payload is too small to be a valid message: {message.hex()}"
|
||||
)
|
||||
raise ValueError(f"server payload is too small to be a valid message: {message.hex()}")
|
||||
|
||||
|
||||
# https://core.telegram.org/mtproto/description#content-related-message
|
||||
|
|
|
@ -313,13 +313,7 @@ class Sender:
|
|||
|
||||
def _on_ping_timeout(self) -> None:
|
||||
ping_id = generate_random_id()
|
||||
self._enqueue_body(
|
||||
bytes(
|
||||
ping_delay_disconnect(
|
||||
ping_id=ping_id, disconnect_delay=NO_PING_DISCONNECT
|
||||
)
|
||||
)
|
||||
)
|
||||
self._enqueue_body(bytes(ping_delay_disconnect(ping_id=ping_id, disconnect_delay=NO_PING_DISCONNECT)))
|
||||
self._next_ping = asyncio.get_running_loop().time() + PING_DELAY
|
||||
|
||||
def _process_mtp_buffer(self, updates: list[Updates]) -> None:
|
||||
|
@ -335,9 +329,7 @@ class Sender:
|
|||
else:
|
||||
self._process_bad_message(result)
|
||||
|
||||
def _process_update(
|
||||
self, updates: list[Updates], update: bytes | bytearray | memoryview
|
||||
) -> None:
|
||||
def _process_update(self, updates: list[Updates], update: bytes | bytearray | memoryview) -> None:
|
||||
try:
|
||||
updates.append(Updates.from_bytes(update))
|
||||
except ValueError:
|
||||
|
@ -441,9 +433,7 @@ class Sender:
|
|||
req.state.msg_id == msg_id or req.state.container_msg_id == msg_id
|
||||
):
|
||||
raise RuntimeError("got response for unsent request")
|
||||
elif isinstance(req.state, Sent) and (
|
||||
req.state.msg_id == msg_id or req.state.container_msg_id == msg_id
|
||||
):
|
||||
elif isinstance(req.state, Sent) and (req.state.msg_id == msg_id or req.state.container_msg_id == msg_id):
|
||||
yield self._requests.pop(i)
|
||||
|
||||
@property
|
||||
|
|
|
@ -65,9 +65,7 @@ class ChatHashCache:
|
|||
return self._has_peer(peer.peer)
|
||||
elif isinstance(peer, types.NotifyForumTopic):
|
||||
return self._has_peer(peer.peer)
|
||||
elif isinstance(
|
||||
peer, (types.NotifyUsers, types.NotifyChats, types.NotifyBroadcasts)
|
||||
):
|
||||
elif isinstance(peer, (types.NotifyUsers, types.NotifyChats, types.NotifyBroadcasts)):
|
||||
return True
|
||||
else:
|
||||
raise RuntimeError("unexpected case")
|
||||
|
@ -120,9 +118,7 @@ class ChatHashCache:
|
|||
elif isinstance(participant, types.ChannelParticipantAdmin):
|
||||
return (
|
||||
self._has(participant.user_id)
|
||||
and (
|
||||
participant.inviter_id is None or self._has(participant.inviter_id)
|
||||
)
|
||||
and (participant.inviter_id is None or self._has(participant.inviter_id))
|
||||
and self._has(participant.promoted_by)
|
||||
)
|
||||
elif isinstance(participant, types.ChannelParticipantBanned):
|
||||
|
|
|
@ -43,12 +43,8 @@ class PeerRef(abc.ABC):
|
|||
|
||||
__slots__ = ("identifier", "authorization")
|
||||
|
||||
def __init__(
|
||||
self, identifier: PeerIdentifier, authorization: PeerAuth = None
|
||||
) -> None:
|
||||
assert (
|
||||
identifier >= 0
|
||||
), "PeerRef identifiers must be positive; see the documentation for Peers"
|
||||
def __init__(self, identifier: PeerIdentifier, authorization: PeerAuth = None) -> None:
|
||||
assert identifier >= 0, "PeerRef identifiers must be positive; see the documentation for Peers"
|
||||
self.identifier = identifier
|
||||
self.authorization = authorization
|
||||
|
||||
|
@ -83,9 +79,7 @@ class PeerRef(abc.ABC):
|
|||
authorization: Optional[int] = None
|
||||
else:
|
||||
try:
|
||||
(authorization,) = struct.unpack(
|
||||
"!q", base64.urlsafe_b64decode(auth.encode("ascii") + b"=")
|
||||
)
|
||||
(authorization,) = struct.unpack("!q", base64.urlsafe_b64decode(auth.encode("ascii") + b"="))
|
||||
except Exception:
|
||||
raise ValueError(f"invalid PeerRef string: {string!r}")
|
||||
|
||||
|
@ -137,21 +131,14 @@ class PeerRef(abc.ABC):
|
|||
if self.authorization is None:
|
||||
auth = "0"
|
||||
else:
|
||||
auth = (
|
||||
base64.urlsafe_b64encode(struct.pack("!q", self.authorization))
|
||||
.decode("ascii")
|
||||
.rstrip("=")
|
||||
)
|
||||
auth = base64.urlsafe_b64encode(struct.pack("!q", self.authorization)).decode("ascii").rstrip("=")
|
||||
|
||||
return f"{self.identifier}.{auth}"
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, self.__class__):
|
||||
return NotImplemented
|
||||
return (
|
||||
self.identifier == other.identifier
|
||||
and self.authorization == other.authorization
|
||||
)
|
||||
return self.identifier == other.identifier and self.authorization == other.authorization
|
||||
|
||||
@property
|
||||
def _ref(self) -> UserRef | GroupRef | ChannelRef:
|
||||
|
@ -182,16 +169,12 @@ class UserRef(PeerRef):
|
|||
def _to_input_peer(self) -> abcs.InputPeer:
|
||||
if self.identifier == SELF_USER_SENTINEL_ID:
|
||||
return types.InputPeerSelf()
|
||||
return types.InputPeerUser(
|
||||
user_id=self.identifier, access_hash=self.authorization or 0
|
||||
)
|
||||
return types.InputPeerUser(user_id=self.identifier, access_hash=self.authorization or 0)
|
||||
|
||||
def _to_input_user(self) -> abcs.InputUser:
|
||||
if self.identifier == SELF_USER_SENTINEL_ID:
|
||||
return types.InputUserSelf()
|
||||
return types.InputUser(
|
||||
user_id=self.identifier, access_hash=self.authorization or 0
|
||||
)
|
||||
return types.InputUser(user_id=self.identifier, access_hash=self.authorization or 0)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{USER_PREFIX}{self._encode_str()}"
|
||||
|
@ -252,14 +235,10 @@ class ChannelRef(PeerRef):
|
|||
return types.PeerChannel(channel_id=self.identifier)
|
||||
|
||||
def _to_input_peer(self) -> abcs.InputPeer:
|
||||
return types.InputPeerChannel(
|
||||
channel_id=self.identifier, access_hash=self.authorization or 0
|
||||
)
|
||||
return types.InputPeerChannel(channel_id=self.identifier, access_hash=self.authorization or 0)
|
||||
|
||||
def _to_input_channel(self) -> types.InputChannel:
|
||||
return types.InputChannel(
|
||||
channel_id=self.identifier, access_hash=self.authorization or 0
|
||||
)
|
||||
return types.InputChannel(channel_id=self.identifier, access_hash=self.authorization or 0)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{CHANNEL_PREFIX}{self._encode_str()}"
|
||||
|
|
|
@ -27,9 +27,7 @@ def update_short(short: types.UpdateShort) -> types.UpdatesCombined:
|
|||
)
|
||||
|
||||
|
||||
def update_short_message(
|
||||
short: types.UpdateShortMessage, self_id: int
|
||||
) -> types.UpdatesCombined:
|
||||
def update_short_message(short: types.UpdateShortMessage, self_id: int) -> types.UpdatesCombined:
|
||||
return update_short(
|
||||
types.UpdateShort(
|
||||
update=types.UpdateNewMessage(
|
||||
|
@ -46,9 +44,7 @@ def update_short_message(
|
|||
noforwards=False,
|
||||
reactions=None,
|
||||
id=short.id,
|
||||
from_id=types.PeerUser(
|
||||
user_id=self_id if short.out else short.user_id
|
||||
),
|
||||
from_id=types.PeerUser(user_id=self_id if short.out else short.user_id),
|
||||
peer_id=types.PeerChat(
|
||||
chat_id=short.user_id,
|
||||
),
|
||||
|
|
|
@ -38,9 +38,7 @@ class PtsInfo:
|
|||
self.entry = entry
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"PtsInfo(pts={self.pts}, pts_count={self.pts_count}, entry={self.entry})"
|
||||
)
|
||||
return f"PtsInfo(pts={self.pts}, pts_count={self.pts_count}, entry={self.entry})"
|
||||
|
||||
|
||||
class State:
|
||||
|
@ -70,9 +68,7 @@ class PossibleGap:
|
|||
self.updates = updates
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"PossibleGap(deadline={self.deadline}, update_count={len(self.updates)})"
|
||||
)
|
||||
return f"PossibleGap(deadline={self.deadline}, update_count={len(self.updates)})"
|
||||
|
||||
|
||||
class PrematureEndReason(Enum):
|
||||
|
|
|
@ -91,13 +91,9 @@ class MessageBox:
|
|||
self.map[ENTRY_ACCOUNT] = State(pts=state.pts, deadline=deadline)
|
||||
if state.qts != NO_SEQ:
|
||||
self.map[ENTRY_SECRET] = State(pts=state.qts, deadline=deadline)
|
||||
self.map.update(
|
||||
(s.id, State(pts=s.pts, deadline=deadline)) for s in state.channels
|
||||
)
|
||||
self.map.update((s.id, State(pts=s.pts, deadline=deadline)) for s in state.channels)
|
||||
|
||||
self.date = datetime.datetime.fromtimestamp(
|
||||
state.date, tz=datetime.timezone.utc
|
||||
)
|
||||
self.date = datetime.datetime.fromtimestamp(state.date, tz=datetime.timezone.utc)
|
||||
self.seq = state.seq
|
||||
self.possible_gaps.clear()
|
||||
self.getting_diff_for.clear()
|
||||
|
@ -136,28 +132,18 @@ class MessageBox:
|
|||
default_deadline = next_updates_deadline()
|
||||
|
||||
if self.possible_gaps:
|
||||
deadline = min(
|
||||
default_deadline, *(gap.deadline for gap in self.possible_gaps.values())
|
||||
)
|
||||
deadline = min(default_deadline, *(gap.deadline for gap in self.possible_gaps.values()))
|
||||
elif self.next_deadline in self.map:
|
||||
deadline = min(default_deadline, self.map[self.next_deadline].deadline)
|
||||
else:
|
||||
deadline = default_deadline
|
||||
|
||||
if now >= deadline:
|
||||
self.getting_diff_for.update(
|
||||
entry
|
||||
for entry, gap in self.possible_gaps.items()
|
||||
if now >= gap.deadline
|
||||
)
|
||||
self.getting_diff_for.update(
|
||||
entry for entry, state in self.map.items() if now >= state.deadline
|
||||
)
|
||||
self.getting_diff_for.update(entry for entry, gap in self.possible_gaps.items() if now >= gap.deadline)
|
||||
self.getting_diff_for.update(entry for entry, state in self.map.items() if now >= state.deadline)
|
||||
|
||||
if __debug__:
|
||||
self._trace(
|
||||
"deadlines met, now getting diff for: %r", self.getting_diff_for
|
||||
)
|
||||
self._trace("deadlines met, now getting diff for: %r", self.getting_diff_for)
|
||||
|
||||
for entry in self.getting_diff_for:
|
||||
self.possible_gaps.pop(entry, None)
|
||||
|
@ -171,19 +157,12 @@ class MessageBox:
|
|||
entry: Entry = ENTRY_ACCOUNT # for pyright to know it's not unbound
|
||||
for entry in entries:
|
||||
if entry not in self.map:
|
||||
raise RuntimeError(
|
||||
"Called reset_deadline on an entry for which we do not have state"
|
||||
)
|
||||
raise RuntimeError("Called reset_deadline on an entry for which we do not have state")
|
||||
self.map[entry].deadline = deadline
|
||||
|
||||
if self.next_deadline in entries:
|
||||
self.next_deadline = min(
|
||||
self.map.items(), key=lambda entry_state: entry_state[1].deadline
|
||||
)[0]
|
||||
elif (
|
||||
self.next_deadline in self.map
|
||||
and deadline < self.map[self.next_deadline].deadline
|
||||
):
|
||||
self.next_deadline = min(self.map.items(), key=lambda entry_state: entry_state[1].deadline)[0]
|
||||
elif self.next_deadline in self.map and deadline < self.map[self.next_deadline].deadline:
|
||||
self.next_deadline = entry
|
||||
|
||||
def reset_channel_deadline(self, channel_id: int, timeout: Optional[float]) -> None:
|
||||
|
@ -200,9 +179,7 @@ class MessageBox:
|
|||
assert isinstance(state, types.updates.State)
|
||||
self.map[ENTRY_ACCOUNT] = State(state.pts, deadline)
|
||||
self.map[ENTRY_SECRET] = State(state.qts, deadline)
|
||||
self.date = datetime.datetime.fromtimestamp(
|
||||
state.date, tz=datetime.timezone.utc
|
||||
)
|
||||
self.date = datetime.datetime.fromtimestamp(state.date, tz=datetime.timezone.utc)
|
||||
self.seq = state.seq
|
||||
|
||||
def try_set_channel_state(self, id: int, pts: int) -> None:
|
||||
|
@ -215,15 +192,11 @@ class MessageBox:
|
|||
def try_begin_get_diff(self, entry: Entry, reason: str) -> None:
|
||||
if entry not in self.map:
|
||||
if entry in self.possible_gaps:
|
||||
raise RuntimeError(
|
||||
"Should not have a possible_gap for an entry not in the state map"
|
||||
)
|
||||
raise RuntimeError("Should not have a possible_gap for an entry not in the state map")
|
||||
return
|
||||
|
||||
if __debug__:
|
||||
self._trace(
|
||||
"marking entry=%r as needing difference because: %s", entry, reason
|
||||
)
|
||||
self._trace("marking entry=%r as needing difference because: %s", entry, reason)
|
||||
self.getting_diff_for.add(entry)
|
||||
self.possible_gaps.pop(entry, None)
|
||||
|
||||
|
@ -231,14 +204,10 @@ class MessageBox:
|
|||
try:
|
||||
self.getting_diff_for.remove(entry)
|
||||
except KeyError:
|
||||
raise RuntimeError(
|
||||
"Called end_get_diff on an entry which was not getting diff for"
|
||||
)
|
||||
raise RuntimeError("Called end_get_diff on an entry which was not getting diff for")
|
||||
|
||||
self.reset_deadlines({entry}, next_updates_deadline())
|
||||
assert (
|
||||
entry not in self.possible_gaps
|
||||
), "gaps shouldn't be created while getting difference"
|
||||
assert entry not in self.possible_gaps, "gaps shouldn't be created while getting difference"
|
||||
|
||||
def ensure_known_peer_hashes(
|
||||
self,
|
||||
|
@ -246,10 +215,7 @@ class MessageBox:
|
|||
chat_hashes: ChatHashCache,
|
||||
) -> None:
|
||||
if not chat_hashes.extend_from_updates(updates):
|
||||
can_recover = (
|
||||
not isinstance(updates, types.UpdateShort)
|
||||
or pts_info_from_update(updates.update) is not None
|
||||
)
|
||||
can_recover = not isinstance(updates, types.UpdateShort) or pts_info_from_update(updates.update) is not None
|
||||
if can_recover:
|
||||
self.try_begin_get_diff(ENTRY_ACCOUNT, "missing hash")
|
||||
raise Gap
|
||||
|
@ -275,9 +241,7 @@ class MessageBox:
|
|||
if combined.seq_start != NO_SEQ:
|
||||
if self.seq + 1 > combined.seq_start:
|
||||
if __debug__:
|
||||
self._trace(
|
||||
"skipping updates as they should have already been handled"
|
||||
)
|
||||
self._trace("skipping updates as they should have already been handled")
|
||||
return result, combined.users, combined.chats
|
||||
elif self.seq + 1 < combined.seq_start:
|
||||
self.try_begin_get_diff(ENTRY_ACCOUNT, "detected gap")
|
||||
|
@ -305,17 +269,13 @@ class MessageBox:
|
|||
if __debug__:
|
||||
self._trace("updating seq as local pts was updated too")
|
||||
if combined.date != NO_DATE:
|
||||
self.date = datetime.datetime.fromtimestamp(
|
||||
combined.date, tz=datetime.timezone.utc
|
||||
)
|
||||
self.date = datetime.datetime.fromtimestamp(combined.date, tz=datetime.timezone.utc)
|
||||
if combined.seq != NO_SEQ:
|
||||
self.seq = combined.seq
|
||||
|
||||
if self.possible_gaps:
|
||||
if __debug__:
|
||||
self._trace(
|
||||
"trying to re-apply count=%r possible gaps", len(self.possible_gaps)
|
||||
)
|
||||
self._trace("trying to re-apply count=%r possible gaps", len(self.possible_gaps))
|
||||
|
||||
for key in list(self.possible_gaps.keys()):
|
||||
self.possible_gaps[key].updates.sort(key=update_sort_key)
|
||||
|
@ -332,9 +292,7 @@ class MessageBox:
|
|||
applied,
|
||||
)
|
||||
|
||||
self.possible_gaps = {
|
||||
entry: gap for entry, gap in self.possible_gaps.items() if gap.updates
|
||||
}
|
||||
self.possible_gaps = {entry: gap for entry, gap in self.possible_gaps.items() if gap.updates}
|
||||
|
||||
return result, combined.users, combined.chats
|
||||
|
||||
|
@ -384,8 +342,7 @@ class MessageBox:
|
|||
)
|
||||
if pts.entry not in self.possible_gaps:
|
||||
self.possible_gaps[pts.entry] = PossibleGap(
|
||||
deadline=asyncio.get_running_loop().time()
|
||||
+ POSSIBLE_GAP_TIMEOUT,
|
||||
deadline=asyncio.get_running_loop().time() + POSSIBLE_GAP_TIMEOUT,
|
||||
updates=[],
|
||||
)
|
||||
|
||||
|
@ -413,20 +370,14 @@ class MessageBox:
|
|||
for entry in (ENTRY_ACCOUNT, ENTRY_SECRET):
|
||||
if entry in self.getting_diff_for:
|
||||
if entry not in self.map:
|
||||
raise RuntimeError(
|
||||
"Should not try to get difference for an entry without known state"
|
||||
)
|
||||
raise RuntimeError("Should not try to get difference for an entry without known state")
|
||||
|
||||
gd = functions.updates.get_difference(
|
||||
pts=self.map[ENTRY_ACCOUNT].pts,
|
||||
pts_limit=None,
|
||||
pts_total_limit=None,
|
||||
date=int(self.date.timestamp()),
|
||||
qts=(
|
||||
self.map[ENTRY_SECRET].pts
|
||||
if ENTRY_SECRET in self.map
|
||||
else NO_SEQ
|
||||
),
|
||||
qts=(self.map[ENTRY_SECRET].pts if ENTRY_SECRET in self.map else NO_SEQ),
|
||||
qts_limit=None,
|
||||
)
|
||||
if __debug__:
|
||||
|
@ -447,9 +398,7 @@ class MessageBox:
|
|||
result: tuple[list[abcs.Update], Sequence[abcs.User], Sequence[abcs.Chat]]
|
||||
if isinstance(diff, types.updates.DifferenceEmpty):
|
||||
finish = True
|
||||
self.date = datetime.datetime.fromtimestamp(
|
||||
diff.date, tz=datetime.timezone.utc
|
||||
)
|
||||
self.date = datetime.datetime.fromtimestamp(diff.date, tz=datetime.timezone.utc)
|
||||
self.seq = diff.seq
|
||||
result = [], [], []
|
||||
elif isinstance(diff, types.updates.Difference):
|
||||
|
@ -502,9 +451,7 @@ class MessageBox:
|
|||
assert isinstance(state, types.updates.State)
|
||||
self.map[ENTRY_ACCOUNT].pts = state.pts
|
||||
self.map[ENTRY_SECRET].pts = state.qts
|
||||
self.date = datetime.datetime.fromtimestamp(
|
||||
state.date, tz=datetime.timezone.utc
|
||||
)
|
||||
self.date = datetime.datetime.fromtimestamp(state.date, tz=datetime.timezone.utc)
|
||||
self.seq = state.seq
|
||||
|
||||
updates, users, chats = self.process_updates(
|
||||
|
@ -560,19 +507,13 @@ class MessageBox:
|
|||
channel=channel,
|
||||
filter=types.ChannelMessagesFilterEmpty(),
|
||||
pts=state.pts,
|
||||
limit=(
|
||||
BOT_CHANNEL_DIFF_LIMIT
|
||||
if chat_hashes.is_self_bot
|
||||
else USER_CHANNEL_DIFF_LIMIT
|
||||
),
|
||||
limit=(BOT_CHANNEL_DIFF_LIMIT if chat_hashes.is_self_bot else USER_CHANNEL_DIFF_LIMIT),
|
||||
)
|
||||
if __debug__:
|
||||
self._trace("requesting channel difference: %s", gd)
|
||||
return gd
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"should not try to get difference for an entry without known state"
|
||||
)
|
||||
raise RuntimeError("should not try to get difference for an entry without known state")
|
||||
else:
|
||||
self.end_get_diff(entry)
|
||||
self.map.pop(entry, None)
|
||||
|
@ -638,9 +579,7 @@ class MessageBox:
|
|||
else:
|
||||
raise RuntimeError("unexpected case")
|
||||
|
||||
def end_channel_difference(
|
||||
self, channel_id: int, reason: PrematureEndReason
|
||||
) -> None:
|
||||
def end_channel_difference(self, channel_id: int, reason: PrematureEndReason) -> None:
|
||||
entry: Entry = channel_id
|
||||
if __debug__:
|
||||
self._trace("ending channel=%r difference: %s", entry, reason)
|
||||
|
|
|
@ -38,9 +38,7 @@ class SqliteSession(Storage):
|
|||
if version == 7:
|
||||
session = self._load_v7(c)
|
||||
else:
|
||||
raise ValueError(
|
||||
"only migration from sqlite session format 7 supported"
|
||||
)
|
||||
raise ValueError("only migration from sqlite session format 7 supported")
|
||||
|
||||
self._reset(c)
|
||||
self._get_or_init_version(c)
|
||||
|
@ -105,11 +103,7 @@ class SqliteSession(Storage):
|
|||
DataCenter(id=id, ipv4_addr=ipv4_addr, ipv6_addr=ipv6_addr, auth=auth)
|
||||
for (id, ipv4_addr, ipv6_addr, auth) in datacenter
|
||||
],
|
||||
user=(
|
||||
User(id=user[0], dc=user[1], bot=bool(user[2]), username=user[3])
|
||||
if user
|
||||
else None
|
||||
),
|
||||
user=(User(id=user[0], dc=user[1], bot=bool(user[2]), username=user[3]) if user else None),
|
||||
state=(
|
||||
UpdateState(
|
||||
pts=state[0],
|
||||
|
@ -166,9 +160,7 @@ class SqliteSession(Storage):
|
|||
|
||||
@staticmethod
|
||||
def _get_or_init_version(c: sqlite3.Cursor) -> int:
|
||||
c.execute(
|
||||
"select name from sqlite_master where type='table' and name='version'"
|
||||
)
|
||||
c.execute("select name from sqlite_master where type='table' and name='version'")
|
||||
if c.fetchone():
|
||||
c.execute("select version from version")
|
||||
tup = c.fetchone()
|
||||
|
|
|
@ -23,9 +23,7 @@ def _bootstrap_get_ty(constructor_id: int) -> Optional[Type["Serializable"]]:
|
|||
from ..mtproto.layer import TYPE_MAPPING as MTPROTO_TYPES
|
||||
|
||||
if API_TYPES.keys() & MTPROTO_TYPES.keys():
|
||||
raise RuntimeError(
|
||||
"generated api and mtproto schemas cannot have colliding constructor identifiers"
|
||||
)
|
||||
raise RuntimeError("generated api and mtproto schemas cannot have colliding constructor identifiers")
|
||||
ALL_TYPES = API_TYPES | MTPROTO_TYPES
|
||||
|
||||
# Signatures don't fully match, but this is a private method
|
||||
|
@ -39,9 +37,7 @@ class Reader:
|
|||
__slots__ = ("_view", "_pos", "_len")
|
||||
|
||||
def __init__(self, buffer: "Buffer") -> None:
|
||||
self._view = (
|
||||
memoryview(buffer) if not isinstance(buffer, memoryview) else buffer
|
||||
)
|
||||
self._view = memoryview(buffer) if not isinstance(buffer, memoryview) else buffer
|
||||
self._pos = 0
|
||||
self._len = len(self._view)
|
||||
|
||||
|
|
|
@ -14,9 +14,7 @@ def _bootstrap_get_deserializer(
|
|||
from ..mtproto.layer import RESPONSE_MAPPING as MTPROTO_DESER
|
||||
|
||||
if API_DESER.keys() & MTPROTO_DESER.keys():
|
||||
raise RuntimeError(
|
||||
"generated api and mtproto schemas cannot have colliding constructor identifiers"
|
||||
)
|
||||
raise RuntimeError("generated api and mtproto schemas cannot have colliding constructor identifiers")
|
||||
ALL_DESER = API_DESER | MTPROTO_DESER
|
||||
|
||||
Request._get_deserializer = ALL_DESER.get # type: ignore [assignment]
|
||||
|
|
|
@ -49,9 +49,7 @@ class Serializable(abc.ABC):
|
|||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, self.__class__):
|
||||
return NotImplemented
|
||||
return all(
|
||||
getattr(self, attr) == getattr(other, attr) for attr in self.__slots__
|
||||
)
|
||||
return all(getattr(self, attr) == getattr(other, attr) for attr in self.__slots__)
|
||||
|
||||
|
||||
def serialize_bytes_to(buffer: bytearray, data: bytes | bytearray | memoryview) -> None:
|
||||
|
|
|
@ -19,6 +19,7 @@ and those you can define when using :meth:`telethon.Client.send_message`:
|
|||
buttons.Callback('Demo', b'data')
|
||||
])
|
||||
"""
|
||||
|
||||
from .._impl.client.types.buttons import (
|
||||
Callback,
|
||||
RequestGeoLocation,
|
||||
|
|
|
@ -26,25 +26,16 @@ def test_auth_key_id() -> None:
|
|||
def test_calc_new_nonce_hash1() -> None:
|
||||
auth_key = get_auth_key()
|
||||
new_nonce = get_new_nonce()
|
||||
assert (
|
||||
auth_key.calc_new_nonce_hash(new_nonce, 1)
|
||||
== 258944117842285651226187582903746985063
|
||||
)
|
||||
assert auth_key.calc_new_nonce_hash(new_nonce, 1) == 258944117842285651226187582903746985063
|
||||
|
||||
|
||||
def test_calc_new_nonce_hash2() -> None:
|
||||
auth_key = get_auth_key()
|
||||
new_nonce = get_new_nonce()
|
||||
assert (
|
||||
auth_key.calc_new_nonce_hash(new_nonce, 2)
|
||||
== 324588944215647649895949797213421233055
|
||||
)
|
||||
assert auth_key.calc_new_nonce_hash(new_nonce, 2) == 324588944215647649895949797213421233055
|
||||
|
||||
|
||||
def test_calc_new_nonce_hash3() -> None:
|
||||
auth_key = get_auth_key()
|
||||
new_nonce = get_new_nonce()
|
||||
assert (
|
||||
auth_key.calc_new_nonce_hash(new_nonce, 3)
|
||||
== 100989356540453064705070297823778556733
|
||||
)
|
||||
assert auth_key.calc_new_nonce_hash(new_nonce, 3) == 100989356540453064705070297823778556733
|
||||
|
|
|
@ -66,14 +66,8 @@ def test_key_from_nonce() -> None:
|
|||
new_nonce = int.from_bytes(bytes(range(32)))
|
||||
|
||||
key, iv = generate_key_data_from_nonce(server_nonce, new_nonce)
|
||||
assert (
|
||||
key
|
||||
== b'\x07X\xf1S;a]$\xf6\xe8\xa9Jo\xcb\xee\nU\xea\xab"\x17\xd7)\\\xa9!=\x1a-}\x16\xa6'
|
||||
)
|
||||
assert (
|
||||
iv
|
||||
== b"Z\x84\x10\x8e\x98\x05el\xe8d\x07\x0e\x16nb\x18\xf6x>\x85\x11G\x1aZ\xb7\x80,\xf2\x00\x01\x02\x03"
|
||||
)
|
||||
assert key == b'\x07X\xf1S;a]$\xf6\xe8\xa9Jo\xcb\xee\nU\xea\xab"\x17\xd7)\\\xa9!=\x1a-}\x16\xa6'
|
||||
assert iv == b"Z\x84\x10\x8e\x98\x05el\xe8d\x07\x0e\x16nb\x18\xf6x>\x85\x11G\x1aZ\xb7\x80,\xf2\x00\x01\x02\x03"
|
||||
|
||||
|
||||
def test_verify_ige_encryption() -> None:
|
||||
|
@ -88,5 +82,7 @@ def test_verify_ige_decryption() -> None:
|
|||
ciphertext = get_test_aes_key_or_iv()
|
||||
key = get_test_aes_key_or_iv()
|
||||
iv = get_test_aes_key_or_iv()
|
||||
expected = b"\xe5wz\xfa\xcd{,\x16\xf7\xac@\xca\xe6\x1e\xf6\x03\xfe\xe6\t\x8f\xb8\xa8\x86\n\xb9\xeeg,\xd7\xe5\xba\xcc"
|
||||
expected = (
|
||||
b"\xe5wz\xfa\xcd{,\x16\xf7\xac@\xca\xe6\x1e\xf6\x03\xfe\xe6\t\x8f\xb8\xa8\x86\n\xb9\xeeg,\xd7\xe5\xba\xcc"
|
||||
)
|
||||
assert decrypt_ige(ciphertext, key, iv) == expected
|
||||
|
|
|
@ -32,10 +32,7 @@ def test_parse_all_entities_markdown() -> None:
|
|||
markdown = "Some **bold** (__strong__), *italics* (_cursive_), inline `code`, a\n```rust\npre\n```\nblock, a [link](https://example.com), and [mentions](tg://user?id=12345678)"
|
||||
text, entities = parse_markdown_message(markdown)
|
||||
|
||||
assert (
|
||||
text
|
||||
== "Some bold (strong), italics (cursive), inline code, a\npre\nblock, a link, and mentions"
|
||||
)
|
||||
assert text == "Some bold (strong), italics (cursive), inline code, a\npre\nblock, a link, and mentions"
|
||||
assert entities == [
|
||||
types.MessageEntityBold(offset=5, length=4),
|
||||
types.MessageEntityBold(offset=11, length=6),
|
||||
|
@ -92,10 +89,7 @@ def test_parse_emoji_html() -> None:
|
|||
def test_parse_all_entities_html() -> None:
|
||||
html = 'Some <b>bold</b> (<strong>strong</strong>), <i>italics</i> (<em>cursive</em>), inline <code>code</code>, a <pre>pre</pre> block, a <a href="https://example.com">link</a>, <details>spoilers</details> and <a href="tg://user?id=12345678">mentions</a>'
|
||||
text, entities = parse_html_message(html)
|
||||
assert (
|
||||
text
|
||||
== "Some bold (strong), italics (cursive), inline code, a pre block, a link, spoilers and mentions"
|
||||
)
|
||||
assert text == "Some bold (strong), italics (cursive), inline code, a pre block, a link, spoilers and mentions"
|
||||
assert entities == [
|
||||
types.MessageEntityBold(offset=5, length=4),
|
||||
types.MessageEntityBold(offset=11, length=6),
|
||||
|
|
|
@ -38,6 +38,10 @@ build-backend = "setuptools.build_meta"
|
|||
version = {attr = "telethon_generator.version.__version__"}
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 120
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["F", "E", "W", "I"]
|
||||
ignore = [
|
||||
"E501", # formatter takes care of lines that are too long besides documentation
|
||||
]
|
||||
|
|
|
@ -17,9 +17,7 @@ from .serde.deserialization import (
|
|||
from .serde.serialization import generate_function, generate_write
|
||||
|
||||
|
||||
def generate_init(
|
||||
writer: SourceWriter, namespaces: set[str], classes: set[str]
|
||||
) -> None:
|
||||
def generate_init(writer: SourceWriter, namespaces: set[str], classes: set[str]) -> None:
|
||||
sorted_cls = list(sorted(classes))
|
||||
sorted_ns = list(sorted(namespaces))
|
||||
|
||||
|
@ -93,9 +91,7 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
|
|||
writer = fs.open(type_path)
|
||||
|
||||
if type_path not in fs:
|
||||
writer.write(
|
||||
"# pyright: reportUnusedImport=false, reportConstantRedefinition=false"
|
||||
)
|
||||
writer.write("# pyright: reportUnusedImport=false, reportConstantRedefinition=false")
|
||||
writer.write("import struct")
|
||||
writer.write("from typing import Optional, Self, Sequence")
|
||||
writer.write("from .. import abcs")
|
||||
|
@ -106,9 +102,7 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
|
|||
generated_type_names.add(f"{ns}{to_class_name(typedef.name)}")
|
||||
|
||||
# class Type(BaseType)
|
||||
writer.write(
|
||||
f"class {to_class_name(typedef.name)}({inner_type_fmt(typedef.ty)}):"
|
||||
)
|
||||
writer.write(f"class {to_class_name(typedef.name)}({inner_type_fmt(typedef.ty)}):")
|
||||
|
||||
# __slots__ = ('params', ...)
|
||||
slots = " ".join(f"'{p.name}'," for p in property_params)
|
||||
|
@ -121,9 +115,7 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
|
|||
|
||||
# def __init__()
|
||||
if property_params:
|
||||
params = "".join(
|
||||
f", {p.name}: {param_type_fmt(p.ty)}" for p in property_params
|
||||
)
|
||||
params = "".join(f", {p.name}: {param_type_fmt(p.ty)}" for p in property_params)
|
||||
writer.write(f" def __init__(_s, *{params}) -> None:")
|
||||
for p in property_params:
|
||||
writer.write(f" _s.{p.name} = {p.name}")
|
||||
|
@ -151,9 +143,7 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
|
|||
raise ValueError("nested function-namespaces are not supported")
|
||||
elif len(functiondef.namespace) == 1:
|
||||
function_namespaces.add(functiondef.namespace[0])
|
||||
function_path = (Path("functions") / functiondef.namespace[0]).with_suffix(
|
||||
".py"
|
||||
)
|
||||
function_path = (Path("functions") / functiondef.namespace[0]).with_suffix(".py")
|
||||
else:
|
||||
function_def_names.add(to_method_name(functiondef.name))
|
||||
function_path = Path("functions/_nons.py")
|
||||
|
@ -173,18 +163,14 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
|
|||
params = "".join(f", {p.name}: {param_type_fmt(p.ty)}" for p in required_params)
|
||||
star = "*" if params else ""
|
||||
return_ty = param_type_fmt(NormalParameter(ty=functiondef.ty, flag=None))
|
||||
writer.write(
|
||||
f"def {to_method_name(functiondef.name)}({star}{params}) -> Request[{return_ty}]:"
|
||||
)
|
||||
writer.write(f"def {to_method_name(functiondef.name)}({star}{params}) -> Request[{return_ty}]:")
|
||||
writer.indent(2)
|
||||
generate_function(writer, functiondef)
|
||||
writer.dedent(2)
|
||||
|
||||
generate_init(fs.open(Path("abcs/__init__.py")), abc_namespaces, abc_class_names)
|
||||
generate_init(fs.open(Path("types/__init__.py")), type_namespaces, type_class_names)
|
||||
generate_init(
|
||||
fs.open(Path("functions/__init__.py")), function_namespaces, function_def_names
|
||||
)
|
||||
generate_init(fs.open(Path("functions/__init__.py")), function_namespaces, function_def_names)
|
||||
|
||||
writer = fs.open(Path("layer.py"))
|
||||
writer.write("# pyright: reportUnusedImport=false")
|
||||
|
@ -194,16 +180,12 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
|
|||
)
|
||||
writer.write("from typing import cast, Type")
|
||||
writer.write(f"LAYER = {tl.layer!r}")
|
||||
writer.write(
|
||||
"TYPE_MAPPING = {t.constructor_id(): t for t in cast(tuple[Type[Serializable]], ("
|
||||
)
|
||||
writer.write("TYPE_MAPPING = {t.constructor_id(): t for t in cast(tuple[Type[Serializable]], (")
|
||||
for name in sorted(generated_type_names):
|
||||
writer.write(f" types.{name},")
|
||||
writer.write("))}")
|
||||
writer.write("RESPONSE_MAPPING = {")
|
||||
for functiondef in tl.functiondefs:
|
||||
writer.write(
|
||||
f" {hex(functiondef.id)}: {function_deserializer_fmt(functiondef)},"
|
||||
)
|
||||
writer.write(f" {hex(functiondef.id)}: {function_deserializer_fmt(functiondef)},")
|
||||
writer.write("}")
|
||||
writer.write("__all__ = ['LAYER', 'TYPE_MAPPING', 'RESPONSE_MAPPING']")
|
||||
|
|
|
@ -50,9 +50,7 @@ _TRIVIAL_STRUCT_MAP = {"int": "i", "long": "q", "double": "d", "Bool": "I"}
|
|||
|
||||
def trivial_struct_fmt(ty: BaseParameter) -> str:
|
||||
try:
|
||||
return (
|
||||
_TRIVIAL_STRUCT_MAP[ty.ty.name] if isinstance(ty, NormalParameter) else "I"
|
||||
)
|
||||
return _TRIVIAL_STRUCT_MAP[ty.ty.name] if isinstance(ty, NormalParameter) else "I"
|
||||
except KeyError:
|
||||
raise ValueError("input param was not trivial")
|
||||
|
||||
|
|
|
@ -38,9 +38,7 @@ def reader_read_fmt(ty: Type, constructor_id: int) -> tuple[str, Optional[str]]:
|
|||
return f"reader.read_serializable({inner_type_fmt(ty)})", "type-abstract"
|
||||
|
||||
|
||||
def generate_normal_param_read(
|
||||
writer: SourceWriter, name: str, param: NormalParameter, constructor_id: int
|
||||
) -> None:
|
||||
def generate_normal_param_read(writer: SourceWriter, name: str, param: NormalParameter, constructor_id: int) -> None:
|
||||
flag_check = f"_{param.flag.name} & {1 << param.flag.index}" if param.flag else None
|
||||
if param.ty.name == "true":
|
||||
if not flag_check:
|
||||
|
@ -55,9 +53,7 @@ def generate_normal_param_read(
|
|||
|
||||
if param.ty.generic_arg:
|
||||
if param.ty.name not in ("Vector", "vector"):
|
||||
raise ValueError(
|
||||
"generic_arg deserialization for non-vectors is not supported"
|
||||
)
|
||||
raise ValueError("generic_arg deserialization for non-vectors is not supported")
|
||||
|
||||
if param.ty.bare:
|
||||
writer.write("__len = reader.read_fmt('<i', 4)[0]")
|
||||
|
@ -70,18 +66,12 @@ def generate_normal_param_read(
|
|||
if is_trivial(generic):
|
||||
fmt = trivial_struct_fmt(generic)
|
||||
size = struct.calcsize(f"<{fmt}")
|
||||
writer.write(
|
||||
f"_{name} = [*reader.read_fmt(f'<{{__len}}{fmt}', __len * {size})]"
|
||||
)
|
||||
writer.write(f"_{name} = [*reader.read_fmt(f'<{{__len}}{fmt}', __len * {size})]")
|
||||
if param.ty.generic_arg.name == "Bool":
|
||||
writer.write(
|
||||
f"assert all(__x in (0xbc799737, 0x0x997275b5) for __x in _{name})"
|
||||
)
|
||||
writer.write(f"assert all(__x in (0xbc799737, 0x0x997275b5) for __x in _{name})")
|
||||
writer.write(f"_{name} = [_{name} == 0x997275b5]")
|
||||
else:
|
||||
fmt_read, type_ignore = reader_read_fmt(
|
||||
param.ty.generic_arg, constructor_id
|
||||
)
|
||||
fmt_read, type_ignore = reader_read_fmt(param.ty.generic_arg, constructor_id)
|
||||
comment = f" # type: ignore [{type_ignore}]" if type_ignore else ""
|
||||
writer.write(f"_{name} = [{fmt_read} for _ in range(__len)]{comment}")
|
||||
else:
|
||||
|
@ -127,9 +117,7 @@ def param_value_fmt(param: Parameter) -> str:
|
|||
def function_deserializer_fmt(defn: Definition) -> str:
|
||||
if defn.ty.generic_arg:
|
||||
if defn.ty.name != ("Vector"):
|
||||
raise ValueError(
|
||||
"generic_arg return for non-boxed-vectors is not supported"
|
||||
)
|
||||
raise ValueError("generic_arg return for non-boxed-vectors is not supported")
|
||||
elif defn.ty.generic_ref:
|
||||
raise ValueError("return for generic refs inside vector is not supported")
|
||||
elif is_trivial(NormalParameter(ty=defn.ty.generic_arg, flag=None)):
|
||||
|
@ -138,13 +126,9 @@ def function_deserializer_fmt(defn: Definition) -> str:
|
|||
elif defn.ty.generic_arg.name == "long":
|
||||
return "deserialize_i64_list"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"return for trivial arg {defn.ty.generic_arg} is not supported"
|
||||
)
|
||||
raise ValueError(f"return for trivial arg {defn.ty.generic_arg} is not supported")
|
||||
elif defn.ty.generic_arg.bare:
|
||||
raise ValueError(
|
||||
"return for non-boxed serializables inside a vector is not supported"
|
||||
)
|
||||
raise ValueError("return for non-boxed serializables inside a vector is not supported")
|
||||
else:
|
||||
return f"list_deserializer({inner_type_fmt(defn.ty.generic_arg)})"
|
||||
elif defn.ty.generic_ref:
|
||||
|
|
|
@ -15,15 +15,11 @@ def param_value_expr(param: Parameter) -> str:
|
|||
return f"{pre}{mid}{suf}"
|
||||
|
||||
|
||||
def generate_buffer_append(
|
||||
writer: SourceWriter, buffer: str, name: str, ty: Type
|
||||
) -> None:
|
||||
def generate_buffer_append(writer: SourceWriter, buffer: str, name: str, ty: Type) -> None:
|
||||
if is_trivial(NormalParameter(ty=ty, flag=None)):
|
||||
fmt = trivial_struct_fmt(NormalParameter(ty=ty, flag=None))
|
||||
if ty.name == "Bool":
|
||||
writer.write(
|
||||
f"{buffer} += struct.pack(f'<{fmt}', (0x997275b5 if {name} else 0xbc799737))"
|
||||
)
|
||||
writer.write(f"{buffer} += struct.pack(f'<{fmt}', (0x997275b5 if {name} else 0xbc799737))")
|
||||
else:
|
||||
writer.write(f"{buffer} += struct.pack(f'<{fmt}', {name})")
|
||||
elif ty.generic_ref or ty.name == "Object":
|
||||
|
@ -58,9 +54,7 @@ def generate_normal_param_write(
|
|||
|
||||
if param.ty.generic_arg:
|
||||
if param.ty.name not in ("Vector", "vector"):
|
||||
raise ValueError(
|
||||
"generic_arg deserialization for non-vectors is not supported"
|
||||
)
|
||||
raise ValueError("generic_arg deserialization for non-vectors is not supported")
|
||||
|
||||
if param.ty.bare:
|
||||
writer.write(f"{buffer} += struct.pack('<i', len({name}))")
|
||||
|
@ -76,9 +70,7 @@ def generate_normal_param_write(
|
|||
f"{buffer} += struct.pack(f'<{{len({name})}}{fmt}', *(0x997275b5 if {tmp} else 0xbc799737 for {tmp} in {name}))"
|
||||
)
|
||||
else:
|
||||
writer.write(
|
||||
f"{buffer} += struct.pack(f'<{{len({name})}}{fmt}', *{name})"
|
||||
)
|
||||
writer.write(f"{buffer} += struct.pack(f'<{{len({name})}}{fmt}', *{name})")
|
||||
else:
|
||||
tmp = next(tmp_names)
|
||||
writer.write(f"for {tmp} in {name}:")
|
||||
|
@ -110,9 +102,7 @@ def generate_write(writer: SourceWriter, defn: Definition) -> None:
|
|||
else f"(0 if self.{p.name} is None else {1 << p.ty.flag.index})"
|
||||
)
|
||||
for p in defn.params
|
||||
if isinstance(p.ty, NormalParameter)
|
||||
and p.ty.flag
|
||||
and p.ty.flag.name == param.name
|
||||
if isinstance(p.ty, NormalParameter) and p.ty.flag and p.ty.flag.name == param.name
|
||||
)
|
||||
writer.write(f"_{param.name} = {flags or 0}")
|
||||
|
||||
|
@ -123,9 +113,7 @@ def generate_write(writer: SourceWriter, defn: Definition) -> None:
|
|||
for param in iter:
|
||||
if not isinstance(param.ty, NormalParameter):
|
||||
raise RuntimeError("FlagsParameter should be considered trivial")
|
||||
generate_normal_param_write(
|
||||
writer, tmp_names, "buffer", f"self.{param.name}", param.ty
|
||||
)
|
||||
generate_normal_param_write(writer, tmp_names, "buffer", f"self.{param.name}", param.ty)
|
||||
|
||||
|
||||
def generate_function(writer: SourceWriter, defn: Definition) -> None:
|
||||
|
@ -148,9 +136,7 @@ def generate_function(writer: SourceWriter, defn: Definition) -> None:
|
|||
else f"(0 if {p.name} is None else {1 << p.ty.flag.index})"
|
||||
)
|
||||
for p in defn.params
|
||||
if isinstance(p.ty, NormalParameter)
|
||||
and p.ty.flag
|
||||
and p.ty.flag.name == param.name
|
||||
if isinstance(p.ty, NormalParameter) and p.ty.flag and p.ty.flag.name == param.name
|
||||
)
|
||||
writer.write(f"{param.name} = {flags or 0}")
|
||||
|
||||
|
@ -161,7 +147,5 @@ def generate_function(writer: SourceWriter, defn: Definition) -> None:
|
|||
for param in iter:
|
||||
if not isinstance(param.ty, NormalParameter):
|
||||
raise RuntimeError("FlagsParameter should be considered trivial")
|
||||
generate_normal_param_write(
|
||||
writer, tmp_names, "_buffer", param.name, param.ty
|
||||
)
|
||||
generate_normal_param_write(writer, tmp_names, "_buffer", param.name, param.ty)
|
||||
writer.write("return Request(b'' + _buffer)")
|
||||
|
|
|
@ -35,6 +35,4 @@ def load_tl_file(path: str | Path) -> ParsedTl:
|
|||
else:
|
||||
functiondefs.append(definition)
|
||||
|
||||
return ParsedTl(
|
||||
layer=layer, typedefs=list(typedefs), functiondefs=list(functiondefs)
|
||||
)
|
||||
return ParsedTl(layer=layer, typedefs=list(typedefs), functiondefs=list(functiondefs))
|
||||
|
|
|
@ -14,9 +14,7 @@ def gen_py_code(
|
|||
functiondefs: Optional[list[Definition]] = None,
|
||||
) -> str:
|
||||
fs = FakeFs()
|
||||
generate(
|
||||
fs, ParsedTl(layer=0, typedefs=typedefs or [], functiondefs=functiondefs or [])
|
||||
)
|
||||
generate(fs, ParsedTl(layer=0, typedefs=typedefs or [], functiondefs=functiondefs or []))
|
||||
generated = bytearray()
|
||||
for path, data in fs._files.items():
|
||||
if path.stem not in ("__init__", "layer"):
|
||||
|
@ -27,9 +25,7 @@ def gen_py_code(
|
|||
|
||||
|
||||
def test_generic_functions_use_bytes_parameters() -> None:
|
||||
definitions = get_definitions(
|
||||
"invokeWithLayer#da9b0d0d {X:Type} layer:int query:!X = X;"
|
||||
)
|
||||
definitions = get_definitions("invokeWithLayer#da9b0d0d {X:Type} layer:int query:!X = X;")
|
||||
result = gen_py_code(functiondefs=definitions)
|
||||
assert "invoke_with_layer" in result
|
||||
assert "query: _bytes" in result
|
||||
|
|
|
@ -54,18 +54,14 @@ def test_valid_param() -> None:
|
|||
assert Parameter.from_str("foo:!bar") == Parameter(
|
||||
name="foo",
|
||||
ty=NormalParameter(
|
||||
ty=Type(
|
||||
namespace=[], name="bar", bare=True, generic_ref=True, generic_arg=None
|
||||
),
|
||||
ty=Type(namespace=[], name="bar", bare=True, generic_ref=True, generic_arg=None),
|
||||
flag=None,
|
||||
),
|
||||
)
|
||||
assert Parameter.from_str("foo:bar.1?baz") == Parameter(
|
||||
name="foo",
|
||||
ty=NormalParameter(
|
||||
ty=Type(
|
||||
namespace=[], name="baz", bare=True, generic_ref=False, generic_arg=None
|
||||
),
|
||||
ty=Type(namespace=[], name="baz", bare=True, generic_ref=False, generic_arg=None),
|
||||
flag=Flag(
|
||||
name="bar",
|
||||
index=1,
|
||||
|
|
|
@ -11,9 +11,7 @@ def test_empty_simple() -> None:
|
|||
|
||||
|
||||
def test_simple() -> None:
|
||||
assert Type.from_str("foo") == Type(
|
||||
namespace=[], name="foo", bare=True, generic_ref=False, generic_arg=None
|
||||
)
|
||||
assert Type.from_str("foo") == Type(namespace=[], name="foo", bare=True, generic_ref=False, generic_arg=None)
|
||||
|
||||
|
||||
@mark.parametrize("ty", [".", "..", ".foo", "foo.", "foo..foo", ".foo."])
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
Check formatting, type-check and run offline tests.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
|
@ -15,9 +16,7 @@ def run(*args: str) -> int:
|
|||
def main() -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
exit(
|
||||
run("isort", ".", "-c", "--profile", "black", "--gitignore")
|
||||
or run("black", ".", "--check", "--extend-exclude", BLACK_IGNORE)
|
||||
or run("mypy", "--strict", ".")
|
||||
run("mypy", "--strict", ".")
|
||||
or run("ruff", "check", ".")
|
||||
or run("sphinx", "-M", "dummy", "client/doc", tmp_dir, "-n", "-W")
|
||||
or run("pytest", ".", "-m", "not net")
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
Run `telethon_generator.codegen` on both `api.tl` and `mtproto.tl` to output
|
||||
corresponding Python code in the default directories under the `client/`.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
|
|
|
@ -110,13 +110,13 @@ def main() -> None:
|
|||
function.args.args[0].annotation = None
|
||||
|
||||
if isinstance(function, ast.AsyncFunctionDef):
|
||||
call = ast.Await(value=call)
|
||||
call = ast.Await(value=call) # type: ignore [arg-type]
|
||||
|
||||
match function.returns:
|
||||
case ast.Constant(value=None):
|
||||
call = ast.Expr(value=call)
|
||||
call = ast.Expr(value=call) # type: ignore [arg-type]
|
||||
case _:
|
||||
call = ast.Return(value=call)
|
||||
call = ast.Return(value=call) # type: ignore [arg-type]
|
||||
|
||||
function.body.append(call)
|
||||
class_body.append(function)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
Run `sphinx-build` to create HTML documentation and detect errors.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
Sort imports and format code.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user