Make pyright happy

This commit is contained in:
Lonami Exo 2024-03-16 19:05:58 +01:00
parent 854096e9d3
commit 033b56f1d3
55 changed files with 435 additions and 309 deletions

View File

@ -25,6 +25,7 @@ def serialize_builtin(value: Any) -> bytes:
def overhead(obj: Obj) -> None:
x: Any
for v in obj.__dict__.values():
for x in v if isinstance(v, list) else [v]:
if isinstance(x, Obj):
@ -34,6 +35,7 @@ def overhead(obj: Obj) -> None:
def strategy_concat(obj: Obj) -> bytes:
x: Any
res = b""
for v in obj.__dict__.values():
for x in v if isinstance(v, list) else [v]:
@ -45,6 +47,7 @@ def strategy_concat(obj: Obj) -> bytes:
def strategy_append(obj: Obj) -> bytes:
x: Any
res = bytearray()
for v in obj.__dict__.values():
for x in v if isinstance(v, list) else [v]:
@ -57,6 +60,7 @@ def strategy_append(obj: Obj) -> bytes:
def strategy_append_reuse(obj: Obj) -> bytes:
def do_append(o: Obj, res: bytearray) -> None:
x: Any
for v in o.__dict__.values():
for x in v if isinstance(v, list) else [v]:
if isinstance(x, Obj):
@ -70,15 +74,18 @@ def strategy_append_reuse(obj: Obj) -> bytes:
def strategy_join(obj: Obj) -> bytes:
return b"".join(
strategy_join(x) if isinstance(x, Obj) else serialize_builtin(x)
for v in obj.__dict__.values()
for x in (v if isinstance(v, list) else [v])
)
def iterator() -> Iterator[bytes]:
x: Any
for v in obj.__dict__.values():
for x in v if isinstance(v, list) else [v]:
yield strategy_join(x) if isinstance(x, Obj) else serialize_builtin(x)
return b"".join(iterator())
def strategy_join_flat(obj: Obj) -> bytes:
def flatten(o: Obj) -> Iterator[bytes]:
x: Any
for v in o.__dict__.values():
for x in v if isinstance(v, list) else [v]:
if isinstance(x, Obj):
@ -91,6 +98,7 @@ def strategy_join_flat(obj: Obj) -> bytes:
def strategy_write(obj: Obj) -> bytes:
def do_write(o: Obj, buffer: io.BytesIO) -> None:
x: Any
for v in o.__dict__.values():
for x in v if isinstance(v, list) else [v]:
if isinstance(x, Obj):

View File

@ -6,7 +6,7 @@ DATA = 42
def overhead(n: int) -> None:
n
n = n
def strategy_bool(n: int) -> bool:

View File

@ -3,7 +3,7 @@ from pathlib import Path
from typing import Any, Dict, Optional
from setuptools import build_meta as _orig
from setuptools.build_meta import * # noqa: F403
from setuptools.build_meta import * # noqa: F403 # pyright: ignore [reportWildcardImportFromLibrary]
def gen_types_if_needed() -> None:

View File

@ -223,6 +223,7 @@ async def check_password(
if not two_factor_auth.check_p_and_g(algo.p, algo.g):
token = await get_password_information(self)
algo = token._password.current_algo
if not isinstance(
algo,
types.PasswordKdfAlgoSha256Sha256Pbkdf2HmacshA512Iter100000Sha256ModPow,

View File

@ -2059,5 +2059,4 @@ class Client:
exc: Optional[BaseException],
tb: Optional[TracebackType],
) -> None:
exc_type, exc, tb
await disconnect(self)

View File

@ -444,6 +444,7 @@ class FileBytesList(AsyncList[bytes]):
if result.bytes:
self._offset += MAX_CHUNK_SIZE
assert isinstance(result.bytes, bytes)
self._buffer.append(result.bytes)
self._done = len(result.bytes) < MAX_CHUNK_SIZE

View File

@ -39,11 +39,13 @@ async def send_message(
noforwards=not text.can_forward,
update_stickersets_order=False,
peer=peer,
reply_to=types.InputReplyToMessage(
reply_to=(
types.InputReplyToMessage(
reply_to_msg_id=text.replied_message_id, top_msg_id=None
)
if text.replied_message_id
else None,
else None
),
message=message,
random_id=random_id,
reply_markup=getattr(text._raw, "reply_markup", None),
@ -63,11 +65,11 @@ async def send_message(
noforwards=False,
update_stickersets_order=False,
peer=peer,
reply_to=types.InputReplyToMessage(
reply_to_msg_id=reply_to, top_msg_id=None
)
reply_to=(
types.InputReplyToMessage(reply_to_msg_id=reply_to, top_msg_id=None)
if reply_to
else None,
else None
),
message=message,
random_id=random_id,
reply_markup=btns.build_keyboard(buttons),
@ -83,11 +85,14 @@ async def send_message(
{},
out=result.out,
id=result.id,
from_id=types.PeerUser(user_id=self._session.user.id)
from_id=(
types.PeerUser(user_id=self._session.user.id)
if self._session.user
else None,
else None
),
peer_id=packed._to_peer(),
reply_to=types.MessageReplyHeader(
reply_to=(
types.MessageReplyHeader(
reply_to_scheduled=False,
forum_topic=False,
reply_to_msg_id=reply_to,
@ -95,7 +100,8 @@ async def send_message(
reply_to_top_id=None,
)
if reply_to
else None,
else None
),
date=result.date,
message=message,
media=result.media,
@ -593,8 +599,8 @@ def build_message_map(
else:
return MessageMap(client, peer, {}, {})
random_id_to_id = {}
id_to_message = {}
random_id_to_id: Dict[int, int] = {}
id_to_message: Dict[int, Message] = {}
for update in updates:
if isinstance(update, types.UpdateMessageId):
random_id_to_id[update.random_id] = update.id

View File

@ -8,6 +8,7 @@ from typing import (
Callable,
List,
Optional,
Sequence,
Type,
TypeVar,
)
@ -103,8 +104,8 @@ def process_socket_updates(client: Client, all_updates: List[abcs.Updates]) -> N
def extend_update_queue(
client: Client,
updates: List[abcs.Update],
users: List[abcs.User],
chats: List[abcs.Chat],
users: Sequence[abcs.User],
chats: Sequence[abcs.Chat],
) -> None:
chat_map = build_chat_map(client, users, chats)

View File

@ -91,21 +91,22 @@ async def get_chats(self: Client, chats: Sequence[ChatLike]) -> List[Chat]:
input_channels.append(packed._to_input_channel())
if input_users:
users = await self(functions.users.get_users(id=input_users))
ret_users = await self(functions.users.get_users(id=input_users))
users = list(ret_users)
else:
users = []
if input_chats:
ret_chats = await self(functions.messages.get_chats(id=input_chats))
assert isinstance(ret_chats, types.messages.Chats)
groups = ret_chats.chats
groups = list(ret_chats.chats)
else:
groups = []
if input_channels:
ret_chats = await self(functions.channels.get_channels(id=input_channels))
assert isinstance(ret_chats, types.messages.Chats)
channels = ret_chats.chats
channels = list(ret_chats.chats)
else:
channels = []
@ -133,7 +134,7 @@ async def resolve_to_packed(
ty = PackedType.USER
elif isinstance(chat, Group):
ty = PackedType.MEGAGROUP if chat.is_megagroup else PackedType.CHAT
elif isinstance(chat, Channel):
else:
ty = PackedType.BROADCAST
return PackedChat(ty=ty, id=chat.id, access_hash=0)

View File

@ -1,6 +1,6 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Dict, List, Optional, Self, Union
from typing import TYPE_CHECKING, Dict, Optional, Self, Sequence, Union
from ...tl import abcs, types
from ..types import Chat, Message, expand_peer, peer_id
@ -69,7 +69,7 @@ class MessageDeleted(Event):
The chat is only known when the deletion occurs in broadcast channels or supergroups.
"""
def __init__(self, msg_ids: List[int], channel_id: Optional[int]) -> None:
def __init__(self, msg_ids: Sequence[int], channel_id: Optional[int]) -> None:
self._msg_ids = msg_ids
self._channel_id = channel_id
@ -85,7 +85,7 @@ class MessageDeleted(Event):
return None
@property
def message_ids(self) -> List[int]:
def message_ids(self) -> Sequence[int]:
"""
The message identifiers of the messages that were deleted.
"""

View File

@ -40,7 +40,7 @@ class ButtonCallback(Event):
@property
def data(self) -> bytes:
assert self._raw.data is not None
assert isinstance(self._raw.data, bytes)
return self._raw.data
async def answer(

View File

@ -1,7 +1,19 @@
from collections import deque
from html import escape
from html.parser import HTMLParser
from typing import Any, Deque, Dict, Iterable, List, Optional, Tuple, Type, cast
from typing import (
Any,
Callable,
Deque,
Dict,
Iterable,
List,
Optional,
Tuple,
Type,
Union,
cast,
)
from ...tl.abcs import MessageEntity
from ...tl.types import (
@ -30,13 +42,11 @@ class HTMLToTelegramParser(HTMLParser):
self._open_tags: Deque[str] = deque()
self._open_tags_meta: Deque[Optional[str]] = deque()
def handle_starttag(
self, tag: str, attrs_seq: List[Tuple[str, Optional[str]]]
) -> None:
def handle_starttag(self, tag: str, attrs: List[Tuple[str, Optional[str]]]) -> None:
self._open_tags.appendleft(tag)
self._open_tags_meta.appendleft(None)
attrs = dict(attrs_seq)
attributes = dict(attrs)
EntityType: Optional[Type[MessageEntity]] = None
args = {}
if tag == "strong" or tag == "b":
@ -61,7 +71,7 @@ class HTMLToTelegramParser(HTMLParser):
# inside <pre> tags
pre = self._building_entities["pre"]
assert isinstance(pre, MessageEntityPre)
if cls := attrs.get("class"):
if cls := attributes.get("class"):
pre.language = cls[len("language-") :]
except KeyError:
EntityType = MessageEntityCode
@ -69,7 +79,7 @@ class HTMLToTelegramParser(HTMLParser):
EntityType = MessageEntityPre
args["language"] = ""
elif tag == "a":
url = attrs.get("href")
url = attributes.get("href")
if not url:
return
if url.startswith("mailto:"):
@ -94,18 +104,18 @@ class HTMLToTelegramParser(HTMLParser):
**args,
)
def handle_data(self, text: str) -> None:
def handle_data(self, data: str) -> None:
previous_tag = self._open_tags[0] if len(self._open_tags) > 0 else ""
if previous_tag == "a":
url = self._open_tags_meta[0]
if url:
text = url
data = url
for entity in self._building_entities.values():
assert hasattr(entity, "length")
entity.length += len(text)
setattr(entity, "length", getattr(entity, "length", 0) + len(data))
self.text += text
self.text += data
def handle_endtag(self, tag: str) -> None:
try:
@ -114,7 +124,7 @@ class HTMLToTelegramParser(HTMLParser):
except IndexError:
pass
entity = self._building_entities.pop(tag, None)
if entity and hasattr(entity, "length") and entity.length:
if entity and getattr(entity, "length", None):
self.entities.append(entity)
@ -135,7 +145,9 @@ def parse(html: str) -> Tuple[str, List[MessageEntity]]:
return del_surrogate(text), parser.entities
ENTITY_TO_FORMATTER = {
ENTITY_TO_FORMATTER: Dict[
Type[MessageEntity], Union[Tuple[str, str], Callable[[Any, str], Tuple[str, str]]]
] = {
MessageEntityBold: ("<strong>", "</strong>"),
MessageEntityItalic: ("<em>", "</em>"),
MessageEntityCode: ("<code>", "</code>"),
@ -173,18 +185,20 @@ def unparse(text: str, entities: Iterable[MessageEntity]) -> str:
text = add_surrogate(text)
insert_at: List[Tuple[int, str]] = []
for entity in entities:
assert hasattr(entity, "offset") and hasattr(entity, "length")
s = entity.offset
e = entity.offset + entity.length
delimiter = ENTITY_TO_FORMATTER.get(type(entity), None)
for e in entities:
offset, length = getattr(e, "offset", None), getattr(e, "length", None)
assert isinstance(offset, int) and isinstance(length, int)
h = offset
t = offset + length
delimiter = ENTITY_TO_FORMATTER.get(type(e), None)
if delimiter:
if callable(delimiter):
delim = delimiter(entity, text[s:e])
delim = delimiter(e, text[h:t])
else:
delim = delimiter
insert_at.append((s, delim[0]))
insert_at.append((e, delim[1]))
insert_at.append((h, delim[0]))
insert_at.append((t, delim[1]))
insert_at.sort(key=lambda t: t[0])
next_escape_bound = len(text)

View File

@ -1,5 +1,5 @@
import re
from typing import Any, Iterator, List, Tuple
from typing import Any, Dict, Iterator, List, Tuple, Type
import markdown_it
import markdown_it.token
@ -19,7 +19,7 @@ from ...tl.types import (
from .strings import add_surrogate, del_surrogate, within_surrogate
MARKDOWN = markdown_it.MarkdownIt().enable("strikethrough")
DELIMITERS = {
DELIMITERS: Dict[Type[MessageEntity], Tuple[str, str]] = {
MessageEntityBlockquote: ("> ", ""),
MessageEntityBold: ("**", "**"),
MessageEntityCode: ("`", "`"),
@ -81,7 +81,9 @@ def parse(message: str) -> Tuple[str, List[MessageEntity]]:
else:
for entity in reversed(entities):
if isinstance(entity, ty):
entity.length = len(message) - entity.offset
setattr(
entity, "length", len(message) - getattr(entity, "offset", 0)
)
break
parsed = MARKDOWN.parse(add_surrogate(message.strip()))
@ -156,24 +158,25 @@ def unparse(text: str, entities: List[MessageEntity]) -> str:
text = add_surrogate(text)
insert_at: List[Tuple[int, str]] = []
for entity in entities:
assert hasattr(entity, "offset")
assert hasattr(entity, "length")
s = entity.offset
e = entity.offset + entity.length
delimiter = DELIMITERS.get(type(entity), None)
for e in entities:
offset, length = getattr(e, "offset", None), getattr(e, "length", None)
assert isinstance(offset, int) and isinstance(length, int)
h = offset
t = offset + length
delimiter = DELIMITERS.get(type(e), None)
if delimiter:
insert_at.append((s, delimiter[0]))
insert_at.append((e, delimiter[1]))
elif isinstance(entity, MessageEntityPre):
insert_at.append((s, f"```{entity.language}\n"))
insert_at.append((e, "```\n"))
elif isinstance(entity, MessageEntityTextUrl):
insert_at.append((s, "["))
insert_at.append((e, f"]({entity.url})"))
elif isinstance(entity, MessageEntityMentionName):
insert_at.append((s, "["))
insert_at.append((e, f"](tg://user?id={entity.user_id})"))
insert_at.append((h, delimiter[0]))
insert_at.append((t, delimiter[1]))
elif isinstance(e, MessageEntityPre):
insert_at.append((h, f"```{e.language}\n"))
insert_at.append((t, "```\n"))
elif isinstance(e, MessageEntityTextUrl):
insert_at.append((h, "["))
insert_at.append((t, f"]({e.url})"))
elif isinstance(e, MessageEntityMentionName):
insert_at.append((h, "["))
insert_at.append((t, f"](tg://user?id={e.user_id})"))
insert_at.sort(key=lambda t: t[0])
while insert_at:

View File

@ -8,9 +8,11 @@ def add_surrogate(text: str) -> str:
return "".join(
# SMP -> Surrogate Pairs (Telegram offsets are calculated with these).
# See https://en.wikipedia.org/wiki/Plane_(Unicode)#Overview for more.
(
"".join(chr(y) for y in struct.unpack("<HH", x.encode("utf-16le")))
if (0x10000 <= ord(x) <= 0x10FFFF)
else x
)
for x in text
)
@ -43,32 +45,38 @@ def strip_text(text: str, entities: List[MessageEntity]) -> str:
if not entities:
return text.strip()
assert all(isinstance(getattr(e, "offset"), int) for e in entities)
while text and text[-1].isspace():
e = entities[-1]
assert hasattr(e, "offset") and hasattr(e, "length")
if e.offset + e.length == len(text):
if e.length == 1:
offset, length = getattr(e, "offset", None), getattr(e, "length", None)
assert isinstance(offset, int) and isinstance(length, int)
if offset + length == len(text):
if length == 1:
del entities[-1]
if not entities:
return text.strip()
else:
e.length -= 1
length -= 1
text = text[:-1]
while text and text[0].isspace():
for i in reversed(range(len(entities))):
e = entities[i]
assert hasattr(e, "offset") and hasattr(e, "length")
if e.offset != 0:
e.offset -= 1
offset, length = getattr(e, "offset", None), getattr(e, "length", None)
assert isinstance(offset, int) and isinstance(length, int)
if offset != 0:
setattr(e, "offset", offset - 1)
continue
if e.length == 1:
if length == 1:
del entities[0]
if not entities:
return text.lstrip()
else:
e.length -= 1
setattr(e, "length", length - 1)
text = text[1:]

View File

@ -29,6 +29,7 @@ class Callback(InlineButton):
This data will be received by :class:`telethon.events.ButtonCallback` when the button is pressed.
"""
assert isinstance(self._raw, types.KeyboardButtonCallback)
assert isinstance(self._raw.data, bytes)
return self._raw.data
@data.setter

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import itertools
import sys
from collections import defaultdict
from typing import TYPE_CHECKING, DefaultDict, Dict, List, Optional, Union
from typing import TYPE_CHECKING, DefaultDict, Dict, List, Optional, Sequence, Union
from ....session import PackedChat
from ....tl import abcs, types
@ -19,13 +19,15 @@ ChatLike = Union[Chat, PackedChat, int, str]
def build_chat_map(
client: Client, users: List[abcs.User], chats: List[abcs.Chat]
client: Client, users: Sequence[abcs.User], chats: Sequence[abcs.Chat]
) -> Dict[int, Chat]:
users_iter = (User._from_raw(u) for u in users)
chats_iter = (
(
Channel._from_raw(c)
if isinstance(c, (types.Channel, types.ChannelForbidden)) and c.broadcast
else Group._from_raw(client, c)
)
for c in chats
)

View File

@ -5,7 +5,17 @@ import urllib.parse
from inspect import isawaitable
from io import BufferedWriter
from pathlib import Path
from typing import TYPE_CHECKING, Any, Coroutine, List, Optional, Protocol, Self, Union
from typing import (
TYPE_CHECKING,
Any,
Coroutine,
List,
Optional,
Protocol,
Self,
Sequence,
Union,
)
from ...tl import abcs, types
from .meta import NoPublicConstructor
@ -43,7 +53,7 @@ stripped_size_header = bytes.fromhex(
stripped_size_footer = bytes.fromhex("FFD9")
def expand_stripped_size(data: bytes) -> bytes:
def expand_stripped_size(data: bytes | bytearray | memoryview) -> bytes:
header = bytearray(stripped_size_header)
header[164] = data[1]
header[166] = data[2]
@ -87,13 +97,14 @@ class InFileLike(Protocol):
It's only used in function parameters.
"""
def read(self, n: int) -> Union[bytes, Coroutine[Any, Any, bytes]]:
def read(self, n: int, /) -> Union[bytes, Coroutine[Any, Any, bytes]]:
"""
Read from the file or buffer.
:param n:
Maximum amount of bytes that should be returned.
"""
raise NotImplementedError
class OutFileLike(Protocol):
@ -115,18 +126,21 @@ class OutFileLike(Protocol):
class OutWrapper:
__slots__ = ("_fd", "_owned")
__slots__ = ("_fd", "_owned_fd")
_fd: Union[OutFileLike, BufferedWriter]
_owned_fd: Optional[BufferedWriter]
def __init__(self, file: Union[str, Path, OutFileLike]):
if isinstance(file, str):
file = Path(file)
if isinstance(file, Path):
self._fd: Union[OutFileLike, BufferedWriter] = file.open("wb")
self._owned = True
self._fd = file.open("wb")
self._owned_fd = self._fd
else:
self._fd = file
self._owned = False
self._owned_fd = None
async def write(self, chunk: bytes) -> None:
ret = self._fd.write(chunk)
@ -134,9 +148,9 @@ class OutWrapper:
await ret
def close(self) -> None:
if self._owned:
assert hasattr(self._fd, "close")
self._fd.close()
if self._owned_fd is not None:
assert hasattr(self._owned_fd, "close")
self._owned_fd.close()
class File(metaclass=NoPublicConstructor):
@ -150,7 +164,7 @@ class File(metaclass=NoPublicConstructor):
def __init__(
self,
*,
attributes: List[abcs.DocumentAttribute],
attributes: Sequence[abcs.DocumentAttribute],
size: int,
name: str,
mime: str,
@ -158,7 +172,7 @@ class File(metaclass=NoPublicConstructor):
muted: bool,
input_media: abcs.InputMedia,
thumb: Optional[abcs.PhotoSize],
thumbs: Optional[List[abcs.PhotoSize]],
thumbs: Optional[Sequence[abcs.PhotoSize]],
raw: Optional[Union[abcs.MessageMedia, abcs.Photo, abcs.Document]],
client: Optional[Client],
):
@ -405,9 +419,9 @@ class File(metaclass=NoPublicConstructor):
id=self._input_media.id.id,
access_hash=self._input_media.id.access_hash,
file_reference=self._input_media.id.file_reference,
thumb_size=self._thumb.type
if isinstance(self._thumb, thumb_types)
else "",
thumb_size=(
self._thumb.type if isinstance(self._thumb, thumb_types) else ""
),
)
elif isinstance(self._input_media, types.InputMediaPhoto):
assert isinstance(self._input_media.id, types.InputPhoto)

View File

@ -2,7 +2,17 @@ from __future__ import annotations
import datetime
import time
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Self, Tuple, Union
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Optional,
Self,
Sequence,
Tuple,
Union,
)
from ...tl import abcs, types
from ..parsers import (
@ -502,7 +512,7 @@ class Message(metaclass=NoPublicConstructor):
def build_msg_map(
client: Client, messages: List[abcs.Message], chat_map: Dict[int, Chat]
client: Client, messages: Sequence[abcs.Message], chat_map: Dict[int, Chat]
) -> Dict[int, Message]:
return {
msg.id: msg

View File

@ -1,8 +1,9 @@
"""
Class definitions stolen from `trio`, with some modifications.
"""
import abc
from typing import Type, TypeVar
from typing import Any, Type, TypeVar
T = TypeVar("T")
@ -28,7 +29,7 @@ class Final(abc.ABCMeta):
class NoPublicConstructor(Final):
def __call__(cls) -> None:
def __call__(cls, *args: Any, **kwds: Any) -> Any:
raise TypeError(
f"{cls.__module__}.{cls.__qualname__} has no public constructor"
)

View File

@ -100,12 +100,8 @@ class Participant(metaclass=NoPublicConstructor):
),
):
return self._raw.user_id
elif isinstance(
self._raw, (types.ChannelParticipantBanned, types.ChannelParticipantLeft)
):
return peer_id(self._raw.peer)
else:
raise RuntimeError("unexpected case")
return peer_id(self._raw.peer)
@property
def user(self) -> Optional[User]:

View File

@ -1,7 +1,28 @@
try:
import cryptg
def ige_encrypt(
plaintext: bytes | bytearray | memoryview, key: bytes, iv: bytes
) -> bytes: # noqa: F811
return cryptg.encrypt_ige(
bytes(plaintext) if not isinstance(plaintext, bytes) else plaintext, key, iv
)
def ige_decrypt(
ciphertext: bytes | bytearray | memoryview, key: bytes, iv: bytes
) -> bytes: # noqa: F811
return cryptg.decrypt_ige(
bytes(ciphertext) if not isinstance(ciphertext, bytes) else ciphertext,
key,
iv,
)
except ImportError:
import pyaes
def ige_encrypt(plaintext: bytes, key: bytes, iv: bytes) -> bytes:
def ige_encrypt(
plaintext: bytes | bytearray | memoryview, key: bytes, iv: bytes
) -> bytes:
assert len(plaintext) % 16 == 0
assert len(iv) == 32
@ -26,8 +47,9 @@ def ige_encrypt(plaintext: bytes, key: bytes, iv: bytes) -> bytes:
return bytes(ciphertext)
def ige_decrypt(ciphertext: bytes, key: bytes, iv: bytes) -> bytes:
def ige_decrypt(
ciphertext: bytes | bytearray | memoryview, key: bytes, iv: bytes
) -> bytes:
assert len(ciphertext) % 16 == 0
assert len(iv) == 32
@ -51,22 +73,3 @@ def ige_decrypt(ciphertext: bytes, key: bytes, iv: bytes) -> bytes:
plaintext += plaintext_block
return bytes(plaintext)
try:
import cryptg
def ige_encrypt(plaintext: bytes, key: bytes, iv: bytes) -> bytes: # noqa: F811
return cryptg.encrypt_ige(
bytes(plaintext) if not isinstance(plaintext, bytes) else plaintext, key, iv
)
def ige_decrypt(ciphertext: bytes, key: bytes, iv: bytes) -> bytes: # noqa: F811
return cryptg.decrypt_ige(
bytes(ciphertext) if not isinstance(ciphertext, bytes) else ciphertext,
key,
iv,
)
except ImportError:
pass

View File

@ -1,7 +1,7 @@
import os
from collections import namedtuple
from enum import IntEnum
from hashlib import sha1, sha256
from typing import NamedTuple
from .aes import ige_decrypt, ige_encrypt
from .auth_key import AuthKey
@ -13,15 +13,19 @@ class Side(IntEnum):
SERVER = 8
CalcKey = namedtuple("CalcKey", ("key", "iv"))
class CalcKey(NamedTuple):
key: bytes
iv: bytes
# https://core.telegram.org/mtproto/description#defining-aes-key-and-initialization-vector
def calc_key(auth_key: AuthKey, msg_key: bytes, side: Side) -> CalcKey:
def calc_key(
auth_key: AuthKey, msg_key: bytes | bytearray | memoryview, side: Side
) -> CalcKey:
x = int(side)
# sha256_a = SHA256 (msg_key + substr (auth_key, x, 36))
sha256_a = sha256(msg_key + auth_key.data[x : x + 36]).digest()
sha256_a = sha256(bytes(msg_key) + auth_key.data[x : x + 36]).digest()
# sha256_b = SHA256 (substr (auth_key, 40+x, 36) + msg_key)
sha256_b = sha256(auth_key.data[x + 40 : x + 76] + msg_key).digest()
@ -66,7 +70,9 @@ def encrypt_data_v2(plaintext: bytes, auth_key: AuthKey) -> bytes:
return _do_encrypt_data_v2(plaintext, auth_key, random_padding)
def decrypt_data_v2(ciphertext: bytes, auth_key: AuthKey) -> bytes:
def decrypt_data_v2(
ciphertext: bytes | bytearray | memoryview, auth_key: AuthKey
) -> bytes:
side = Side.SERVER
x = int(side)

View File

@ -1,45 +1,58 @@
# Ported from https://github.com/Lonami/grammers/blob/d91dc82/lib/grammers-crypto/src/two_factor_auth.rs
from collections import namedtuple
from hashlib import pbkdf2_hmac, sha256
from typing import NamedTuple
from .factorize import factorize
TwoFactorAuth = namedtuple("TwoFactorAuth", ("m1", "g_a"))
class TwoFactorAuth(NamedTuple):
m1: bytes
g_a: bytes
def pad_to_256(data: bytes) -> bytes:
def pad_to_256(data: bytes | bytearray | memoryview) -> bytes:
return bytes(256 - len(data)) + data
# H(data) := sha256(data)
def h(*data: bytes) -> bytes:
def h(*data: bytes | bytearray | memoryview) -> bytes:
return sha256(b"".join(data)).digest()
# SH(data, salt) := H(salt | data | salt)
def sh(data: bytes, salt: bytes) -> bytes:
def sh(
data: bytes | bytearray | memoryview, salt: bytes | bytearray | memoryview
) -> bytes:
return h(salt, data, salt)
# PH1(password, salt1, salt2) := SH(SH(password, salt1), salt2)
def ph1(password: bytes, salt1: bytes, salt2: bytes) -> bytes:
def ph1(
password: bytes | bytearray | memoryview,
salt1: bytes | bytearray | memoryview,
salt2: bytes | bytearray | memoryview,
) -> bytes:
return sh(sh(password, salt1), salt2)
# PH2(password, salt1, salt2) := SH(pbkdf2(sha512, PH1(password, salt1, salt2), salt1, 100000), salt2)
def ph2(password: bytes, salt1: bytes, salt2: bytes) -> bytes:
def ph2(
password: bytes | bytearray | memoryview,
salt1: bytes | bytearray | memoryview,
salt2: bytes | bytearray | memoryview,
) -> bytes:
return sh(pbkdf2_hmac("sha512", ph1(password, salt1, salt2), salt1, 100000), salt2)
# https://core.telegram.org/api/srp
def calculate_2fa(
*,
salt1: bytes,
salt2: bytes,
salt1: bytes | bytearray | memoryview,
salt2: bytes | bytearray | memoryview,
g: int,
p: bytes,
g_b: bytes,
a: bytes,
p: bytes | bytearray | memoryview,
g_b: bytes | bytearray | memoryview,
a: bytes | bytearray | memoryview,
password: bytes,
) -> TwoFactorAuth:
big_p = int.from_bytes(p)
@ -100,16 +113,16 @@ def calculate_2fa(
return TwoFactorAuth(m1, g_a)
def check_p_len(p: bytes) -> bool:
def check_p_len(p: bytes | bytearray | memoryview) -> bool:
return len(p) == 256
def check_known_prime(p: bytes, g: int) -> bool:
def check_known_prime(p: bytes | bytearray | memoryview, g: int) -> bool:
good_prime = b"\xc7\x1c\xae\xb9\xc6\xb1\xc9\x04\x8elR/p\xf1?s\x98\r@#\x8e>!\xc1I4\xd07V=\x93\x0fH\x19\x8a\n\xa7\xc1@X\"\x94\x93\xd2%0\xf4\xdb\xfa3on\n\xc9%\x13\x95C\xae\xd4L\xce|7 \xfdQ\xf6\x94XpZ\xc6\x8c\xd4\xfekk\x13\xab\xdc\x97FQ)i2\x84T\xf1\x8f\xaf\x8cY_d$w\xfe\x96\xbb*\x94\x1d[\xcd\x1dJ\xc8\xccI\x88\x07\x08\xfa\x9b7\x8e<O:\x90`\xbe\xe6|\xf9\xa4\xa4\xa6\x95\x81\x10Q\x90~\x16'S\xb5k\x0fkA\r\xbat\xd8\xa8K*\x14\xb3\x14N\x0e\xf1(GT\xfd\x17\xed\x95\rYe\xb4\xb9\xddFX-\xb1\x17\x8d\x16\x9ck\xc4e\xb0\xd6\xff\x9c\xa3\x92\x8f\xef[\x9a\xe4\xe4\x18\xfc\x15\xe8>\xbe\xa0\xf8\x7f\xa9\xff^\xedp\x05\r\xed(I\xf4{\xf9Y\xd9V\x85\x0c\xe9)\x85\x1f\r\x81\x15\xf65\xb1\x05\xee.N\x15\xd0K$T\xbfoO\xad\xf04\xb1\x04\x03\x11\x9c\xd8\xe3\xb9/\xcc["
return p == good_prime and g in (3, 4, 5, 7)
def check_p_prime_and_subgroup(p: bytes, g: int) -> bool:
def check_p_prime_and_subgroup(p: bytes | bytearray | memoryview, g: int) -> bool:
if check_known_prime(p, g):
return True
@ -133,7 +146,7 @@ def check_p_prime_and_subgroup(p: bytes, g: int) -> bool:
return candidate and factorize((big_p - 1) // 2)[0] == 1
def check_p_and_g(p: bytes, g: int) -> bool:
def check_p_and_g(p: bytes | bytearray | memoryview, g: int) -> bool:
if not check_p_len(p):
return False

View File

@ -164,6 +164,7 @@ def _do_step3(
)
key, iv = generate_key_data_from_nonce(server_nonce, new_nonce)
assert isinstance(server_dh_params.encrypted_answer, bytes)
plain_text_answer = decrypt_ige(server_dh_params.encrypted_answer, key, iv)
got_answer_hash = plain_text_answer[:20]

View File

@ -260,7 +260,7 @@ class Encrypted(Mtp):
self._store_own_updates(result)
self._deserialization.append(RpcResult(msg_id, result))
def _store_own_updates(self, body: bytes) -> None:
def _store_own_updates(self, body: bytes | bytearray | memoryview) -> None:
constructor_id = struct.unpack_from("I", body)[0]
if constructor_id in UPDATE_IDS:
self._deserialization.append(Update(body))
@ -332,7 +332,7 @@ class Encrypted(Mtp):
)
self._start_salt_time = (salts.now, self._adjusted_now())
self._salts = salts.salts
self._salts = list(salts.salts)
self._salts.sort(key=lambda salt: -salt.valid_since)
def _handle_future_salt(self, message: Message) -> None:
@ -439,7 +439,9 @@ class Encrypted(Mtp):
msg_id, buffer = result
return msg_id, encrypt_data_v2(buffer, self._auth_key)
def deserialize(self, payload: bytes) -> List[Deserialization]:
def deserialize(
self, payload: bytes | bytearray | memoryview
) -> List[Deserialization]:
check_message_buffer(payload)
plaintext = decrypt_data_v2(payload, self._auth_key)

View File

@ -31,7 +31,9 @@ class Plain(Mtp):
self._buffer.clear()
return MsgId(0), result
def deserialize(self, payload: bytes) -> List[Deserialization]:
def deserialize(
self, payload: bytes | bytearray | memoryview
) -> List[Deserialization]:
check_message_buffer(payload)
auth_key_id, msg_id, length = struct.unpack_from("<qqi", payload)

View File

@ -15,7 +15,7 @@ class Update:
__slots__ = ("body",)
def __init__(self, body: bytes):
def __init__(self, body: bytes | bytearray | memoryview):
self.body = body
@ -26,7 +26,7 @@ class RpcResult:
__slots__ = ("msg_id", "body")
def __init__(self, msg_id: MsgId, body: bytes):
def __init__(self, msg_id: MsgId, body: bytes | bytearray | memoryview):
self.msg_id = msg_id
self.body = body
@ -201,7 +201,9 @@ class Mtp(ABC):
"""
@abstractmethod
def deserialize(self, payload: bytes) -> List[Deserialization]:
def deserialize(
self, payload: bytes | bytearray | memoryview
) -> List[Deserialization]:
"""
Deserialize incoming buffer payload.
"""

View File

@ -12,7 +12,7 @@ class Transport(ABC):
pass
@abstractmethod
def unpack(self, input: bytes, output: bytearray) -> int:
def unpack(self, input: bytes | bytearray | memoryview, output: bytearray) -> int:
pass

View File

@ -36,7 +36,7 @@ class Abridged(Transport):
write(struct.pack("<i", 0x7F | (length << 8)))
write(input)
def unpack(self, input: bytes, output: bytearray) -> int:
def unpack(self, input: bytes | bytearray | memoryview, output: bytearray) -> int:
if not input:
raise MissingBytes(expected=1, got=0)

View File

@ -35,7 +35,7 @@ class Full(Transport):
write(struct.pack("<I", crc32(tmp)))
self._send_seq += 1
def unpack(self, input: bytes, output: bytearray) -> int:
def unpack(self, input: bytes | bytearray | memoryview, output: bytearray) -> int:
if len(input) < 4:
raise MissingBytes(expected=4, got=len(input))

View File

@ -32,7 +32,7 @@ class Intermediate(Transport):
write(struct.pack("<i", len(input)))
write(input)
def unpack(self, input: bytes, output: bytearray) -> int:
def unpack(self, input: bytes | bytearray | memoryview, output: bytearray) -> int:
if len(input) < 4:
raise MissingBytes(expected=4, got=len(input))

View File

@ -10,7 +10,7 @@ CONTAINER_MAX_LENGTH = 100
MESSAGE_SIZE_OVERHEAD = 8 + 4 + 4 # msg_id, seq_no, bytes
def check_message_buffer(message: bytes) -> None:
def check_message_buffer(message: bytes | bytearray | memoryview) -> None:
if len(message) < 20:
raise ValueError(
f"server payload is too small to be a valid message: {message.hex()}"

View File

@ -70,6 +70,7 @@ class AsyncReader(Protocol):
:param n:
Amount of bytes to read at most.
"""
raise NotImplementedError
class AsyncWriter(Protocol):
@ -77,7 +78,7 @@ class AsyncWriter(Protocol):
A :class:`asyncio.StreamWriter`-like class.
"""
def write(self, data: bytes) -> None:
def write(self, data: bytes | bytearray | memoryview) -> None:
"""
Must behave like :meth:`asyncio.StreamWriter.write`.
@ -127,7 +128,7 @@ class Connector(Protocol):
"""
async def __call__(self, ip: str, port: int) -> Tuple[AsyncReader, AsyncWriter]:
pass
raise NotImplementedError
class RequestState(ABC):
@ -340,12 +341,12 @@ class Sender:
self._process_result(result)
elif isinstance(result, RpcError):
self._process_error(result)
elif isinstance(result, BadMessage):
self._process_bad_message(result)
else:
raise RuntimeError("unexpected case")
self._process_bad_message(result)
def _process_update(self, updates: List[Updates], update: bytes) -> None:
def _process_update(
self, updates: List[Updates], update: bytes | bytearray | memoryview
) -> None:
try:
updates.append(Updates.from_bytes(update))
except ValueError:

View File

@ -1,4 +1,4 @@
from typing import Dict, List, Optional, Tuple
from typing import Any, Dict, Optional, Sequence, Tuple
from ...tl import abcs, types
from .packed import PackedChat, PackedType
@ -131,7 +131,7 @@ class ChatHashCache:
else:
raise RuntimeError("unexpected case")
def extend(self, users: List[abcs.User], chats: List[abcs.Chat]) -> bool:
def extend(self, users: Sequence[abcs.User], chats: Sequence[abcs.Chat]) -> bool:
# See https://core.telegram.org/api/min for "issues" with "min constructors".
success = True
@ -299,6 +299,7 @@ class ChatHashCache:
success &= self._has_notify_peer(peer)
for field in ("users",):
user: Any
users = getattr(message.action, field, None)
if isinstance(users, list):
for user in users:

View File

@ -4,11 +4,35 @@ from typing import List, Literal, Union
from ...tl import abcs
NO_DATE = 0 # used on adapted messages.affected* from lower layers
NO_SEQ = 0
NO_PTS = 0
# https://core.telegram.org/method/updates.getChannelDifference
BOT_CHANNEL_DIFF_LIMIT = 100000
USER_CHANNEL_DIFF_LIMIT = 100
POSSIBLE_GAP_TIMEOUT = 0.5
# https://core.telegram.org/api/updates
NO_UPDATES_TIMEOUT = 15 * 60
ENTRY_ACCOUNT: Literal["ACCOUNT"] = "ACCOUNT"
ENTRY_SECRET: Literal["SECRET"] = "SECRET"
Entry = Union[Literal["ACCOUNT", "SECRET"], int]
# Python's logging doesn't define a TRACE level. Pick halfway between DEBUG and NOTSET.
# We don't define a name for this as libraries shouldn't do that though.
LOG_LEVEL_TRACE = (logging.DEBUG - logging.NOTSET) // 2
class PtsInfo:
__slots__ = ("pts", "pts_count", "entry")
def __init__(self, entry: "Entry", pts: int, pts_count: int) -> None:
entry: Entry # pyright needs this or it infers int | str
def __init__(self, entry: Entry, pts: int, pts_count: int) -> None:
self.pts = pts
self.pts_count = pts_count
self.entry = entry
@ -59,26 +83,3 @@ class PrematureEndReason(Enum):
class Gap(ValueError):
def __repr__(self) -> str:
return "Gap()"
NO_DATE = 0 # used on adapted messages.affected* from lower layers
NO_SEQ = 0
NO_PTS = 0
# https://core.telegram.org/method/updates.getChannelDifference
BOT_CHANNEL_DIFF_LIMIT = 100000
USER_CHANNEL_DIFF_LIMIT = 100
POSSIBLE_GAP_TIMEOUT = 0.5
# https://core.telegram.org/api/updates
NO_UPDATES_TIMEOUT = 15 * 60
ENTRY_ACCOUNT: Literal["ACCOUNT"] = "ACCOUNT"
ENTRY_SECRET: Literal["SECRET"] = "SECRET"
Entry = Union[Literal["ACCOUNT", "SECRET"], int]
# Python's logging doesn't define a TRACE level. Pick halfway between DEBUG and NOTSET.
# We don't define a name for this as libraries shouldn't do that though.
LOG_LEVEL_TRACE = (logging.DEBUG - logging.NOTSET) // 2

View File

@ -2,7 +2,7 @@ import asyncio
import datetime
import logging
import time
from typing import Dict, List, Optional, Set, Tuple
from typing import Dict, List, Optional, Sequence, Set, Tuple
from ...tl import Request, abcs, functions, types
from ..chat import ChatHashCache
@ -168,6 +168,7 @@ class MessageBox:
if not entries:
return
entry: Entry = ENTRY_ACCOUNT # for pyright to know it's not unbound
for entry in entries:
if entry not in self.map:
raise RuntimeError(
@ -258,7 +259,7 @@ class MessageBox:
self,
updates: abcs.Updates,
chat_hashes: ChatHashCache,
) -> Tuple[List[abcs.Update], List[abcs.User], List[abcs.Chat]]:
) -> Tuple[List[abcs.Update], Sequence[abcs.User], Sequence[abcs.Chat]]:
result: List[abcs.Update] = []
combined = adapt(updates, chat_hashes)
@ -289,7 +290,7 @@ class MessageBox:
sorted_updates = list(sorted(combined.updates, key=update_sort_key))
any_pts_applied = False
reset_deadlines_for = set()
reset_deadlines_for: Set[Entry] = set()
for update in sorted_updates:
entry, applied = self.apply_pts_info(update)
if entry is not None:
@ -420,9 +421,11 @@ class MessageBox:
pts_limit=None,
pts_total_limit=None,
date=int(self.date.timestamp()),
qts=self.map[ENTRY_SECRET].pts
qts=(
self.map[ENTRY_SECRET].pts
if ENTRY_SECRET in self.map
else NO_SEQ,
else NO_SEQ
),
qts_limit=None,
)
if __debug__:
@ -435,12 +438,12 @@ class MessageBox:
self,
diff: abcs.updates.Difference,
chat_hashes: ChatHashCache,
) -> Tuple[List[abcs.Update], List[abcs.User], List[abcs.Chat]]:
) -> Tuple[List[abcs.Update], Sequence[abcs.User], Sequence[abcs.Chat]]:
if __debug__:
self._trace("applying account difference: %s", diff)
finish: bool
result: Tuple[List[abcs.Update], List[abcs.User], List[abcs.Chat]]
result: Tuple[List[abcs.Update], Sequence[abcs.User], Sequence[abcs.Chat]]
if isinstance(diff, types.updates.DifferenceEmpty):
finish = True
self.date = datetime.datetime.fromtimestamp(
@ -493,7 +496,7 @@ class MessageBox:
self,
diff: types.updates.Difference,
chat_hashes: ChatHashCache,
) -> Tuple[List[abcs.Update], List[abcs.User], List[abcs.Chat]]:
) -> Tuple[List[abcs.Update], Sequence[abcs.User], Sequence[abcs.Chat]]:
state = diff.state
assert isinstance(state, types.updates.State)
self.map[ENTRY_ACCOUNT].pts = state.pts
@ -556,9 +559,11 @@ class MessageBox:
channel=channel,
filter=types.ChannelMessagesFilterEmpty(),
pts=state.pts,
limit=BOT_CHANNEL_DIFF_LIMIT
limit=(
BOT_CHANNEL_DIFF_LIMIT
if chat_hashes.is_self_bot
else USER_CHANNEL_DIFF_LIMIT,
else USER_CHANNEL_DIFF_LIMIT
),
)
if __debug__:
self._trace("requesting channel difference: %s", gd)
@ -577,7 +582,7 @@ class MessageBox:
channel_id: int,
diff: abcs.updates.ChannelDifference,
chat_hashes: ChatHashCache,
) -> Tuple[List[abcs.Update], List[abcs.User], List[abcs.Chat]]:
) -> Tuple[List[abcs.Update], Sequence[abcs.User], Sequence[abcs.Chat]]:
entry: Entry = channel_id
if __debug__:
self._trace("applying channel=%r difference: %s", entry, diff)

View File

@ -4,7 +4,7 @@ from .memory import MemorySession
from .storage import Storage
try:
from .sqlite import SqliteSession
from .sqlite import SqliteSession # pyright: ignore [reportAssignmentType]
except ImportError as e:
import_err = e

View File

@ -32,7 +32,7 @@ def _bootstrap_get_ty(constructor_id: int) -> Optional[Type["Serializable"]]:
class Reader:
__slots__ = ("_view", "_pos", "_len")
def __init__(self, buffer: bytes | memoryview) -> None:
def __init__(self, buffer: bytes | bytearray | memoryview) -> None:
self._view = (
memoryview(buffer) if not isinstance(buffer, memoryview) else buffer
)

View File

@ -9,10 +9,10 @@ class HasSlots(Protocol):
__slots__: Tuple[str, ...]
def obj_repr(obj: HasSlots) -> str:
fields = ((attr, getattr(obj, attr)) for attr in obj.__slots__)
def obj_repr(self: HasSlots) -> str:
fields = ((attr, getattr(self, attr)) for attr in self.__slots__)
params = ", ".join(f"{name}={field!r}" for name, field in fields)
return f"{obj.__class__.__name__}({params})"
return f"{self.__class__.__name__}({params})"
class Serializable(abc.ABC):
@ -36,7 +36,7 @@ class Serializable(abc.ABC):
pass
@classmethod
def from_bytes(cls, blob: bytes) -> Self:
def from_bytes(cls, blob: bytes | bytearray | memoryview) -> Self:
return Reader(blob).read_serializable(cls)
def __bytes__(self) -> bytes:
@ -54,7 +54,7 @@ class Serializable(abc.ABC):
)
def serialize_bytes_to(buffer: bytearray, data: bytes) -> None:
def serialize_bytes_to(buffer: bytearray, data: bytes | bytearray | memoryview) -> None:
length = len(data)
if length < 0xFE:
buffer += struct.pack("<B", length)

View File

@ -65,7 +65,7 @@ def test_key_from_nonce() -> None:
server_nonce = int.from_bytes(bytes(range(16)))
new_nonce = int.from_bytes(bytes(range(32)))
(key, iv) = generate_key_data_from_nonce(server_nonce, new_nonce)
key, iv = generate_key_data_from_nonce(server_nonce, new_nonce)
assert (
key
== b'\x07X\xf1S;a]$\xf6\xe8\xa9Jo\xcb\xee\nU\xea\xab"\x17\xd7)\\\xa9!=\x1a-}\x16\xa6'

View File

@ -7,7 +7,7 @@ from telethon._impl.mtproto import Abridged
class Output(bytearray):
__slots__ = ()
def __call__(self, data: bytes) -> None:
def __call__(self, data: bytes | bytearray | memoryview) -> None:
self += data

View File

@ -7,7 +7,7 @@ from telethon._impl.mtproto import Full
class Output(bytearray):
__slots__ = ()
def __call__(self, data: bytes) -> None:
def __call__(self, data: bytes | bytearray | memoryview) -> None:
self += data
@ -20,7 +20,7 @@ def setup_unpack(n: int) -> Tuple[bytes, Full, bytes, bytearray]:
transport, expected_output, input = setup_pack(n)
transport.pack(expected_output, input)
return expected_output, Full(), input, bytearray()
return expected_output, Full(), bytes(input), bytearray()
def test_pack_empty() -> None:

View File

@ -7,7 +7,7 @@ from telethon._impl.mtproto import Intermediate
class Output(bytearray):
__slots__ = ()
def __call__(self, data: bytes) -> None:
def __call__(self, data: bytes | bytearray | memoryview) -> None:
self += data

View File

@ -46,14 +46,14 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
ignored_types = {"true", "boolTrue", "boolFalse"} # also "compiler built-ins"
abc_namespaces = set()
type_namespaces = set()
function_namespaces = set()
abc_namespaces: Set[str] = set()
type_namespaces: Set[str] = set()
function_namespaces: Set[str] = set()
abc_class_names = set()
type_class_names = set()
function_def_names = set()
generated_type_names = set()
abc_class_names: Set[str] = set()
type_class_names: Set[str] = set()
function_def_names: Set[str] = set()
generated_type_names: Set[str] = set()
for typedef in tl.typedefs:
if typedef.ty.full_name not in generated_types:
@ -67,6 +67,7 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
abc_path = Path("abcs/_nons.py")
if abc_path not in fs:
fs.write(abc_path, "# pyright: reportUnusedImport=false\n")
fs.write(abc_path, "from abc import ABCMeta\n")
fs.write(abc_path, "from ..core import Serializable\n")
@ -93,11 +94,14 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
writer = fs.open(type_path)
if type_path not in fs:
writer.write(
"# pyright: reportUnusedImport=false, reportConstantRedefinition=false"
)
writer.write("import struct")
writer.write("from typing import List, Optional, Self, Sequence")
writer.write("from typing import Optional, Self, Sequence")
writer.write("from .. import abcs")
writer.write("from ..core import Reader, Serializable, serialize_bytes_to")
writer.write("_bytes = bytes")
writer.write("_bytes = bytes | bytearray | memoryview")
ns = f"{typedef.namespace[0]}." if typedef.namespace else ""
generated_type_names.add(f"{ns}{to_class_name(typedef.name)}")
@ -113,14 +117,13 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
# def constructor_id()
writer.write(" @classmethod")
writer.write(" def constructor_id(_) -> int:")
writer.write(" def constructor_id(cls) -> int:")
writer.write(f" return {hex(typedef.id)}")
# def __init__()
if property_params:
params = "".join(
f", {p.name}: {param_type_fmt(p.ty, immutable=False)}"
for p in property_params
f", {p.name}: {param_type_fmt(p.ty)}" for p in property_params
)
writer.write(f" def __init__(_s, *{params}) -> None:")
for p in property_params:
@ -159,22 +162,18 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
writer = fs.open(function_path)
if function_path not in fs:
writer.write("# pyright: reportUnusedImport=false")
writer.write("import struct")
writer.write("from typing import List, Optional, Self, Sequence")
writer.write("from typing import Optional, Self, Sequence")
writer.write("from .. import abcs")
writer.write("from ..core import Request, serialize_bytes_to")
writer.write("_bytes = bytes")
writer.write("_bytes = bytes | bytearray | memoryview")
# def name(params, ...)
required_params = [p for p in functiondef.params if not is_computed(p.ty)]
params = "".join(
f", {p.name}: {param_type_fmt(p.ty, immutable=True)}"
for p in required_params
)
params = "".join(f", {p.name}: {param_type_fmt(p.ty)}" for p in required_params)
star = "*" if params else ""
return_ty = param_type_fmt(
NormalParameter(ty=functiondef.ty, flag=None), immutable=False
)
return_ty = param_type_fmt(NormalParameter(ty=functiondef.ty, flag=None))
writer.write(
f"def {to_method_name(functiondef.name)}({star}{params}) -> Request[{return_ty}]:"
)
@ -189,6 +188,7 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
)
writer = fs.open(Path("layer.py"))
writer.write("# pyright: reportUnusedImport=false")
writer.write("from . import abcs, types")
writer.write(
"from .core import Serializable, Reader, deserialize_bool, deserialize_i32_list, deserialize_i64_list, deserialize_identity, single_deserializer, list_deserializer"

View File

@ -86,7 +86,7 @@ def inner_type_fmt(ty: Type) -> str:
return f"abcs.{ns}{to_class_name(ty.name)}"
def param_type_fmt(ty: BaseParameter, *, immutable: bool) -> str:
def param_type_fmt(ty: BaseParameter) -> str:
if isinstance(ty, FlagsParameter):
return "int"
elif not isinstance(ty, NormalParameter):
@ -104,10 +104,7 @@ def param_type_fmt(ty: BaseParameter, *, immutable: bool) -> str:
res = "_bytes" if inner_ty.name == "Object" else inner_type_fmt(inner_ty)
if ty.ty.generic_arg:
if immutable:
res = f"Sequence[{res}]"
else:
res = f"List[{res}]"
if ty.flag and ty.ty.name != "true":
res = f"Optional[{res}]"

View File

@ -15,7 +15,8 @@ class ParsedTl:
def load_tl_file(path: Union[str, Path]) -> ParsedTl:
typedefs, functiondefs = [], []
typedefs: List[TypeDef] = []
functiondefs: List[FunctionDef] = []
with open(path, "r", encoding="utf-8") as fd:
contents = fd.read()
@ -31,10 +32,8 @@ def load_tl_file(path: Union[str, Path]) -> ParsedTl:
raise
elif isinstance(definition, TypeDef):
typedefs.append(definition)
elif isinstance(definition, FunctionDef):
functiondefs.append(definition)
else:
raise TypeError(f"unexpected type: {type(definition)}")
functiondefs.append(definition)
return ParsedTl(
layer=layer, typedefs=list(typedefs), functiondefs=list(functiondefs)

View File

@ -59,8 +59,8 @@ class Definition:
raise ValueError("invalid id")
type_defs: List[str] = []
flag_defs = []
params = []
flag_defs: List[str] = []
params: List[Parameter] = []
for param_str in middle.split():
try:

View File

@ -79,7 +79,7 @@ def test_recursive_vec() -> None:
"""
)
result = gen_py_code(typedefs=definitions)
assert "value: List[abcs.JsonObjectValue]" in result
assert "value: Sequence[abcs.JsonObjectValue]" in result
def test_object_blob_special_case() -> None:

View File

View File

@ -10,6 +10,7 @@ Imports of new definitions and formatting must be added with other tools.
Properties and private methods can use a different parameter name than `self`
to avoid being included.
"""
import ast
import subprocess
import sys
@ -31,6 +32,8 @@ class FunctionMethodsVisitor(ast.NodeVisitor):
match node.args.args:
case [ast.arg(arg="self", annotation=ast.Name(id="Client")), *_]:
self.methods.append(node)
case _:
pass
class MethodVisitor(ast.NodeVisitor):
@ -59,6 +62,8 @@ class MethodVisitor(ast.NodeVisitor):
match node.body:
case [ast.Expr(value=ast.Constant(value=str(doc))), *_]:
self.method_docs[node.name] = doc
case _:
pass
def main() -> None:
@ -81,10 +86,10 @@ def main() -> None:
m_visitor.visit(ast.parse(contents))
class_body = []
class_body: List[ast.stmt] = []
for function in sorted(fm_visitor.methods, key=lambda f: f.name):
function.body = []
function_body: List[ast.stmt] = []
if doc := m_visitor.method_docs.get(function.name):
function.body.append(ast.Expr(value=ast.Constant(value=doc)))
@ -108,7 +113,7 @@ def main() -> None:
case _:
call = ast.Return(value=call)
function.body.append(call)
function.body = function_body
class_body.append(function)
generated = ast.unparse(

View File

@ -2,6 +2,7 @@
Scan the `client/` directory for __init__.py files.
For every depth-1 import, add it, in order, to the __all__ variable.
"""
import ast
import os
import re
@ -25,7 +26,7 @@ def main() -> None:
rf"(tl|mtproto){re.escape(os.path.sep)}(abcs|functions|types)"
)
files = []
files: List[str] = []
for file in impl_root.rglob("__init__.py"):
file_str = str(file)
if autogenerated_re.search(file_str):
@ -52,6 +53,8 @@ def main() -> None:
+ "]\n"
]
break
case _:
pass
with file.open("w", encoding="utf-8", newline="\n") as fd:
fd.writelines(lines)

19
typings/setuptools.pyi Normal file
View File

@ -0,0 +1,19 @@
from typing import Any, Dict, Optional
class build_meta:
@staticmethod
def build_wheel(
wheel_directory: str,
config_settings: Optional[Dict[Any, Any]] = None,
metadata_directory: Optional[str] = None,
) -> str: ...
@staticmethod
def build_sdist(
sdist_directory: str, config_settings: Optional[Dict[Any, Any]] = None
) -> str: ...
@staticmethod
def build_editable(
wheel_directory: str,
config_settings: Optional[Dict[Any, Any]] = None,
metadata_directory: Optional[str] = None,
) -> str: ...