Make pyright happy

This commit is contained in:
Lonami Exo 2024-03-16 19:05:58 +01:00
parent 854096e9d3
commit 033b56f1d3
55 changed files with 435 additions and 309 deletions

View File

@ -25,6 +25,7 @@ def serialize_builtin(value: Any) -> bytes:
def overhead(obj: Obj) -> None: def overhead(obj: Obj) -> None:
x: Any
for v in obj.__dict__.values(): for v in obj.__dict__.values():
for x in v if isinstance(v, list) else [v]: for x in v if isinstance(v, list) else [v]:
if isinstance(x, Obj): if isinstance(x, Obj):
@ -34,6 +35,7 @@ def overhead(obj: Obj) -> None:
def strategy_concat(obj: Obj) -> bytes: def strategy_concat(obj: Obj) -> bytes:
x: Any
res = b"" res = b""
for v in obj.__dict__.values(): for v in obj.__dict__.values():
for x in v if isinstance(v, list) else [v]: for x in v if isinstance(v, list) else [v]:
@ -45,6 +47,7 @@ def strategy_concat(obj: Obj) -> bytes:
def strategy_append(obj: Obj) -> bytes: def strategy_append(obj: Obj) -> bytes:
x: Any
res = bytearray() res = bytearray()
for v in obj.__dict__.values(): for v in obj.__dict__.values():
for x in v if isinstance(v, list) else [v]: for x in v if isinstance(v, list) else [v]:
@ -57,6 +60,7 @@ def strategy_append(obj: Obj) -> bytes:
def strategy_append_reuse(obj: Obj) -> bytes: def strategy_append_reuse(obj: Obj) -> bytes:
def do_append(o: Obj, res: bytearray) -> None: def do_append(o: Obj, res: bytearray) -> None:
x: Any
for v in o.__dict__.values(): for v in o.__dict__.values():
for x in v if isinstance(v, list) else [v]: for x in v if isinstance(v, list) else [v]:
if isinstance(x, Obj): if isinstance(x, Obj):
@ -70,15 +74,18 @@ def strategy_append_reuse(obj: Obj) -> bytes:
def strategy_join(obj: Obj) -> bytes: def strategy_join(obj: Obj) -> bytes:
return b"".join( def iterator() -> Iterator[bytes]:
strategy_join(x) if isinstance(x, Obj) else serialize_builtin(x) x: Any
for v in obj.__dict__.values() for v in obj.__dict__.values():
for x in (v if isinstance(v, list) else [v]) for x in v if isinstance(v, list) else [v]:
) yield strategy_join(x) if isinstance(x, Obj) else serialize_builtin(x)
return b"".join(iterator())
def strategy_join_flat(obj: Obj) -> bytes: def strategy_join_flat(obj: Obj) -> bytes:
def flatten(o: Obj) -> Iterator[bytes]: def flatten(o: Obj) -> Iterator[bytes]:
x: Any
for v in o.__dict__.values(): for v in o.__dict__.values():
for x in v if isinstance(v, list) else [v]: for x in v if isinstance(v, list) else [v]:
if isinstance(x, Obj): if isinstance(x, Obj):
@ -91,6 +98,7 @@ def strategy_join_flat(obj: Obj) -> bytes:
def strategy_write(obj: Obj) -> bytes: def strategy_write(obj: Obj) -> bytes:
def do_write(o: Obj, buffer: io.BytesIO) -> None: def do_write(o: Obj, buffer: io.BytesIO) -> None:
x: Any
for v in o.__dict__.values(): for v in o.__dict__.values():
for x in v if isinstance(v, list) else [v]: for x in v if isinstance(v, list) else [v]:
if isinstance(x, Obj): if isinstance(x, Obj):

View File

@ -6,7 +6,7 @@ DATA = 42
def overhead(n: int) -> None: def overhead(n: int) -> None:
n n = n
def strategy_bool(n: int) -> bool: def strategy_bool(n: int) -> bool:

View File

@ -3,7 +3,7 @@ from pathlib import Path
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from setuptools import build_meta as _orig from setuptools import build_meta as _orig
from setuptools.build_meta import * # noqa: F403 from setuptools.build_meta import * # noqa: F403 # pyright: ignore [reportWildcardImportFromLibrary]
def gen_types_if_needed() -> None: def gen_types_if_needed() -> None:

View File

@ -223,6 +223,7 @@ async def check_password(
if not two_factor_auth.check_p_and_g(algo.p, algo.g): if not two_factor_auth.check_p_and_g(algo.p, algo.g):
token = await get_password_information(self) token = await get_password_information(self)
algo = token._password.current_algo
if not isinstance( if not isinstance(
algo, algo,
types.PasswordKdfAlgoSha256Sha256Pbkdf2HmacshA512Iter100000Sha256ModPow, types.PasswordKdfAlgoSha256Sha256Pbkdf2HmacshA512Iter100000Sha256ModPow,

View File

@ -2059,5 +2059,4 @@ class Client:
exc: Optional[BaseException], exc: Optional[BaseException],
tb: Optional[TracebackType], tb: Optional[TracebackType],
) -> None: ) -> None:
exc_type, exc, tb
await disconnect(self) await disconnect(self)

View File

@ -444,6 +444,7 @@ class FileBytesList(AsyncList[bytes]):
if result.bytes: if result.bytes:
self._offset += MAX_CHUNK_SIZE self._offset += MAX_CHUNK_SIZE
assert isinstance(result.bytes, bytes)
self._buffer.append(result.bytes) self._buffer.append(result.bytes)
self._done = len(result.bytes) < MAX_CHUNK_SIZE self._done = len(result.bytes) < MAX_CHUNK_SIZE

View File

@ -39,11 +39,13 @@ async def send_message(
noforwards=not text.can_forward, noforwards=not text.can_forward,
update_stickersets_order=False, update_stickersets_order=False,
peer=peer, peer=peer,
reply_to=types.InputReplyToMessage( reply_to=(
types.InputReplyToMessage(
reply_to_msg_id=text.replied_message_id, top_msg_id=None reply_to_msg_id=text.replied_message_id, top_msg_id=None
) )
if text.replied_message_id if text.replied_message_id
else None, else None
),
message=message, message=message,
random_id=random_id, random_id=random_id,
reply_markup=getattr(text._raw, "reply_markup", None), reply_markup=getattr(text._raw, "reply_markup", None),
@ -63,11 +65,11 @@ async def send_message(
noforwards=False, noforwards=False,
update_stickersets_order=False, update_stickersets_order=False,
peer=peer, peer=peer,
reply_to=types.InputReplyToMessage( reply_to=(
reply_to_msg_id=reply_to, top_msg_id=None types.InputReplyToMessage(reply_to_msg_id=reply_to, top_msg_id=None)
)
if reply_to if reply_to
else None, else None
),
message=message, message=message,
random_id=random_id, random_id=random_id,
reply_markup=btns.build_keyboard(buttons), reply_markup=btns.build_keyboard(buttons),
@ -83,11 +85,14 @@ async def send_message(
{}, {},
out=result.out, out=result.out,
id=result.id, id=result.id,
from_id=types.PeerUser(user_id=self._session.user.id) from_id=(
types.PeerUser(user_id=self._session.user.id)
if self._session.user if self._session.user
else None, else None
),
peer_id=packed._to_peer(), peer_id=packed._to_peer(),
reply_to=types.MessageReplyHeader( reply_to=(
types.MessageReplyHeader(
reply_to_scheduled=False, reply_to_scheduled=False,
forum_topic=False, forum_topic=False,
reply_to_msg_id=reply_to, reply_to_msg_id=reply_to,
@ -95,7 +100,8 @@ async def send_message(
reply_to_top_id=None, reply_to_top_id=None,
) )
if reply_to if reply_to
else None, else None
),
date=result.date, date=result.date,
message=message, message=message,
media=result.media, media=result.media,
@ -593,8 +599,8 @@ def build_message_map(
else: else:
return MessageMap(client, peer, {}, {}) return MessageMap(client, peer, {}, {})
random_id_to_id = {} random_id_to_id: Dict[int, int] = {}
id_to_message = {} id_to_message: Dict[int, Message] = {}
for update in updates: for update in updates:
if isinstance(update, types.UpdateMessageId): if isinstance(update, types.UpdateMessageId):
random_id_to_id[update.random_id] = update.id random_id_to_id[update.random_id] = update.id

View File

@ -8,6 +8,7 @@ from typing import (
Callable, Callable,
List, List,
Optional, Optional,
Sequence,
Type, Type,
TypeVar, TypeVar,
) )
@ -103,8 +104,8 @@ def process_socket_updates(client: Client, all_updates: List[abcs.Updates]) -> N
def extend_update_queue( def extend_update_queue(
client: Client, client: Client,
updates: List[abcs.Update], updates: List[abcs.Update],
users: List[abcs.User], users: Sequence[abcs.User],
chats: List[abcs.Chat], chats: Sequence[abcs.Chat],
) -> None: ) -> None:
chat_map = build_chat_map(client, users, chats) chat_map = build_chat_map(client, users, chats)

View File

@ -91,21 +91,22 @@ async def get_chats(self: Client, chats: Sequence[ChatLike]) -> List[Chat]:
input_channels.append(packed._to_input_channel()) input_channels.append(packed._to_input_channel())
if input_users: if input_users:
users = await self(functions.users.get_users(id=input_users)) ret_users = await self(functions.users.get_users(id=input_users))
users = list(ret_users)
else: else:
users = [] users = []
if input_chats: if input_chats:
ret_chats = await self(functions.messages.get_chats(id=input_chats)) ret_chats = await self(functions.messages.get_chats(id=input_chats))
assert isinstance(ret_chats, types.messages.Chats) assert isinstance(ret_chats, types.messages.Chats)
groups = ret_chats.chats groups = list(ret_chats.chats)
else: else:
groups = [] groups = []
if input_channels: if input_channels:
ret_chats = await self(functions.channels.get_channels(id=input_channels)) ret_chats = await self(functions.channels.get_channels(id=input_channels))
assert isinstance(ret_chats, types.messages.Chats) assert isinstance(ret_chats, types.messages.Chats)
channels = ret_chats.chats channels = list(ret_chats.chats)
else: else:
channels = [] channels = []
@ -133,7 +134,7 @@ async def resolve_to_packed(
ty = PackedType.USER ty = PackedType.USER
elif isinstance(chat, Group): elif isinstance(chat, Group):
ty = PackedType.MEGAGROUP if chat.is_megagroup else PackedType.CHAT ty = PackedType.MEGAGROUP if chat.is_megagroup else PackedType.CHAT
elif isinstance(chat, Channel): else:
ty = PackedType.BROADCAST ty = PackedType.BROADCAST
return PackedChat(ty=ty, id=chat.id, access_hash=0) return PackedChat(ty=ty, id=chat.id, access_hash=0)

View File

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Dict, List, Optional, Self, Union from typing import TYPE_CHECKING, Dict, Optional, Self, Sequence, Union
from ...tl import abcs, types from ...tl import abcs, types
from ..types import Chat, Message, expand_peer, peer_id from ..types import Chat, Message, expand_peer, peer_id
@ -69,7 +69,7 @@ class MessageDeleted(Event):
The chat is only known when the deletion occurs in broadcast channels or supergroups. The chat is only known when the deletion occurs in broadcast channels or supergroups.
""" """
def __init__(self, msg_ids: List[int], channel_id: Optional[int]) -> None: def __init__(self, msg_ids: Sequence[int], channel_id: Optional[int]) -> None:
self._msg_ids = msg_ids self._msg_ids = msg_ids
self._channel_id = channel_id self._channel_id = channel_id
@ -85,7 +85,7 @@ class MessageDeleted(Event):
return None return None
@property @property
def message_ids(self) -> List[int]: def message_ids(self) -> Sequence[int]:
""" """
The message identifiers of the messages that were deleted. The message identifiers of the messages that were deleted.
""" """

View File

@ -40,7 +40,7 @@ class ButtonCallback(Event):
@property @property
def data(self) -> bytes: def data(self) -> bytes:
assert self._raw.data is not None assert isinstance(self._raw.data, bytes)
return self._raw.data return self._raw.data
async def answer( async def answer(

View File

@ -1,7 +1,19 @@
from collections import deque from collections import deque
from html import escape from html import escape
from html.parser import HTMLParser from html.parser import HTMLParser
from typing import Any, Deque, Dict, Iterable, List, Optional, Tuple, Type, cast from typing import (
Any,
Callable,
Deque,
Dict,
Iterable,
List,
Optional,
Tuple,
Type,
Union,
cast,
)
from ...tl.abcs import MessageEntity from ...tl.abcs import MessageEntity
from ...tl.types import ( from ...tl.types import (
@ -30,13 +42,11 @@ class HTMLToTelegramParser(HTMLParser):
self._open_tags: Deque[str] = deque() self._open_tags: Deque[str] = deque()
self._open_tags_meta: Deque[Optional[str]] = deque() self._open_tags_meta: Deque[Optional[str]] = deque()
def handle_starttag( def handle_starttag(self, tag: str, attrs: List[Tuple[str, Optional[str]]]) -> None:
self, tag: str, attrs_seq: List[Tuple[str, Optional[str]]]
) -> None:
self._open_tags.appendleft(tag) self._open_tags.appendleft(tag)
self._open_tags_meta.appendleft(None) self._open_tags_meta.appendleft(None)
attrs = dict(attrs_seq) attributes = dict(attrs)
EntityType: Optional[Type[MessageEntity]] = None EntityType: Optional[Type[MessageEntity]] = None
args = {} args = {}
if tag == "strong" or tag == "b": if tag == "strong" or tag == "b":
@ -61,7 +71,7 @@ class HTMLToTelegramParser(HTMLParser):
# inside <pre> tags # inside <pre> tags
pre = self._building_entities["pre"] pre = self._building_entities["pre"]
assert isinstance(pre, MessageEntityPre) assert isinstance(pre, MessageEntityPre)
if cls := attrs.get("class"): if cls := attributes.get("class"):
pre.language = cls[len("language-") :] pre.language = cls[len("language-") :]
except KeyError: except KeyError:
EntityType = MessageEntityCode EntityType = MessageEntityCode
@ -69,7 +79,7 @@ class HTMLToTelegramParser(HTMLParser):
EntityType = MessageEntityPre EntityType = MessageEntityPre
args["language"] = "" args["language"] = ""
elif tag == "a": elif tag == "a":
url = attrs.get("href") url = attributes.get("href")
if not url: if not url:
return return
if url.startswith("mailto:"): if url.startswith("mailto:"):
@ -94,18 +104,18 @@ class HTMLToTelegramParser(HTMLParser):
**args, **args,
) )
def handle_data(self, text: str) -> None: def handle_data(self, data: str) -> None:
previous_tag = self._open_tags[0] if len(self._open_tags) > 0 else "" previous_tag = self._open_tags[0] if len(self._open_tags) > 0 else ""
if previous_tag == "a": if previous_tag == "a":
url = self._open_tags_meta[0] url = self._open_tags_meta[0]
if url: if url:
text = url data = url
for entity in self._building_entities.values(): for entity in self._building_entities.values():
assert hasattr(entity, "length") assert hasattr(entity, "length")
entity.length += len(text) setattr(entity, "length", getattr(entity, "length", 0) + len(data))
self.text += text self.text += data
def handle_endtag(self, tag: str) -> None: def handle_endtag(self, tag: str) -> None:
try: try:
@ -114,7 +124,7 @@ class HTMLToTelegramParser(HTMLParser):
except IndexError: except IndexError:
pass pass
entity = self._building_entities.pop(tag, None) entity = self._building_entities.pop(tag, None)
if entity and hasattr(entity, "length") and entity.length: if entity and getattr(entity, "length", None):
self.entities.append(entity) self.entities.append(entity)
@ -135,7 +145,9 @@ def parse(html: str) -> Tuple[str, List[MessageEntity]]:
return del_surrogate(text), parser.entities return del_surrogate(text), parser.entities
ENTITY_TO_FORMATTER = { ENTITY_TO_FORMATTER: Dict[
Type[MessageEntity], Union[Tuple[str, str], Callable[[Any, str], Tuple[str, str]]]
] = {
MessageEntityBold: ("<strong>", "</strong>"), MessageEntityBold: ("<strong>", "</strong>"),
MessageEntityItalic: ("<em>", "</em>"), MessageEntityItalic: ("<em>", "</em>"),
MessageEntityCode: ("<code>", "</code>"), MessageEntityCode: ("<code>", "</code>"),
@ -173,18 +185,20 @@ def unparse(text: str, entities: Iterable[MessageEntity]) -> str:
text = add_surrogate(text) text = add_surrogate(text)
insert_at: List[Tuple[int, str]] = [] insert_at: List[Tuple[int, str]] = []
for entity in entities: for e in entities:
assert hasattr(entity, "offset") and hasattr(entity, "length") offset, length = getattr(e, "offset", None), getattr(e, "length", None)
s = entity.offset assert isinstance(offset, int) and isinstance(length, int)
e = entity.offset + entity.length
delimiter = ENTITY_TO_FORMATTER.get(type(entity), None) h = offset
t = offset + length
delimiter = ENTITY_TO_FORMATTER.get(type(e), None)
if delimiter: if delimiter:
if callable(delimiter): if callable(delimiter):
delim = delimiter(entity, text[s:e]) delim = delimiter(e, text[h:t])
else: else:
delim = delimiter delim = delimiter
insert_at.append((s, delim[0])) insert_at.append((h, delim[0]))
insert_at.append((e, delim[1])) insert_at.append((t, delim[1]))
insert_at.sort(key=lambda t: t[0]) insert_at.sort(key=lambda t: t[0])
next_escape_bound = len(text) next_escape_bound = len(text)

View File

@ -1,5 +1,5 @@
import re import re
from typing import Any, Iterator, List, Tuple from typing import Any, Dict, Iterator, List, Tuple, Type
import markdown_it import markdown_it
import markdown_it.token import markdown_it.token
@ -19,7 +19,7 @@ from ...tl.types import (
from .strings import add_surrogate, del_surrogate, within_surrogate from .strings import add_surrogate, del_surrogate, within_surrogate
MARKDOWN = markdown_it.MarkdownIt().enable("strikethrough") MARKDOWN = markdown_it.MarkdownIt().enable("strikethrough")
DELIMITERS = { DELIMITERS: Dict[Type[MessageEntity], Tuple[str, str]] = {
MessageEntityBlockquote: ("> ", ""), MessageEntityBlockquote: ("> ", ""),
MessageEntityBold: ("**", "**"), MessageEntityBold: ("**", "**"),
MessageEntityCode: ("`", "`"), MessageEntityCode: ("`", "`"),
@ -81,7 +81,9 @@ def parse(message: str) -> Tuple[str, List[MessageEntity]]:
else: else:
for entity in reversed(entities): for entity in reversed(entities):
if isinstance(entity, ty): if isinstance(entity, ty):
entity.length = len(message) - entity.offset setattr(
entity, "length", len(message) - getattr(entity, "offset", 0)
)
break break
parsed = MARKDOWN.parse(add_surrogate(message.strip())) parsed = MARKDOWN.parse(add_surrogate(message.strip()))
@ -156,24 +158,25 @@ def unparse(text: str, entities: List[MessageEntity]) -> str:
text = add_surrogate(text) text = add_surrogate(text)
insert_at: List[Tuple[int, str]] = [] insert_at: List[Tuple[int, str]] = []
for entity in entities: for e in entities:
assert hasattr(entity, "offset") offset, length = getattr(e, "offset", None), getattr(e, "length", None)
assert hasattr(entity, "length") assert isinstance(offset, int) and isinstance(length, int)
s = entity.offset
e = entity.offset + entity.length h = offset
delimiter = DELIMITERS.get(type(entity), None) t = offset + length
delimiter = DELIMITERS.get(type(e), None)
if delimiter: if delimiter:
insert_at.append((s, delimiter[0])) insert_at.append((h, delimiter[0]))
insert_at.append((e, delimiter[1])) insert_at.append((t, delimiter[1]))
elif isinstance(entity, MessageEntityPre): elif isinstance(e, MessageEntityPre):
insert_at.append((s, f"```{entity.language}\n")) insert_at.append((h, f"```{e.language}\n"))
insert_at.append((e, "```\n")) insert_at.append((t, "```\n"))
elif isinstance(entity, MessageEntityTextUrl): elif isinstance(e, MessageEntityTextUrl):
insert_at.append((s, "[")) insert_at.append((h, "["))
insert_at.append((e, f"]({entity.url})")) insert_at.append((t, f"]({e.url})"))
elif isinstance(entity, MessageEntityMentionName): elif isinstance(e, MessageEntityMentionName):
insert_at.append((s, "[")) insert_at.append((h, "["))
insert_at.append((e, f"](tg://user?id={entity.user_id})")) insert_at.append((t, f"](tg://user?id={e.user_id})"))
insert_at.sort(key=lambda t: t[0]) insert_at.sort(key=lambda t: t[0])
while insert_at: while insert_at:

View File

@ -8,9 +8,11 @@ def add_surrogate(text: str) -> str:
return "".join( return "".join(
# SMP -> Surrogate Pairs (Telegram offsets are calculated with these). # SMP -> Surrogate Pairs (Telegram offsets are calculated with these).
# See https://en.wikipedia.org/wiki/Plane_(Unicode)#Overview for more. # See https://en.wikipedia.org/wiki/Plane_(Unicode)#Overview for more.
(
"".join(chr(y) for y in struct.unpack("<HH", x.encode("utf-16le"))) "".join(chr(y) for y in struct.unpack("<HH", x.encode("utf-16le")))
if (0x10000 <= ord(x) <= 0x10FFFF) if (0x10000 <= ord(x) <= 0x10FFFF)
else x else x
)
for x in text for x in text
) )
@ -43,32 +45,38 @@ def strip_text(text: str, entities: List[MessageEntity]) -> str:
if not entities: if not entities:
return text.strip() return text.strip()
assert all(isinstance(getattr(e, "offset"), int) for e in entities)
while text and text[-1].isspace(): while text and text[-1].isspace():
e = entities[-1] e = entities[-1]
assert hasattr(e, "offset") and hasattr(e, "length") offset, length = getattr(e, "offset", None), getattr(e, "length", None)
if e.offset + e.length == len(text): assert isinstance(offset, int) and isinstance(length, int)
if e.length == 1:
if offset + length == len(text):
if length == 1:
del entities[-1] del entities[-1]
if not entities: if not entities:
return text.strip() return text.strip()
else: else:
e.length -= 1 length -= 1
text = text[:-1] text = text[:-1]
while text and text[0].isspace(): while text and text[0].isspace():
for i in reversed(range(len(entities))): for i in reversed(range(len(entities))):
e = entities[i] e = entities[i]
assert hasattr(e, "offset") and hasattr(e, "length") offset, length = getattr(e, "offset", None), getattr(e, "length", None)
if e.offset != 0: assert isinstance(offset, int) and isinstance(length, int)
e.offset -= 1
if offset != 0:
setattr(e, "offset", offset - 1)
continue continue
if e.length == 1: if length == 1:
del entities[0] del entities[0]
if not entities: if not entities:
return text.lstrip() return text.lstrip()
else: else:
e.length -= 1 setattr(e, "length", length - 1)
text = text[1:] text = text[1:]

View File

@ -29,6 +29,7 @@ class Callback(InlineButton):
This data will be received by :class:`telethon.events.ButtonCallback` when the button is pressed. This data will be received by :class:`telethon.events.ButtonCallback` when the button is pressed.
""" """
assert isinstance(self._raw, types.KeyboardButtonCallback) assert isinstance(self._raw, types.KeyboardButtonCallback)
assert isinstance(self._raw.data, bytes)
return self._raw.data return self._raw.data
@data.setter @data.setter

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import itertools import itertools
import sys import sys
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, DefaultDict, Dict, List, Optional, Union from typing import TYPE_CHECKING, DefaultDict, Dict, List, Optional, Sequence, Union
from ....session import PackedChat from ....session import PackedChat
from ....tl import abcs, types from ....tl import abcs, types
@ -19,13 +19,15 @@ ChatLike = Union[Chat, PackedChat, int, str]
def build_chat_map( def build_chat_map(
client: Client, users: List[abcs.User], chats: List[abcs.Chat] client: Client, users: Sequence[abcs.User], chats: Sequence[abcs.Chat]
) -> Dict[int, Chat]: ) -> Dict[int, Chat]:
users_iter = (User._from_raw(u) for u in users) users_iter = (User._from_raw(u) for u in users)
chats_iter = ( chats_iter = (
(
Channel._from_raw(c) Channel._from_raw(c)
if isinstance(c, (types.Channel, types.ChannelForbidden)) and c.broadcast if isinstance(c, (types.Channel, types.ChannelForbidden)) and c.broadcast
else Group._from_raw(client, c) else Group._from_raw(client, c)
)
for c in chats for c in chats
) )

View File

@ -5,7 +5,17 @@ import urllib.parse
from inspect import isawaitable from inspect import isawaitable
from io import BufferedWriter from io import BufferedWriter
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Coroutine, List, Optional, Protocol, Self, Union from typing import (
TYPE_CHECKING,
Any,
Coroutine,
List,
Optional,
Protocol,
Self,
Sequence,
Union,
)
from ...tl import abcs, types from ...tl import abcs, types
from .meta import NoPublicConstructor from .meta import NoPublicConstructor
@ -43,7 +53,7 @@ stripped_size_header = bytes.fromhex(
stripped_size_footer = bytes.fromhex("FFD9") stripped_size_footer = bytes.fromhex("FFD9")
def expand_stripped_size(data: bytes) -> bytes: def expand_stripped_size(data: bytes | bytearray | memoryview) -> bytes:
header = bytearray(stripped_size_header) header = bytearray(stripped_size_header)
header[164] = data[1] header[164] = data[1]
header[166] = data[2] header[166] = data[2]
@ -87,13 +97,14 @@ class InFileLike(Protocol):
It's only used in function parameters. It's only used in function parameters.
""" """
def read(self, n: int) -> Union[bytes, Coroutine[Any, Any, bytes]]: def read(self, n: int, /) -> Union[bytes, Coroutine[Any, Any, bytes]]:
""" """
Read from the file or buffer. Read from the file or buffer.
:param n: :param n:
Maximum amount of bytes that should be returned. Maximum amount of bytes that should be returned.
""" """
raise NotImplementedError
class OutFileLike(Protocol): class OutFileLike(Protocol):
@ -115,18 +126,21 @@ class OutFileLike(Protocol):
class OutWrapper: class OutWrapper:
__slots__ = ("_fd", "_owned") __slots__ = ("_fd", "_owned_fd")
_fd: Union[OutFileLike, BufferedWriter]
_owned_fd: Optional[BufferedWriter]
def __init__(self, file: Union[str, Path, OutFileLike]): def __init__(self, file: Union[str, Path, OutFileLike]):
if isinstance(file, str): if isinstance(file, str):
file = Path(file) file = Path(file)
if isinstance(file, Path): if isinstance(file, Path):
self._fd: Union[OutFileLike, BufferedWriter] = file.open("wb") self._fd = file.open("wb")
self._owned = True self._owned_fd = self._fd
else: else:
self._fd = file self._fd = file
self._owned = False self._owned_fd = None
async def write(self, chunk: bytes) -> None: async def write(self, chunk: bytes) -> None:
ret = self._fd.write(chunk) ret = self._fd.write(chunk)
@ -134,9 +148,9 @@ class OutWrapper:
await ret await ret
def close(self) -> None: def close(self) -> None:
if self._owned: if self._owned_fd is not None:
assert hasattr(self._fd, "close") assert hasattr(self._owned_fd, "close")
self._fd.close() self._owned_fd.close()
class File(metaclass=NoPublicConstructor): class File(metaclass=NoPublicConstructor):
@ -150,7 +164,7 @@ class File(metaclass=NoPublicConstructor):
def __init__( def __init__(
self, self,
*, *,
attributes: List[abcs.DocumentAttribute], attributes: Sequence[abcs.DocumentAttribute],
size: int, size: int,
name: str, name: str,
mime: str, mime: str,
@ -158,7 +172,7 @@ class File(metaclass=NoPublicConstructor):
muted: bool, muted: bool,
input_media: abcs.InputMedia, input_media: abcs.InputMedia,
thumb: Optional[abcs.PhotoSize], thumb: Optional[abcs.PhotoSize],
thumbs: Optional[List[abcs.PhotoSize]], thumbs: Optional[Sequence[abcs.PhotoSize]],
raw: Optional[Union[abcs.MessageMedia, abcs.Photo, abcs.Document]], raw: Optional[Union[abcs.MessageMedia, abcs.Photo, abcs.Document]],
client: Optional[Client], client: Optional[Client],
): ):
@ -405,9 +419,9 @@ class File(metaclass=NoPublicConstructor):
id=self._input_media.id.id, id=self._input_media.id.id,
access_hash=self._input_media.id.access_hash, access_hash=self._input_media.id.access_hash,
file_reference=self._input_media.id.file_reference, file_reference=self._input_media.id.file_reference,
thumb_size=self._thumb.type thumb_size=(
if isinstance(self._thumb, thumb_types) self._thumb.type if isinstance(self._thumb, thumb_types) else ""
else "", ),
) )
elif isinstance(self._input_media, types.InputMediaPhoto): elif isinstance(self._input_media, types.InputMediaPhoto):
assert isinstance(self._input_media.id, types.InputPhoto) assert isinstance(self._input_media.id, types.InputPhoto)

View File

@ -2,7 +2,17 @@ from __future__ import annotations
import datetime import datetime
import time import time
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Self, Tuple, Union from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Optional,
Self,
Sequence,
Tuple,
Union,
)
from ...tl import abcs, types from ...tl import abcs, types
from ..parsers import ( from ..parsers import (
@ -502,7 +512,7 @@ class Message(metaclass=NoPublicConstructor):
def build_msg_map( def build_msg_map(
client: Client, messages: List[abcs.Message], chat_map: Dict[int, Chat] client: Client, messages: Sequence[abcs.Message], chat_map: Dict[int, Chat]
) -> Dict[int, Message]: ) -> Dict[int, Message]:
return { return {
msg.id: msg msg.id: msg

View File

@ -1,8 +1,9 @@
""" """
Class definitions stolen from `trio`, with some modifications. Class definitions stolen from `trio`, with some modifications.
""" """
import abc import abc
from typing import Type, TypeVar from typing import Any, Type, TypeVar
T = TypeVar("T") T = TypeVar("T")
@ -28,7 +29,7 @@ class Final(abc.ABCMeta):
class NoPublicConstructor(Final): class NoPublicConstructor(Final):
def __call__(cls) -> None: def __call__(cls, *args: Any, **kwds: Any) -> Any:
raise TypeError( raise TypeError(
f"{cls.__module__}.{cls.__qualname__} has no public constructor" f"{cls.__module__}.{cls.__qualname__} has no public constructor"
) )

View File

@ -100,12 +100,8 @@ class Participant(metaclass=NoPublicConstructor):
), ),
): ):
return self._raw.user_id return self._raw.user_id
elif isinstance(
self._raw, (types.ChannelParticipantBanned, types.ChannelParticipantLeft)
):
return peer_id(self._raw.peer)
else: else:
raise RuntimeError("unexpected case") return peer_id(self._raw.peer)
@property @property
def user(self) -> Optional[User]: def user(self) -> Optional[User]:

View File

@ -1,7 +1,28 @@
try:
import cryptg
def ige_encrypt(
plaintext: bytes | bytearray | memoryview, key: bytes, iv: bytes
) -> bytes: # noqa: F811
return cryptg.encrypt_ige(
bytes(plaintext) if not isinstance(plaintext, bytes) else plaintext, key, iv
)
def ige_decrypt(
ciphertext: bytes | bytearray | memoryview, key: bytes, iv: bytes
) -> bytes: # noqa: F811
return cryptg.decrypt_ige(
bytes(ciphertext) if not isinstance(ciphertext, bytes) else ciphertext,
key,
iv,
)
except ImportError:
import pyaes import pyaes
def ige_encrypt(
def ige_encrypt(plaintext: bytes, key: bytes, iv: bytes) -> bytes: plaintext: bytes | bytearray | memoryview, key: bytes, iv: bytes
) -> bytes:
assert len(plaintext) % 16 == 0 assert len(plaintext) % 16 == 0
assert len(iv) == 32 assert len(iv) == 32
@ -26,8 +47,9 @@ def ige_encrypt(plaintext: bytes, key: bytes, iv: bytes) -> bytes:
return bytes(ciphertext) return bytes(ciphertext)
def ige_decrypt(
def ige_decrypt(ciphertext: bytes, key: bytes, iv: bytes) -> bytes: ciphertext: bytes | bytearray | memoryview, key: bytes, iv: bytes
) -> bytes:
assert len(ciphertext) % 16 == 0 assert len(ciphertext) % 16 == 0
assert len(iv) == 32 assert len(iv) == 32
@ -51,22 +73,3 @@ def ige_decrypt(ciphertext: bytes, key: bytes, iv: bytes) -> bytes:
plaintext += plaintext_block plaintext += plaintext_block
return bytes(plaintext) return bytes(plaintext)
try:
import cryptg
def ige_encrypt(plaintext: bytes, key: bytes, iv: bytes) -> bytes: # noqa: F811
return cryptg.encrypt_ige(
bytes(plaintext) if not isinstance(plaintext, bytes) else plaintext, key, iv
)
def ige_decrypt(ciphertext: bytes, key: bytes, iv: bytes) -> bytes: # noqa: F811
return cryptg.decrypt_ige(
bytes(ciphertext) if not isinstance(ciphertext, bytes) else ciphertext,
key,
iv,
)
except ImportError:
pass

View File

@ -1,7 +1,7 @@
import os import os
from collections import namedtuple
from enum import IntEnum from enum import IntEnum
from hashlib import sha1, sha256 from hashlib import sha1, sha256
from typing import NamedTuple
from .aes import ige_decrypt, ige_encrypt from .aes import ige_decrypt, ige_encrypt
from .auth_key import AuthKey from .auth_key import AuthKey
@ -13,15 +13,19 @@ class Side(IntEnum):
SERVER = 8 SERVER = 8
CalcKey = namedtuple("CalcKey", ("key", "iv")) class CalcKey(NamedTuple):
key: bytes
iv: bytes
# https://core.telegram.org/mtproto/description#defining-aes-key-and-initialization-vector # https://core.telegram.org/mtproto/description#defining-aes-key-and-initialization-vector
def calc_key(auth_key: AuthKey, msg_key: bytes, side: Side) -> CalcKey: def calc_key(
auth_key: AuthKey, msg_key: bytes | bytearray | memoryview, side: Side
) -> CalcKey:
x = int(side) x = int(side)
# sha256_a = SHA256 (msg_key + substr (auth_key, x, 36)) # sha256_a = SHA256 (msg_key + substr (auth_key, x, 36))
sha256_a = sha256(msg_key + auth_key.data[x : x + 36]).digest() sha256_a = sha256(bytes(msg_key) + auth_key.data[x : x + 36]).digest()
# sha256_b = SHA256 (substr (auth_key, 40+x, 36) + msg_key) # sha256_b = SHA256 (substr (auth_key, 40+x, 36) + msg_key)
sha256_b = sha256(auth_key.data[x + 40 : x + 76] + msg_key).digest() sha256_b = sha256(auth_key.data[x + 40 : x + 76] + msg_key).digest()
@ -66,7 +70,9 @@ def encrypt_data_v2(plaintext: bytes, auth_key: AuthKey) -> bytes:
return _do_encrypt_data_v2(plaintext, auth_key, random_padding) return _do_encrypt_data_v2(plaintext, auth_key, random_padding)
def decrypt_data_v2(ciphertext: bytes, auth_key: AuthKey) -> bytes: def decrypt_data_v2(
ciphertext: bytes | bytearray | memoryview, auth_key: AuthKey
) -> bytes:
side = Side.SERVER side = Side.SERVER
x = int(side) x = int(side)

View File

@ -1,45 +1,58 @@
# Ported from https://github.com/Lonami/grammers/blob/d91dc82/lib/grammers-crypto/src/two_factor_auth.rs # Ported from https://github.com/Lonami/grammers/blob/d91dc82/lib/grammers-crypto/src/two_factor_auth.rs
from collections import namedtuple
from hashlib import pbkdf2_hmac, sha256 from hashlib import pbkdf2_hmac, sha256
from typing import NamedTuple
from .factorize import factorize from .factorize import factorize
TwoFactorAuth = namedtuple("TwoFactorAuth", ("m1", "g_a"))
class TwoFactorAuth(NamedTuple):
m1: bytes
g_a: bytes
def pad_to_256(data: bytes) -> bytes: def pad_to_256(data: bytes | bytearray | memoryview) -> bytes:
return bytes(256 - len(data)) + data return bytes(256 - len(data)) + data
# H(data) := sha256(data) # H(data) := sha256(data)
def h(*data: bytes) -> bytes: def h(*data: bytes | bytearray | memoryview) -> bytes:
return sha256(b"".join(data)).digest() return sha256(b"".join(data)).digest()
# SH(data, salt) := H(salt | data | salt) # SH(data, salt) := H(salt | data | salt)
def sh(data: bytes, salt: bytes) -> bytes: def sh(
data: bytes | bytearray | memoryview, salt: bytes | bytearray | memoryview
) -> bytes:
return h(salt, data, salt) return h(salt, data, salt)
# PH1(password, salt1, salt2) := SH(SH(password, salt1), salt2) # PH1(password, salt1, salt2) := SH(SH(password, salt1), salt2)
def ph1(password: bytes, salt1: bytes, salt2: bytes) -> bytes: def ph1(
password: bytes | bytearray | memoryview,
salt1: bytes | bytearray | memoryview,
salt2: bytes | bytearray | memoryview,
) -> bytes:
return sh(sh(password, salt1), salt2) return sh(sh(password, salt1), salt2)
# PH2(password, salt1, salt2) := SH(pbkdf2(sha512, PH1(password, salt1, salt2), salt1, 100000), salt2) # PH2(password, salt1, salt2) := SH(pbkdf2(sha512, PH1(password, salt1, salt2), salt1, 100000), salt2)
def ph2(password: bytes, salt1: bytes, salt2: bytes) -> bytes: def ph2(
password: bytes | bytearray | memoryview,
salt1: bytes | bytearray | memoryview,
salt2: bytes | bytearray | memoryview,
) -> bytes:
return sh(pbkdf2_hmac("sha512", ph1(password, salt1, salt2), salt1, 100000), salt2) return sh(pbkdf2_hmac("sha512", ph1(password, salt1, salt2), salt1, 100000), salt2)
# https://core.telegram.org/api/srp # https://core.telegram.org/api/srp
def calculate_2fa( def calculate_2fa(
*, *,
salt1: bytes, salt1: bytes | bytearray | memoryview,
salt2: bytes, salt2: bytes | bytearray | memoryview,
g: int, g: int,
p: bytes, p: bytes | bytearray | memoryview,
g_b: bytes, g_b: bytes | bytearray | memoryview,
a: bytes, a: bytes | bytearray | memoryview,
password: bytes, password: bytes,
) -> TwoFactorAuth: ) -> TwoFactorAuth:
big_p = int.from_bytes(p) big_p = int.from_bytes(p)
@ -100,16 +113,16 @@ def calculate_2fa(
return TwoFactorAuth(m1, g_a) return TwoFactorAuth(m1, g_a)
def check_p_len(p: bytes) -> bool: def check_p_len(p: bytes | bytearray | memoryview) -> bool:
return len(p) == 256 return len(p) == 256
def check_known_prime(p: bytes, g: int) -> bool: def check_known_prime(p: bytes | bytearray | memoryview, g: int) -> bool:
good_prime = b"\xc7\x1c\xae\xb9\xc6\xb1\xc9\x04\x8elR/p\xf1?s\x98\r@#\x8e>!\xc1I4\xd07V=\x93\x0fH\x19\x8a\n\xa7\xc1@X\"\x94\x93\xd2%0\xf4\xdb\xfa3on\n\xc9%\x13\x95C\xae\xd4L\xce|7 \xfdQ\xf6\x94XpZ\xc6\x8c\xd4\xfekk\x13\xab\xdc\x97FQ)i2\x84T\xf1\x8f\xaf\x8cY_d$w\xfe\x96\xbb*\x94\x1d[\xcd\x1dJ\xc8\xccI\x88\x07\x08\xfa\x9b7\x8e<O:\x90`\xbe\xe6|\xf9\xa4\xa4\xa6\x95\x81\x10Q\x90~\x16'S\xb5k\x0fkA\r\xbat\xd8\xa8K*\x14\xb3\x14N\x0e\xf1(GT\xfd\x17\xed\x95\rYe\xb4\xb9\xddFX-\xb1\x17\x8d\x16\x9ck\xc4e\xb0\xd6\xff\x9c\xa3\x92\x8f\xef[\x9a\xe4\xe4\x18\xfc\x15\xe8>\xbe\xa0\xf8\x7f\xa9\xff^\xedp\x05\r\xed(I\xf4{\xf9Y\xd9V\x85\x0c\xe9)\x85\x1f\r\x81\x15\xf65\xb1\x05\xee.N\x15\xd0K$T\xbfoO\xad\xf04\xb1\x04\x03\x11\x9c\xd8\xe3\xb9/\xcc[" good_prime = b"\xc7\x1c\xae\xb9\xc6\xb1\xc9\x04\x8elR/p\xf1?s\x98\r@#\x8e>!\xc1I4\xd07V=\x93\x0fH\x19\x8a\n\xa7\xc1@X\"\x94\x93\xd2%0\xf4\xdb\xfa3on\n\xc9%\x13\x95C\xae\xd4L\xce|7 \xfdQ\xf6\x94XpZ\xc6\x8c\xd4\xfekk\x13\xab\xdc\x97FQ)i2\x84T\xf1\x8f\xaf\x8cY_d$w\xfe\x96\xbb*\x94\x1d[\xcd\x1dJ\xc8\xccI\x88\x07\x08\xfa\x9b7\x8e<O:\x90`\xbe\xe6|\xf9\xa4\xa4\xa6\x95\x81\x10Q\x90~\x16'S\xb5k\x0fkA\r\xbat\xd8\xa8K*\x14\xb3\x14N\x0e\xf1(GT\xfd\x17\xed\x95\rYe\xb4\xb9\xddFX-\xb1\x17\x8d\x16\x9ck\xc4e\xb0\xd6\xff\x9c\xa3\x92\x8f\xef[\x9a\xe4\xe4\x18\xfc\x15\xe8>\xbe\xa0\xf8\x7f\xa9\xff^\xedp\x05\r\xed(I\xf4{\xf9Y\xd9V\x85\x0c\xe9)\x85\x1f\r\x81\x15\xf65\xb1\x05\xee.N\x15\xd0K$T\xbfoO\xad\xf04\xb1\x04\x03\x11\x9c\xd8\xe3\xb9/\xcc["
return p == good_prime and g in (3, 4, 5, 7) return p == good_prime and g in (3, 4, 5, 7)
def check_p_prime_and_subgroup(p: bytes, g: int) -> bool: def check_p_prime_and_subgroup(p: bytes | bytearray | memoryview, g: int) -> bool:
if check_known_prime(p, g): if check_known_prime(p, g):
return True return True
@ -133,7 +146,7 @@ def check_p_prime_and_subgroup(p: bytes, g: int) -> bool:
return candidate and factorize((big_p - 1) // 2)[0] == 1 return candidate and factorize((big_p - 1) // 2)[0] == 1
def check_p_and_g(p: bytes, g: int) -> bool: def check_p_and_g(p: bytes | bytearray | memoryview, g: int) -> bool:
if not check_p_len(p): if not check_p_len(p):
return False return False

View File

@ -164,6 +164,7 @@ def _do_step3(
) )
key, iv = generate_key_data_from_nonce(server_nonce, new_nonce) key, iv = generate_key_data_from_nonce(server_nonce, new_nonce)
assert isinstance(server_dh_params.encrypted_answer, bytes)
plain_text_answer = decrypt_ige(server_dh_params.encrypted_answer, key, iv) plain_text_answer = decrypt_ige(server_dh_params.encrypted_answer, key, iv)
got_answer_hash = plain_text_answer[:20] got_answer_hash = plain_text_answer[:20]

View File

@ -260,7 +260,7 @@ class Encrypted(Mtp):
self._store_own_updates(result) self._store_own_updates(result)
self._deserialization.append(RpcResult(msg_id, result)) self._deserialization.append(RpcResult(msg_id, result))
def _store_own_updates(self, body: bytes) -> None: def _store_own_updates(self, body: bytes | bytearray | memoryview) -> None:
constructor_id = struct.unpack_from("I", body)[0] constructor_id = struct.unpack_from("I", body)[0]
if constructor_id in UPDATE_IDS: if constructor_id in UPDATE_IDS:
self._deserialization.append(Update(body)) self._deserialization.append(Update(body))
@ -332,7 +332,7 @@ class Encrypted(Mtp):
) )
self._start_salt_time = (salts.now, self._adjusted_now()) self._start_salt_time = (salts.now, self._adjusted_now())
self._salts = salts.salts self._salts = list(salts.salts)
self._salts.sort(key=lambda salt: -salt.valid_since) self._salts.sort(key=lambda salt: -salt.valid_since)
def _handle_future_salt(self, message: Message) -> None: def _handle_future_salt(self, message: Message) -> None:
@ -439,7 +439,9 @@ class Encrypted(Mtp):
msg_id, buffer = result msg_id, buffer = result
return msg_id, encrypt_data_v2(buffer, self._auth_key) return msg_id, encrypt_data_v2(buffer, self._auth_key)
def deserialize(self, payload: bytes) -> List[Deserialization]: def deserialize(
self, payload: bytes | bytearray | memoryview
) -> List[Deserialization]:
check_message_buffer(payload) check_message_buffer(payload)
plaintext = decrypt_data_v2(payload, self._auth_key) plaintext = decrypt_data_v2(payload, self._auth_key)

View File

@ -31,7 +31,9 @@ class Plain(Mtp):
self._buffer.clear() self._buffer.clear()
return MsgId(0), result return MsgId(0), result
def deserialize(self, payload: bytes) -> List[Deserialization]: def deserialize(
self, payload: bytes | bytearray | memoryview
) -> List[Deserialization]:
check_message_buffer(payload) check_message_buffer(payload)
auth_key_id, msg_id, length = struct.unpack_from("<qqi", payload) auth_key_id, msg_id, length = struct.unpack_from("<qqi", payload)

View File

@ -15,7 +15,7 @@ class Update:
__slots__ = ("body",) __slots__ = ("body",)
def __init__(self, body: bytes): def __init__(self, body: bytes | bytearray | memoryview):
self.body = body self.body = body
@ -26,7 +26,7 @@ class RpcResult:
__slots__ = ("msg_id", "body") __slots__ = ("msg_id", "body")
def __init__(self, msg_id: MsgId, body: bytes): def __init__(self, msg_id: MsgId, body: bytes | bytearray | memoryview):
self.msg_id = msg_id self.msg_id = msg_id
self.body = body self.body = body
@ -201,7 +201,9 @@ class Mtp(ABC):
""" """
@abstractmethod @abstractmethod
def deserialize(self, payload: bytes) -> List[Deserialization]: def deserialize(
self, payload: bytes | bytearray | memoryview
) -> List[Deserialization]:
""" """
Deserialize incoming buffer payload. Deserialize incoming buffer payload.
""" """

View File

@ -12,7 +12,7 @@ class Transport(ABC):
pass pass
@abstractmethod @abstractmethod
def unpack(self, input: bytes, output: bytearray) -> int: def unpack(self, input: bytes | bytearray | memoryview, output: bytearray) -> int:
pass pass

View File

@ -36,7 +36,7 @@ class Abridged(Transport):
write(struct.pack("<i", 0x7F | (length << 8))) write(struct.pack("<i", 0x7F | (length << 8)))
write(input) write(input)
def unpack(self, input: bytes, output: bytearray) -> int: def unpack(self, input: bytes | bytearray | memoryview, output: bytearray) -> int:
if not input: if not input:
raise MissingBytes(expected=1, got=0) raise MissingBytes(expected=1, got=0)

View File

@ -35,7 +35,7 @@ class Full(Transport):
write(struct.pack("<I", crc32(tmp))) write(struct.pack("<I", crc32(tmp)))
self._send_seq += 1 self._send_seq += 1
def unpack(self, input: bytes, output: bytearray) -> int: def unpack(self, input: bytes | bytearray | memoryview, output: bytearray) -> int:
if len(input) < 4: if len(input) < 4:
raise MissingBytes(expected=4, got=len(input)) raise MissingBytes(expected=4, got=len(input))

View File

@ -32,7 +32,7 @@ class Intermediate(Transport):
write(struct.pack("<i", len(input))) write(struct.pack("<i", len(input)))
write(input) write(input)
def unpack(self, input: bytes, output: bytearray) -> int: def unpack(self, input: bytes | bytearray | memoryview, output: bytearray) -> int:
if len(input) < 4: if len(input) < 4:
raise MissingBytes(expected=4, got=len(input)) raise MissingBytes(expected=4, got=len(input))

View File

@ -10,7 +10,7 @@ CONTAINER_MAX_LENGTH = 100
MESSAGE_SIZE_OVERHEAD = 8 + 4 + 4 # msg_id, seq_no, bytes MESSAGE_SIZE_OVERHEAD = 8 + 4 + 4 # msg_id, seq_no, bytes
def check_message_buffer(message: bytes) -> None: def check_message_buffer(message: bytes | bytearray | memoryview) -> None:
if len(message) < 20: if len(message) < 20:
raise ValueError( raise ValueError(
f"server payload is too small to be a valid message: {message.hex()}" f"server payload is too small to be a valid message: {message.hex()}"

View File

@ -70,6 +70,7 @@ class AsyncReader(Protocol):
:param n: :param n:
Amount of bytes to read at most. Amount of bytes to read at most.
""" """
raise NotImplementedError
class AsyncWriter(Protocol): class AsyncWriter(Protocol):
@ -77,7 +78,7 @@ class AsyncWriter(Protocol):
A :class:`asyncio.StreamWriter`-like class. A :class:`asyncio.StreamWriter`-like class.
""" """
def write(self, data: bytes) -> None: def write(self, data: bytes | bytearray | memoryview) -> None:
""" """
Must behave like :meth:`asyncio.StreamWriter.write`. Must behave like :meth:`asyncio.StreamWriter.write`.
@ -127,7 +128,7 @@ class Connector(Protocol):
""" """
async def __call__(self, ip: str, port: int) -> Tuple[AsyncReader, AsyncWriter]: async def __call__(self, ip: str, port: int) -> Tuple[AsyncReader, AsyncWriter]:
pass raise NotImplementedError
class RequestState(ABC): class RequestState(ABC):
@ -340,12 +341,12 @@ class Sender:
self._process_result(result) self._process_result(result)
elif isinstance(result, RpcError): elif isinstance(result, RpcError):
self._process_error(result) self._process_error(result)
elif isinstance(result, BadMessage):
self._process_bad_message(result)
else: else:
raise RuntimeError("unexpected case") self._process_bad_message(result)
def _process_update(self, updates: List[Updates], update: bytes) -> None: def _process_update(
self, updates: List[Updates], update: bytes | bytearray | memoryview
) -> None:
try: try:
updates.append(Updates.from_bytes(update)) updates.append(Updates.from_bytes(update))
except ValueError: except ValueError:

View File

@ -1,4 +1,4 @@
from typing import Dict, List, Optional, Tuple from typing import Any, Dict, Optional, Sequence, Tuple
from ...tl import abcs, types from ...tl import abcs, types
from .packed import PackedChat, PackedType from .packed import PackedChat, PackedType
@ -131,7 +131,7 @@ class ChatHashCache:
else: else:
raise RuntimeError("unexpected case") raise RuntimeError("unexpected case")
def extend(self, users: List[abcs.User], chats: List[abcs.Chat]) -> bool: def extend(self, users: Sequence[abcs.User], chats: Sequence[abcs.Chat]) -> bool:
# See https://core.telegram.org/api/min for "issues" with "min constructors". # See https://core.telegram.org/api/min for "issues" with "min constructors".
success = True success = True
@ -299,6 +299,7 @@ class ChatHashCache:
success &= self._has_notify_peer(peer) success &= self._has_notify_peer(peer)
for field in ("users",): for field in ("users",):
user: Any
users = getattr(message.action, field, None) users = getattr(message.action, field, None)
if isinstance(users, list): if isinstance(users, list):
for user in users: for user in users:

View File

@ -4,11 +4,35 @@ from typing import List, Literal, Union
from ...tl import abcs from ...tl import abcs
NO_DATE = 0 # used on adapted messages.affected* from lower layers
NO_SEQ = 0
NO_PTS = 0
# https://core.telegram.org/method/updates.getChannelDifference
BOT_CHANNEL_DIFF_LIMIT = 100000
USER_CHANNEL_DIFF_LIMIT = 100
POSSIBLE_GAP_TIMEOUT = 0.5
# https://core.telegram.org/api/updates
NO_UPDATES_TIMEOUT = 15 * 60
ENTRY_ACCOUNT: Literal["ACCOUNT"] = "ACCOUNT"
ENTRY_SECRET: Literal["SECRET"] = "SECRET"
Entry = Union[Literal["ACCOUNT", "SECRET"], int]
# Python's logging doesn't define a TRACE level. Pick halfway between DEBUG and NOTSET.
# We don't define a name for this as libraries shouldn't do that though.
LOG_LEVEL_TRACE = (logging.DEBUG - logging.NOTSET) // 2
class PtsInfo: class PtsInfo:
__slots__ = ("pts", "pts_count", "entry") __slots__ = ("pts", "pts_count", "entry")
def __init__(self, entry: "Entry", pts: int, pts_count: int) -> None: entry: Entry # pyright needs this or it infers int | str
def __init__(self, entry: Entry, pts: int, pts_count: int) -> None:
self.pts = pts self.pts = pts
self.pts_count = pts_count self.pts_count = pts_count
self.entry = entry self.entry = entry
@ -59,26 +83,3 @@ class PrematureEndReason(Enum):
class Gap(ValueError): class Gap(ValueError):
def __repr__(self) -> str: def __repr__(self) -> str:
return "Gap()" return "Gap()"
NO_DATE = 0 # used on adapted messages.affected* from lower layers
NO_SEQ = 0
NO_PTS = 0
# https://core.telegram.org/method/updates.getChannelDifference
BOT_CHANNEL_DIFF_LIMIT = 100000
USER_CHANNEL_DIFF_LIMIT = 100
POSSIBLE_GAP_TIMEOUT = 0.5
# https://core.telegram.org/api/updates
NO_UPDATES_TIMEOUT = 15 * 60
ENTRY_ACCOUNT: Literal["ACCOUNT"] = "ACCOUNT"
ENTRY_SECRET: Literal["SECRET"] = "SECRET"
Entry = Union[Literal["ACCOUNT", "SECRET"], int]
# Python's logging doesn't define a TRACE level. Pick halfway between DEBUG and NOTSET.
# We don't define a name for this as libraries shouldn't do that though.
LOG_LEVEL_TRACE = (logging.DEBUG - logging.NOTSET) // 2

View File

@ -2,7 +2,7 @@ import asyncio
import datetime import datetime
import logging import logging
import time import time
from typing import Dict, List, Optional, Set, Tuple from typing import Dict, List, Optional, Sequence, Set, Tuple
from ...tl import Request, abcs, functions, types from ...tl import Request, abcs, functions, types
from ..chat import ChatHashCache from ..chat import ChatHashCache
@ -168,6 +168,7 @@ class MessageBox:
if not entries: if not entries:
return return
entry: Entry = ENTRY_ACCOUNT # for pyright to know it's not unbound
for entry in entries: for entry in entries:
if entry not in self.map: if entry not in self.map:
raise RuntimeError( raise RuntimeError(
@ -258,7 +259,7 @@ class MessageBox:
self, self,
updates: abcs.Updates, updates: abcs.Updates,
chat_hashes: ChatHashCache, chat_hashes: ChatHashCache,
) -> Tuple[List[abcs.Update], List[abcs.User], List[abcs.Chat]]: ) -> Tuple[List[abcs.Update], Sequence[abcs.User], Sequence[abcs.Chat]]:
result: List[abcs.Update] = [] result: List[abcs.Update] = []
combined = adapt(updates, chat_hashes) combined = adapt(updates, chat_hashes)
@ -289,7 +290,7 @@ class MessageBox:
sorted_updates = list(sorted(combined.updates, key=update_sort_key)) sorted_updates = list(sorted(combined.updates, key=update_sort_key))
any_pts_applied = False any_pts_applied = False
reset_deadlines_for = set() reset_deadlines_for: Set[Entry] = set()
for update in sorted_updates: for update in sorted_updates:
entry, applied = self.apply_pts_info(update) entry, applied = self.apply_pts_info(update)
if entry is not None: if entry is not None:
@ -420,9 +421,11 @@ class MessageBox:
pts_limit=None, pts_limit=None,
pts_total_limit=None, pts_total_limit=None,
date=int(self.date.timestamp()), date=int(self.date.timestamp()),
qts=self.map[ENTRY_SECRET].pts qts=(
self.map[ENTRY_SECRET].pts
if ENTRY_SECRET in self.map if ENTRY_SECRET in self.map
else NO_SEQ, else NO_SEQ
),
qts_limit=None, qts_limit=None,
) )
if __debug__: if __debug__:
@ -435,12 +438,12 @@ class MessageBox:
self, self,
diff: abcs.updates.Difference, diff: abcs.updates.Difference,
chat_hashes: ChatHashCache, chat_hashes: ChatHashCache,
) -> Tuple[List[abcs.Update], List[abcs.User], List[abcs.Chat]]: ) -> Tuple[List[abcs.Update], Sequence[abcs.User], Sequence[abcs.Chat]]:
if __debug__: if __debug__:
self._trace("applying account difference: %s", diff) self._trace("applying account difference: %s", diff)
finish: bool finish: bool
result: Tuple[List[abcs.Update], List[abcs.User], List[abcs.Chat]] result: Tuple[List[abcs.Update], Sequence[abcs.User], Sequence[abcs.Chat]]
if isinstance(diff, types.updates.DifferenceEmpty): if isinstance(diff, types.updates.DifferenceEmpty):
finish = True finish = True
self.date = datetime.datetime.fromtimestamp( self.date = datetime.datetime.fromtimestamp(
@ -493,7 +496,7 @@ class MessageBox:
self, self,
diff: types.updates.Difference, diff: types.updates.Difference,
chat_hashes: ChatHashCache, chat_hashes: ChatHashCache,
) -> Tuple[List[abcs.Update], List[abcs.User], List[abcs.Chat]]: ) -> Tuple[List[abcs.Update], Sequence[abcs.User], Sequence[abcs.Chat]]:
state = diff.state state = diff.state
assert isinstance(state, types.updates.State) assert isinstance(state, types.updates.State)
self.map[ENTRY_ACCOUNT].pts = state.pts self.map[ENTRY_ACCOUNT].pts = state.pts
@ -556,9 +559,11 @@ class MessageBox:
channel=channel, channel=channel,
filter=types.ChannelMessagesFilterEmpty(), filter=types.ChannelMessagesFilterEmpty(),
pts=state.pts, pts=state.pts,
limit=BOT_CHANNEL_DIFF_LIMIT limit=(
BOT_CHANNEL_DIFF_LIMIT
if chat_hashes.is_self_bot if chat_hashes.is_self_bot
else USER_CHANNEL_DIFF_LIMIT, else USER_CHANNEL_DIFF_LIMIT
),
) )
if __debug__: if __debug__:
self._trace("requesting channel difference: %s", gd) self._trace("requesting channel difference: %s", gd)
@ -577,7 +582,7 @@ class MessageBox:
channel_id: int, channel_id: int,
diff: abcs.updates.ChannelDifference, diff: abcs.updates.ChannelDifference,
chat_hashes: ChatHashCache, chat_hashes: ChatHashCache,
) -> Tuple[List[abcs.Update], List[abcs.User], List[abcs.Chat]]: ) -> Tuple[List[abcs.Update], Sequence[abcs.User], Sequence[abcs.Chat]]:
entry: Entry = channel_id entry: Entry = channel_id
if __debug__: if __debug__:
self._trace("applying channel=%r difference: %s", entry, diff) self._trace("applying channel=%r difference: %s", entry, diff)

View File

@ -4,7 +4,7 @@ from .memory import MemorySession
from .storage import Storage from .storage import Storage
try: try:
from .sqlite import SqliteSession from .sqlite import SqliteSession # pyright: ignore [reportAssignmentType]
except ImportError as e: except ImportError as e:
import_err = e import_err = e

View File

@ -32,7 +32,7 @@ def _bootstrap_get_ty(constructor_id: int) -> Optional[Type["Serializable"]]:
class Reader: class Reader:
__slots__ = ("_view", "_pos", "_len") __slots__ = ("_view", "_pos", "_len")
def __init__(self, buffer: bytes | memoryview) -> None: def __init__(self, buffer: bytes | bytearray | memoryview) -> None:
self._view = ( self._view = (
memoryview(buffer) if not isinstance(buffer, memoryview) else buffer memoryview(buffer) if not isinstance(buffer, memoryview) else buffer
) )

View File

@ -9,10 +9,10 @@ class HasSlots(Protocol):
__slots__: Tuple[str, ...] __slots__: Tuple[str, ...]
def obj_repr(obj: HasSlots) -> str: def obj_repr(self: HasSlots) -> str:
fields = ((attr, getattr(obj, attr)) for attr in obj.__slots__) fields = ((attr, getattr(self, attr)) for attr in self.__slots__)
params = ", ".join(f"{name}={field!r}" for name, field in fields) params = ", ".join(f"{name}={field!r}" for name, field in fields)
return f"{obj.__class__.__name__}({params})" return f"{self.__class__.__name__}({params})"
class Serializable(abc.ABC): class Serializable(abc.ABC):
@ -36,7 +36,7 @@ class Serializable(abc.ABC):
pass pass
@classmethod @classmethod
def from_bytes(cls, blob: bytes) -> Self: def from_bytes(cls, blob: bytes | bytearray | memoryview) -> Self:
return Reader(blob).read_serializable(cls) return Reader(blob).read_serializable(cls)
def __bytes__(self) -> bytes: def __bytes__(self) -> bytes:
@ -54,7 +54,7 @@ class Serializable(abc.ABC):
) )
def serialize_bytes_to(buffer: bytearray, data: bytes) -> None: def serialize_bytes_to(buffer: bytearray, data: bytes | bytearray | memoryview) -> None:
length = len(data) length = len(data)
if length < 0xFE: if length < 0xFE:
buffer += struct.pack("<B", length) buffer += struct.pack("<B", length)

View File

@ -65,7 +65,7 @@ def test_key_from_nonce() -> None:
server_nonce = int.from_bytes(bytes(range(16))) server_nonce = int.from_bytes(bytes(range(16)))
new_nonce = int.from_bytes(bytes(range(32))) new_nonce = int.from_bytes(bytes(range(32)))
(key, iv) = generate_key_data_from_nonce(server_nonce, new_nonce) key, iv = generate_key_data_from_nonce(server_nonce, new_nonce)
assert ( assert (
key key
== b'\x07X\xf1S;a]$\xf6\xe8\xa9Jo\xcb\xee\nU\xea\xab"\x17\xd7)\\\xa9!=\x1a-}\x16\xa6' == b'\x07X\xf1S;a]$\xf6\xe8\xa9Jo\xcb\xee\nU\xea\xab"\x17\xd7)\\\xa9!=\x1a-}\x16\xa6'

View File

@ -7,7 +7,7 @@ from telethon._impl.mtproto import Abridged
class Output(bytearray): class Output(bytearray):
__slots__ = () __slots__ = ()
def __call__(self, data: bytes) -> None: def __call__(self, data: bytes | bytearray | memoryview) -> None:
self += data self += data

View File

@ -7,7 +7,7 @@ from telethon._impl.mtproto import Full
class Output(bytearray): class Output(bytearray):
__slots__ = () __slots__ = ()
def __call__(self, data: bytes) -> None: def __call__(self, data: bytes | bytearray | memoryview) -> None:
self += data self += data
@ -20,7 +20,7 @@ def setup_unpack(n: int) -> Tuple[bytes, Full, bytes, bytearray]:
transport, expected_output, input = setup_pack(n) transport, expected_output, input = setup_pack(n)
transport.pack(expected_output, input) transport.pack(expected_output, input)
return expected_output, Full(), input, bytearray() return expected_output, Full(), bytes(input), bytearray()
def test_pack_empty() -> None: def test_pack_empty() -> None:

View File

@ -7,7 +7,7 @@ from telethon._impl.mtproto import Intermediate
class Output(bytearray): class Output(bytearray):
__slots__ = () __slots__ = ()
def __call__(self, data: bytes) -> None: def __call__(self, data: bytes | bytearray | memoryview) -> None:
self += data self += data

View File

@ -46,14 +46,14 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
ignored_types = {"true", "boolTrue", "boolFalse"} # also "compiler built-ins" ignored_types = {"true", "boolTrue", "boolFalse"} # also "compiler built-ins"
abc_namespaces = set() abc_namespaces: Set[str] = set()
type_namespaces = set() type_namespaces: Set[str] = set()
function_namespaces = set() function_namespaces: Set[str] = set()
abc_class_names = set() abc_class_names: Set[str] = set()
type_class_names = set() type_class_names: Set[str] = set()
function_def_names = set() function_def_names: Set[str] = set()
generated_type_names = set() generated_type_names: Set[str] = set()
for typedef in tl.typedefs: for typedef in tl.typedefs:
if typedef.ty.full_name not in generated_types: if typedef.ty.full_name not in generated_types:
@ -67,6 +67,7 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
abc_path = Path("abcs/_nons.py") abc_path = Path("abcs/_nons.py")
if abc_path not in fs: if abc_path not in fs:
fs.write(abc_path, "# pyright: reportUnusedImport=false\n")
fs.write(abc_path, "from abc import ABCMeta\n") fs.write(abc_path, "from abc import ABCMeta\n")
fs.write(abc_path, "from ..core import Serializable\n") fs.write(abc_path, "from ..core import Serializable\n")
@ -93,11 +94,14 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
writer = fs.open(type_path) writer = fs.open(type_path)
if type_path not in fs: if type_path not in fs:
writer.write(
"# pyright: reportUnusedImport=false, reportConstantRedefinition=false"
)
writer.write("import struct") writer.write("import struct")
writer.write("from typing import List, Optional, Self, Sequence") writer.write("from typing import Optional, Self, Sequence")
writer.write("from .. import abcs") writer.write("from .. import abcs")
writer.write("from ..core import Reader, Serializable, serialize_bytes_to") writer.write("from ..core import Reader, Serializable, serialize_bytes_to")
writer.write("_bytes = bytes") writer.write("_bytes = bytes | bytearray | memoryview")
ns = f"{typedef.namespace[0]}." if typedef.namespace else "" ns = f"{typedef.namespace[0]}." if typedef.namespace else ""
generated_type_names.add(f"{ns}{to_class_name(typedef.name)}") generated_type_names.add(f"{ns}{to_class_name(typedef.name)}")
@ -113,14 +117,13 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
# def constructor_id() # def constructor_id()
writer.write(" @classmethod") writer.write(" @classmethod")
writer.write(" def constructor_id(_) -> int:") writer.write(" def constructor_id(cls) -> int:")
writer.write(f" return {hex(typedef.id)}") writer.write(f" return {hex(typedef.id)}")
# def __init__() # def __init__()
if property_params: if property_params:
params = "".join( params = "".join(
f", {p.name}: {param_type_fmt(p.ty, immutable=False)}" f", {p.name}: {param_type_fmt(p.ty)}" for p in property_params
for p in property_params
) )
writer.write(f" def __init__(_s, *{params}) -> None:") writer.write(f" def __init__(_s, *{params}) -> None:")
for p in property_params: for p in property_params:
@ -159,22 +162,18 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
writer = fs.open(function_path) writer = fs.open(function_path)
if function_path not in fs: if function_path not in fs:
writer.write("# pyright: reportUnusedImport=false")
writer.write("import struct") writer.write("import struct")
writer.write("from typing import List, Optional, Self, Sequence") writer.write("from typing import Optional, Self, Sequence")
writer.write("from .. import abcs") writer.write("from .. import abcs")
writer.write("from ..core import Request, serialize_bytes_to") writer.write("from ..core import Request, serialize_bytes_to")
writer.write("_bytes = bytes") writer.write("_bytes = bytes | bytearray | memoryview")
# def name(params, ...) # def name(params, ...)
required_params = [p for p in functiondef.params if not is_computed(p.ty)] required_params = [p for p in functiondef.params if not is_computed(p.ty)]
params = "".join( params = "".join(f", {p.name}: {param_type_fmt(p.ty)}" for p in required_params)
f", {p.name}: {param_type_fmt(p.ty, immutable=True)}"
for p in required_params
)
star = "*" if params else "" star = "*" if params else ""
return_ty = param_type_fmt( return_ty = param_type_fmt(NormalParameter(ty=functiondef.ty, flag=None))
NormalParameter(ty=functiondef.ty, flag=None), immutable=False
)
writer.write( writer.write(
f"def {to_method_name(functiondef.name)}({star}{params}) -> Request[{return_ty}]:" f"def {to_method_name(functiondef.name)}({star}{params}) -> Request[{return_ty}]:"
) )
@ -189,6 +188,7 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
) )
writer = fs.open(Path("layer.py")) writer = fs.open(Path("layer.py"))
writer.write("# pyright: reportUnusedImport=false")
writer.write("from . import abcs, types") writer.write("from . import abcs, types")
writer.write( writer.write(
"from .core import Serializable, Reader, deserialize_bool, deserialize_i32_list, deserialize_i64_list, deserialize_identity, single_deserializer, list_deserializer" "from .core import Serializable, Reader, deserialize_bool, deserialize_i32_list, deserialize_i64_list, deserialize_identity, single_deserializer, list_deserializer"

View File

@ -86,7 +86,7 @@ def inner_type_fmt(ty: Type) -> str:
return f"abcs.{ns}{to_class_name(ty.name)}" return f"abcs.{ns}{to_class_name(ty.name)}"
def param_type_fmt(ty: BaseParameter, *, immutable: bool) -> str: def param_type_fmt(ty: BaseParameter) -> str:
if isinstance(ty, FlagsParameter): if isinstance(ty, FlagsParameter):
return "int" return "int"
elif not isinstance(ty, NormalParameter): elif not isinstance(ty, NormalParameter):
@ -104,10 +104,7 @@ def param_type_fmt(ty: BaseParameter, *, immutable: bool) -> str:
res = "_bytes" if inner_ty.name == "Object" else inner_type_fmt(inner_ty) res = "_bytes" if inner_ty.name == "Object" else inner_type_fmt(inner_ty)
if ty.ty.generic_arg: if ty.ty.generic_arg:
if immutable:
res = f"Sequence[{res}]" res = f"Sequence[{res}]"
else:
res = f"List[{res}]"
if ty.flag and ty.ty.name != "true": if ty.flag and ty.ty.name != "true":
res = f"Optional[{res}]" res = f"Optional[{res}]"

View File

@ -15,7 +15,8 @@ class ParsedTl:
def load_tl_file(path: Union[str, Path]) -> ParsedTl: def load_tl_file(path: Union[str, Path]) -> ParsedTl:
typedefs, functiondefs = [], [] typedefs: List[TypeDef] = []
functiondefs: List[FunctionDef] = []
with open(path, "r", encoding="utf-8") as fd: with open(path, "r", encoding="utf-8") as fd:
contents = fd.read() contents = fd.read()
@ -31,10 +32,8 @@ def load_tl_file(path: Union[str, Path]) -> ParsedTl:
raise raise
elif isinstance(definition, TypeDef): elif isinstance(definition, TypeDef):
typedefs.append(definition) typedefs.append(definition)
elif isinstance(definition, FunctionDef):
functiondefs.append(definition)
else: else:
raise TypeError(f"unexpected type: {type(definition)}") functiondefs.append(definition)
return ParsedTl( return ParsedTl(
layer=layer, typedefs=list(typedefs), functiondefs=list(functiondefs) layer=layer, typedefs=list(typedefs), functiondefs=list(functiondefs)

View File

@ -59,8 +59,8 @@ class Definition:
raise ValueError("invalid id") raise ValueError("invalid id")
type_defs: List[str] = [] type_defs: List[str] = []
flag_defs = [] flag_defs: List[str] = []
params = [] params: List[Parameter] = []
for param_str in middle.split(): for param_str in middle.split():
try: try:

View File

@ -79,7 +79,7 @@ def test_recursive_vec() -> None:
""" """
) )
result = gen_py_code(typedefs=definitions) result = gen_py_code(typedefs=definitions)
assert "value: List[abcs.JsonObjectValue]" in result assert "value: Sequence[abcs.JsonObjectValue]" in result
def test_object_blob_special_case() -> None: def test_object_blob_special_case() -> None:

View File

View File

@ -10,6 +10,7 @@ Imports of new definitions and formatting must be added with other tools.
Properties and private methods can use a different parameter name than `self` Properties and private methods can use a different parameter name than `self`
to avoid being included. to avoid being included.
""" """
import ast import ast
import subprocess import subprocess
import sys import sys
@ -31,6 +32,8 @@ class FunctionMethodsVisitor(ast.NodeVisitor):
match node.args.args: match node.args.args:
case [ast.arg(arg="self", annotation=ast.Name(id="Client")), *_]: case [ast.arg(arg="self", annotation=ast.Name(id="Client")), *_]:
self.methods.append(node) self.methods.append(node)
case _:
pass
class MethodVisitor(ast.NodeVisitor): class MethodVisitor(ast.NodeVisitor):
@ -59,6 +62,8 @@ class MethodVisitor(ast.NodeVisitor):
match node.body: match node.body:
case [ast.Expr(value=ast.Constant(value=str(doc))), *_]: case [ast.Expr(value=ast.Constant(value=str(doc))), *_]:
self.method_docs[node.name] = doc self.method_docs[node.name] = doc
case _:
pass
def main() -> None: def main() -> None:
@ -81,10 +86,10 @@ def main() -> None:
m_visitor.visit(ast.parse(contents)) m_visitor.visit(ast.parse(contents))
class_body = [] class_body: List[ast.stmt] = []
for function in sorted(fm_visitor.methods, key=lambda f: f.name): for function in sorted(fm_visitor.methods, key=lambda f: f.name):
function.body = [] function_body: List[ast.stmt] = []
if doc := m_visitor.method_docs.get(function.name): if doc := m_visitor.method_docs.get(function.name):
function.body.append(ast.Expr(value=ast.Constant(value=doc))) function.body.append(ast.Expr(value=ast.Constant(value=doc)))
@ -108,7 +113,7 @@ def main() -> None:
case _: case _:
call = ast.Return(value=call) call = ast.Return(value=call)
function.body.append(call) function.body = function_body
class_body.append(function) class_body.append(function)
generated = ast.unparse( generated = ast.unparse(

View File

@ -2,6 +2,7 @@
Scan the `client/` directory for __init__.py files. Scan the `client/` directory for __init__.py files.
For every depth-1 import, add it, in order, to the __all__ variable. For every depth-1 import, add it, in order, to the __all__ variable.
""" """
import ast import ast
import os import os
import re import re
@ -25,7 +26,7 @@ def main() -> None:
rf"(tl|mtproto){re.escape(os.path.sep)}(abcs|functions|types)" rf"(tl|mtproto){re.escape(os.path.sep)}(abcs|functions|types)"
) )
files = [] files: List[str] = []
for file in impl_root.rglob("__init__.py"): for file in impl_root.rglob("__init__.py"):
file_str = str(file) file_str = str(file)
if autogenerated_re.search(file_str): if autogenerated_re.search(file_str):
@ -52,6 +53,8 @@ def main() -> None:
+ "]\n" + "]\n"
] ]
break break
case _:
pass
with file.open("w", encoding="utf-8", newline="\n") as fd: with file.open("w", encoding="utf-8", newline="\n") as fd:
fd.writelines(lines) fd.writelines(lines)

19
typings/setuptools.pyi Normal file
View File

@ -0,0 +1,19 @@
from typing import Any, Dict, Optional
class build_meta:
@staticmethod
def build_wheel(
wheel_directory: str,
config_settings: Optional[Dict[Any, Any]] = None,
metadata_directory: Optional[str] = None,
) -> str: ...
@staticmethod
def build_sdist(
sdist_directory: str, config_settings: Optional[Dict[Any, Any]] = None
) -> str: ...
@staticmethod
def build_editable(
wheel_directory: str,
config_settings: Optional[Dict[Any, Any]] = None,
metadata_directory: Optional[str] = None,
) -> str: ...