mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2024-11-24 18:33:44 +03:00
Make pyright happy
This commit is contained in:
parent
854096e9d3
commit
033b56f1d3
|
@ -25,6 +25,7 @@ def serialize_builtin(value: Any) -> bytes:
|
|||
|
||||
|
||||
def overhead(obj: Obj) -> None:
|
||||
x: Any
|
||||
for v in obj.__dict__.values():
|
||||
for x in v if isinstance(v, list) else [v]:
|
||||
if isinstance(x, Obj):
|
||||
|
@ -34,6 +35,7 @@ def overhead(obj: Obj) -> None:
|
|||
|
||||
|
||||
def strategy_concat(obj: Obj) -> bytes:
|
||||
x: Any
|
||||
res = b""
|
||||
for v in obj.__dict__.values():
|
||||
for x in v if isinstance(v, list) else [v]:
|
||||
|
@ -45,6 +47,7 @@ def strategy_concat(obj: Obj) -> bytes:
|
|||
|
||||
|
||||
def strategy_append(obj: Obj) -> bytes:
|
||||
x: Any
|
||||
res = bytearray()
|
||||
for v in obj.__dict__.values():
|
||||
for x in v if isinstance(v, list) else [v]:
|
||||
|
@ -57,6 +60,7 @@ def strategy_append(obj: Obj) -> bytes:
|
|||
|
||||
def strategy_append_reuse(obj: Obj) -> bytes:
|
||||
def do_append(o: Obj, res: bytearray) -> None:
|
||||
x: Any
|
||||
for v in o.__dict__.values():
|
||||
for x in v if isinstance(v, list) else [v]:
|
||||
if isinstance(x, Obj):
|
||||
|
@ -70,15 +74,18 @@ def strategy_append_reuse(obj: Obj) -> bytes:
|
|||
|
||||
|
||||
def strategy_join(obj: Obj) -> bytes:
|
||||
return b"".join(
|
||||
strategy_join(x) if isinstance(x, Obj) else serialize_builtin(x)
|
||||
for v in obj.__dict__.values()
|
||||
for x in (v if isinstance(v, list) else [v])
|
||||
)
|
||||
def iterator() -> Iterator[bytes]:
|
||||
x: Any
|
||||
for v in obj.__dict__.values():
|
||||
for x in v if isinstance(v, list) else [v]:
|
||||
yield strategy_join(x) if isinstance(x, Obj) else serialize_builtin(x)
|
||||
|
||||
return b"".join(iterator())
|
||||
|
||||
|
||||
def strategy_join_flat(obj: Obj) -> bytes:
|
||||
def flatten(o: Obj) -> Iterator[bytes]:
|
||||
x: Any
|
||||
for v in o.__dict__.values():
|
||||
for x in v if isinstance(v, list) else [v]:
|
||||
if isinstance(x, Obj):
|
||||
|
@ -91,6 +98,7 @@ def strategy_join_flat(obj: Obj) -> bytes:
|
|||
|
||||
def strategy_write(obj: Obj) -> bytes:
|
||||
def do_write(o: Obj, buffer: io.BytesIO) -> None:
|
||||
x: Any
|
||||
for v in o.__dict__.values():
|
||||
for x in v if isinstance(v, list) else [v]:
|
||||
if isinstance(x, Obj):
|
||||
|
|
|
@ -6,7 +6,7 @@ DATA = 42
|
|||
|
||||
|
||||
def overhead(n: int) -> None:
|
||||
n
|
||||
n = n
|
||||
|
||||
|
||||
def strategy_bool(n: int) -> bool:
|
||||
|
|
|
@ -3,7 +3,7 @@ from pathlib import Path
|
|||
from typing import Any, Dict, Optional
|
||||
|
||||
from setuptools import build_meta as _orig
|
||||
from setuptools.build_meta import * # noqa: F403
|
||||
from setuptools.build_meta import * # noqa: F403 # pyright: ignore [reportWildcardImportFromLibrary]
|
||||
|
||||
|
||||
def gen_types_if_needed() -> None:
|
||||
|
|
|
@ -223,6 +223,7 @@ async def check_password(
|
|||
|
||||
if not two_factor_auth.check_p_and_g(algo.p, algo.g):
|
||||
token = await get_password_information(self)
|
||||
algo = token._password.current_algo
|
||||
if not isinstance(
|
||||
algo,
|
||||
types.PasswordKdfAlgoSha256Sha256Pbkdf2HmacshA512Iter100000Sha256ModPow,
|
||||
|
|
|
@ -2059,5 +2059,4 @@ class Client:
|
|||
exc: Optional[BaseException],
|
||||
tb: Optional[TracebackType],
|
||||
) -> None:
|
||||
exc_type, exc, tb
|
||||
await disconnect(self)
|
||||
|
|
|
@ -444,6 +444,7 @@ class FileBytesList(AsyncList[bytes]):
|
|||
|
||||
if result.bytes:
|
||||
self._offset += MAX_CHUNK_SIZE
|
||||
assert isinstance(result.bytes, bytes)
|
||||
self._buffer.append(result.bytes)
|
||||
|
||||
self._done = len(result.bytes) < MAX_CHUNK_SIZE
|
||||
|
|
|
@ -39,11 +39,13 @@ async def send_message(
|
|||
noforwards=not text.can_forward,
|
||||
update_stickersets_order=False,
|
||||
peer=peer,
|
||||
reply_to=types.InputReplyToMessage(
|
||||
reply_to_msg_id=text.replied_message_id, top_msg_id=None
|
||||
)
|
||||
if text.replied_message_id
|
||||
else None,
|
||||
reply_to=(
|
||||
types.InputReplyToMessage(
|
||||
reply_to_msg_id=text.replied_message_id, top_msg_id=None
|
||||
)
|
||||
if text.replied_message_id
|
||||
else None
|
||||
),
|
||||
message=message,
|
||||
random_id=random_id,
|
||||
reply_markup=getattr(text._raw, "reply_markup", None),
|
||||
|
@ -63,11 +65,11 @@ async def send_message(
|
|||
noforwards=False,
|
||||
update_stickersets_order=False,
|
||||
peer=peer,
|
||||
reply_to=types.InputReplyToMessage(
|
||||
reply_to_msg_id=reply_to, top_msg_id=None
|
||||
)
|
||||
if reply_to
|
||||
else None,
|
||||
reply_to=(
|
||||
types.InputReplyToMessage(reply_to_msg_id=reply_to, top_msg_id=None)
|
||||
if reply_to
|
||||
else None
|
||||
),
|
||||
message=message,
|
||||
random_id=random_id,
|
||||
reply_markup=btns.build_keyboard(buttons),
|
||||
|
@ -83,19 +85,23 @@ async def send_message(
|
|||
{},
|
||||
out=result.out,
|
||||
id=result.id,
|
||||
from_id=types.PeerUser(user_id=self._session.user.id)
|
||||
if self._session.user
|
||||
else None,
|
||||
from_id=(
|
||||
types.PeerUser(user_id=self._session.user.id)
|
||||
if self._session.user
|
||||
else None
|
||||
),
|
||||
peer_id=packed._to_peer(),
|
||||
reply_to=types.MessageReplyHeader(
|
||||
reply_to_scheduled=False,
|
||||
forum_topic=False,
|
||||
reply_to_msg_id=reply_to,
|
||||
reply_to_peer_id=None,
|
||||
reply_to_top_id=None,
|
||||
)
|
||||
if reply_to
|
||||
else None,
|
||||
reply_to=(
|
||||
types.MessageReplyHeader(
|
||||
reply_to_scheduled=False,
|
||||
forum_topic=False,
|
||||
reply_to_msg_id=reply_to,
|
||||
reply_to_peer_id=None,
|
||||
reply_to_top_id=None,
|
||||
)
|
||||
if reply_to
|
||||
else None
|
||||
),
|
||||
date=result.date,
|
||||
message=message,
|
||||
media=result.media,
|
||||
|
@ -593,8 +599,8 @@ def build_message_map(
|
|||
else:
|
||||
return MessageMap(client, peer, {}, {})
|
||||
|
||||
random_id_to_id = {}
|
||||
id_to_message = {}
|
||||
random_id_to_id: Dict[int, int] = {}
|
||||
id_to_message: Dict[int, Message] = {}
|
||||
for update in updates:
|
||||
if isinstance(update, types.UpdateMessageId):
|
||||
random_id_to_id[update.random_id] = update.id
|
||||
|
|
|
@ -8,6 +8,7 @@ from typing import (
|
|||
Callable,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
TypeVar,
|
||||
)
|
||||
|
@ -103,8 +104,8 @@ def process_socket_updates(client: Client, all_updates: List[abcs.Updates]) -> N
|
|||
def extend_update_queue(
|
||||
client: Client,
|
||||
updates: List[abcs.Update],
|
||||
users: List[abcs.User],
|
||||
chats: List[abcs.Chat],
|
||||
users: Sequence[abcs.User],
|
||||
chats: Sequence[abcs.Chat],
|
||||
) -> None:
|
||||
chat_map = build_chat_map(client, users, chats)
|
||||
|
||||
|
|
|
@ -91,21 +91,22 @@ async def get_chats(self: Client, chats: Sequence[ChatLike]) -> List[Chat]:
|
|||
input_channels.append(packed._to_input_channel())
|
||||
|
||||
if input_users:
|
||||
users = await self(functions.users.get_users(id=input_users))
|
||||
ret_users = await self(functions.users.get_users(id=input_users))
|
||||
users = list(ret_users)
|
||||
else:
|
||||
users = []
|
||||
|
||||
if input_chats:
|
||||
ret_chats = await self(functions.messages.get_chats(id=input_chats))
|
||||
assert isinstance(ret_chats, types.messages.Chats)
|
||||
groups = ret_chats.chats
|
||||
groups = list(ret_chats.chats)
|
||||
else:
|
||||
groups = []
|
||||
|
||||
if input_channels:
|
||||
ret_chats = await self(functions.channels.get_channels(id=input_channels))
|
||||
assert isinstance(ret_chats, types.messages.Chats)
|
||||
channels = ret_chats.chats
|
||||
channels = list(ret_chats.chats)
|
||||
else:
|
||||
channels = []
|
||||
|
||||
|
@ -133,7 +134,7 @@ async def resolve_to_packed(
|
|||
ty = PackedType.USER
|
||||
elif isinstance(chat, Group):
|
||||
ty = PackedType.MEGAGROUP if chat.is_megagroup else PackedType.CHAT
|
||||
elif isinstance(chat, Channel):
|
||||
else:
|
||||
ty = PackedType.BROADCAST
|
||||
|
||||
return PackedChat(ty=ty, id=chat.id, access_hash=0)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Self, Union
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Self, Sequence, Union
|
||||
|
||||
from ...tl import abcs, types
|
||||
from ..types import Chat, Message, expand_peer, peer_id
|
||||
|
@ -69,7 +69,7 @@ class MessageDeleted(Event):
|
|||
The chat is only known when the deletion occurs in broadcast channels or supergroups.
|
||||
"""
|
||||
|
||||
def __init__(self, msg_ids: List[int], channel_id: Optional[int]) -> None:
|
||||
def __init__(self, msg_ids: Sequence[int], channel_id: Optional[int]) -> None:
|
||||
self._msg_ids = msg_ids
|
||||
self._channel_id = channel_id
|
||||
|
||||
|
@ -85,7 +85,7 @@ class MessageDeleted(Event):
|
|||
return None
|
||||
|
||||
@property
|
||||
def message_ids(self) -> List[int]:
|
||||
def message_ids(self) -> Sequence[int]:
|
||||
"""
|
||||
The message identifiers of the messages that were deleted.
|
||||
"""
|
||||
|
|
|
@ -40,7 +40,7 @@ class ButtonCallback(Event):
|
|||
|
||||
@property
|
||||
def data(self) -> bytes:
|
||||
assert self._raw.data is not None
|
||||
assert isinstance(self._raw.data, bytes)
|
||||
return self._raw.data
|
||||
|
||||
async def answer(
|
||||
|
|
|
@ -1,7 +1,19 @@
|
|||
from collections import deque
|
||||
from html import escape
|
||||
from html.parser import HTMLParser
|
||||
from typing import Any, Deque, Dict, Iterable, List, Optional, Tuple, Type, cast
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Deque,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from ...tl.abcs import MessageEntity
|
||||
from ...tl.types import (
|
||||
|
@ -30,13 +42,11 @@ class HTMLToTelegramParser(HTMLParser):
|
|||
self._open_tags: Deque[str] = deque()
|
||||
self._open_tags_meta: Deque[Optional[str]] = deque()
|
||||
|
||||
def handle_starttag(
|
||||
self, tag: str, attrs_seq: List[Tuple[str, Optional[str]]]
|
||||
) -> None:
|
||||
def handle_starttag(self, tag: str, attrs: List[Tuple[str, Optional[str]]]) -> None:
|
||||
self._open_tags.appendleft(tag)
|
||||
self._open_tags_meta.appendleft(None)
|
||||
|
||||
attrs = dict(attrs_seq)
|
||||
attributes = dict(attrs)
|
||||
EntityType: Optional[Type[MessageEntity]] = None
|
||||
args = {}
|
||||
if tag == "strong" or tag == "b":
|
||||
|
@ -61,7 +71,7 @@ class HTMLToTelegramParser(HTMLParser):
|
|||
# inside <pre> tags
|
||||
pre = self._building_entities["pre"]
|
||||
assert isinstance(pre, MessageEntityPre)
|
||||
if cls := attrs.get("class"):
|
||||
if cls := attributes.get("class"):
|
||||
pre.language = cls[len("language-") :]
|
||||
except KeyError:
|
||||
EntityType = MessageEntityCode
|
||||
|
@ -69,7 +79,7 @@ class HTMLToTelegramParser(HTMLParser):
|
|||
EntityType = MessageEntityPre
|
||||
args["language"] = ""
|
||||
elif tag == "a":
|
||||
url = attrs.get("href")
|
||||
url = attributes.get("href")
|
||||
if not url:
|
||||
return
|
||||
if url.startswith("mailto:"):
|
||||
|
@ -94,18 +104,18 @@ class HTMLToTelegramParser(HTMLParser):
|
|||
**args,
|
||||
)
|
||||
|
||||
def handle_data(self, text: str) -> None:
|
||||
def handle_data(self, data: str) -> None:
|
||||
previous_tag = self._open_tags[0] if len(self._open_tags) > 0 else ""
|
||||
if previous_tag == "a":
|
||||
url = self._open_tags_meta[0]
|
||||
if url:
|
||||
text = url
|
||||
data = url
|
||||
|
||||
for entity in self._building_entities.values():
|
||||
assert hasattr(entity, "length")
|
||||
entity.length += len(text)
|
||||
setattr(entity, "length", getattr(entity, "length", 0) + len(data))
|
||||
|
||||
self.text += text
|
||||
self.text += data
|
||||
|
||||
def handle_endtag(self, tag: str) -> None:
|
||||
try:
|
||||
|
@ -114,7 +124,7 @@ class HTMLToTelegramParser(HTMLParser):
|
|||
except IndexError:
|
||||
pass
|
||||
entity = self._building_entities.pop(tag, None)
|
||||
if entity and hasattr(entity, "length") and entity.length:
|
||||
if entity and getattr(entity, "length", None):
|
||||
self.entities.append(entity)
|
||||
|
||||
|
||||
|
@ -135,7 +145,9 @@ def parse(html: str) -> Tuple[str, List[MessageEntity]]:
|
|||
return del_surrogate(text), parser.entities
|
||||
|
||||
|
||||
ENTITY_TO_FORMATTER = {
|
||||
ENTITY_TO_FORMATTER: Dict[
|
||||
Type[MessageEntity], Union[Tuple[str, str], Callable[[Any, str], Tuple[str, str]]]
|
||||
] = {
|
||||
MessageEntityBold: ("<strong>", "</strong>"),
|
||||
MessageEntityItalic: ("<em>", "</em>"),
|
||||
MessageEntityCode: ("<code>", "</code>"),
|
||||
|
@ -173,18 +185,20 @@ def unparse(text: str, entities: Iterable[MessageEntity]) -> str:
|
|||
|
||||
text = add_surrogate(text)
|
||||
insert_at: List[Tuple[int, str]] = []
|
||||
for entity in entities:
|
||||
assert hasattr(entity, "offset") and hasattr(entity, "length")
|
||||
s = entity.offset
|
||||
e = entity.offset + entity.length
|
||||
delimiter = ENTITY_TO_FORMATTER.get(type(entity), None)
|
||||
for e in entities:
|
||||
offset, length = getattr(e, "offset", None), getattr(e, "length", None)
|
||||
assert isinstance(offset, int) and isinstance(length, int)
|
||||
|
||||
h = offset
|
||||
t = offset + length
|
||||
delimiter = ENTITY_TO_FORMATTER.get(type(e), None)
|
||||
if delimiter:
|
||||
if callable(delimiter):
|
||||
delim = delimiter(entity, text[s:e])
|
||||
delim = delimiter(e, text[h:t])
|
||||
else:
|
||||
delim = delimiter
|
||||
insert_at.append((s, delim[0]))
|
||||
insert_at.append((e, delim[1]))
|
||||
insert_at.append((h, delim[0]))
|
||||
insert_at.append((t, delim[1]))
|
||||
|
||||
insert_at.sort(key=lambda t: t[0])
|
||||
next_escape_bound = len(text)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import re
|
||||
from typing import Any, Iterator, List, Tuple
|
||||
from typing import Any, Dict, Iterator, List, Tuple, Type
|
||||
|
||||
import markdown_it
|
||||
import markdown_it.token
|
||||
|
@ -19,7 +19,7 @@ from ...tl.types import (
|
|||
from .strings import add_surrogate, del_surrogate, within_surrogate
|
||||
|
||||
MARKDOWN = markdown_it.MarkdownIt().enable("strikethrough")
|
||||
DELIMITERS = {
|
||||
DELIMITERS: Dict[Type[MessageEntity], Tuple[str, str]] = {
|
||||
MessageEntityBlockquote: ("> ", ""),
|
||||
MessageEntityBold: ("**", "**"),
|
||||
MessageEntityCode: ("`", "`"),
|
||||
|
@ -81,7 +81,9 @@ def parse(message: str) -> Tuple[str, List[MessageEntity]]:
|
|||
else:
|
||||
for entity in reversed(entities):
|
||||
if isinstance(entity, ty):
|
||||
entity.length = len(message) - entity.offset
|
||||
setattr(
|
||||
entity, "length", len(message) - getattr(entity, "offset", 0)
|
||||
)
|
||||
break
|
||||
|
||||
parsed = MARKDOWN.parse(add_surrogate(message.strip()))
|
||||
|
@ -156,24 +158,25 @@ def unparse(text: str, entities: List[MessageEntity]) -> str:
|
|||
|
||||
text = add_surrogate(text)
|
||||
insert_at: List[Tuple[int, str]] = []
|
||||
for entity in entities:
|
||||
assert hasattr(entity, "offset")
|
||||
assert hasattr(entity, "length")
|
||||
s = entity.offset
|
||||
e = entity.offset + entity.length
|
||||
delimiter = DELIMITERS.get(type(entity), None)
|
||||
for e in entities:
|
||||
offset, length = getattr(e, "offset", None), getattr(e, "length", None)
|
||||
assert isinstance(offset, int) and isinstance(length, int)
|
||||
|
||||
h = offset
|
||||
t = offset + length
|
||||
delimiter = DELIMITERS.get(type(e), None)
|
||||
if delimiter:
|
||||
insert_at.append((s, delimiter[0]))
|
||||
insert_at.append((e, delimiter[1]))
|
||||
elif isinstance(entity, MessageEntityPre):
|
||||
insert_at.append((s, f"```{entity.language}\n"))
|
||||
insert_at.append((e, "```\n"))
|
||||
elif isinstance(entity, MessageEntityTextUrl):
|
||||
insert_at.append((s, "["))
|
||||
insert_at.append((e, f"]({entity.url})"))
|
||||
elif isinstance(entity, MessageEntityMentionName):
|
||||
insert_at.append((s, "["))
|
||||
insert_at.append((e, f"](tg://user?id={entity.user_id})"))
|
||||
insert_at.append((h, delimiter[0]))
|
||||
insert_at.append((t, delimiter[1]))
|
||||
elif isinstance(e, MessageEntityPre):
|
||||
insert_at.append((h, f"```{e.language}\n"))
|
||||
insert_at.append((t, "```\n"))
|
||||
elif isinstance(e, MessageEntityTextUrl):
|
||||
insert_at.append((h, "["))
|
||||
insert_at.append((t, f"]({e.url})"))
|
||||
elif isinstance(e, MessageEntityMentionName):
|
||||
insert_at.append((h, "["))
|
||||
insert_at.append((t, f"](tg://user?id={e.user_id})"))
|
||||
|
||||
insert_at.sort(key=lambda t: t[0])
|
||||
while insert_at:
|
||||
|
|
|
@ -8,9 +8,11 @@ def add_surrogate(text: str) -> str:
|
|||
return "".join(
|
||||
# SMP -> Surrogate Pairs (Telegram offsets are calculated with these).
|
||||
# See https://en.wikipedia.org/wiki/Plane_(Unicode)#Overview for more.
|
||||
"".join(chr(y) for y in struct.unpack("<HH", x.encode("utf-16le")))
|
||||
if (0x10000 <= ord(x) <= 0x10FFFF)
|
||||
else x
|
||||
(
|
||||
"".join(chr(y) for y in struct.unpack("<HH", x.encode("utf-16le")))
|
||||
if (0x10000 <= ord(x) <= 0x10FFFF)
|
||||
else x
|
||||
)
|
||||
for x in text
|
||||
)
|
||||
|
||||
|
@ -43,32 +45,38 @@ def strip_text(text: str, entities: List[MessageEntity]) -> str:
|
|||
if not entities:
|
||||
return text.strip()
|
||||
|
||||
assert all(isinstance(getattr(e, "offset"), int) for e in entities)
|
||||
|
||||
while text and text[-1].isspace():
|
||||
e = entities[-1]
|
||||
assert hasattr(e, "offset") and hasattr(e, "length")
|
||||
if e.offset + e.length == len(text):
|
||||
if e.length == 1:
|
||||
offset, length = getattr(e, "offset", None), getattr(e, "length", None)
|
||||
assert isinstance(offset, int) and isinstance(length, int)
|
||||
|
||||
if offset + length == len(text):
|
||||
if length == 1:
|
||||
del entities[-1]
|
||||
if not entities:
|
||||
return text.strip()
|
||||
else:
|
||||
e.length -= 1
|
||||
length -= 1
|
||||
text = text[:-1]
|
||||
|
||||
while text and text[0].isspace():
|
||||
for i in reversed(range(len(entities))):
|
||||
e = entities[i]
|
||||
assert hasattr(e, "offset") and hasattr(e, "length")
|
||||
if e.offset != 0:
|
||||
e.offset -= 1
|
||||
offset, length = getattr(e, "offset", None), getattr(e, "length", None)
|
||||
assert isinstance(offset, int) and isinstance(length, int)
|
||||
|
||||
if offset != 0:
|
||||
setattr(e, "offset", offset - 1)
|
||||
continue
|
||||
|
||||
if e.length == 1:
|
||||
if length == 1:
|
||||
del entities[0]
|
||||
if not entities:
|
||||
return text.lstrip()
|
||||
else:
|
||||
e.length -= 1
|
||||
setattr(e, "length", length - 1)
|
||||
|
||||
text = text[1:]
|
||||
|
||||
|
|
|
@ -29,6 +29,7 @@ class Callback(InlineButton):
|
|||
This data will be received by :class:`telethon.events.ButtonCallback` when the button is pressed.
|
||||
"""
|
||||
assert isinstance(self._raw, types.KeyboardButtonCallback)
|
||||
assert isinstance(self._raw.data, bytes)
|
||||
return self._raw.data
|
||||
|
||||
@data.setter
|
||||
|
|
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||
import itertools
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, DefaultDict, Dict, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, DefaultDict, Dict, List, Optional, Sequence, Union
|
||||
|
||||
from ....session import PackedChat
|
||||
from ....tl import abcs, types
|
||||
|
@ -19,13 +19,15 @@ ChatLike = Union[Chat, PackedChat, int, str]
|
|||
|
||||
|
||||
def build_chat_map(
|
||||
client: Client, users: List[abcs.User], chats: List[abcs.Chat]
|
||||
client: Client, users: Sequence[abcs.User], chats: Sequence[abcs.Chat]
|
||||
) -> Dict[int, Chat]:
|
||||
users_iter = (User._from_raw(u) for u in users)
|
||||
chats_iter = (
|
||||
Channel._from_raw(c)
|
||||
if isinstance(c, (types.Channel, types.ChannelForbidden)) and c.broadcast
|
||||
else Group._from_raw(client, c)
|
||||
(
|
||||
Channel._from_raw(c)
|
||||
if isinstance(c, (types.Channel, types.ChannelForbidden)) and c.broadcast
|
||||
else Group._from_raw(client, c)
|
||||
)
|
||||
for c in chats
|
||||
)
|
||||
|
||||
|
|
|
@ -5,7 +5,17 @@ import urllib.parse
|
|||
from inspect import isawaitable
|
||||
from io import BufferedWriter
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Coroutine, List, Optional, Protocol, Self, Union
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Coroutine,
|
||||
List,
|
||||
Optional,
|
||||
Protocol,
|
||||
Self,
|
||||
Sequence,
|
||||
Union,
|
||||
)
|
||||
|
||||
from ...tl import abcs, types
|
||||
from .meta import NoPublicConstructor
|
||||
|
@ -43,7 +53,7 @@ stripped_size_header = bytes.fromhex(
|
|||
stripped_size_footer = bytes.fromhex("FFD9")
|
||||
|
||||
|
||||
def expand_stripped_size(data: bytes) -> bytes:
|
||||
def expand_stripped_size(data: bytes | bytearray | memoryview) -> bytes:
|
||||
header = bytearray(stripped_size_header)
|
||||
header[164] = data[1]
|
||||
header[166] = data[2]
|
||||
|
@ -87,13 +97,14 @@ class InFileLike(Protocol):
|
|||
It's only used in function parameters.
|
||||
"""
|
||||
|
||||
def read(self, n: int) -> Union[bytes, Coroutine[Any, Any, bytes]]:
|
||||
def read(self, n: int, /) -> Union[bytes, Coroutine[Any, Any, bytes]]:
|
||||
"""
|
||||
Read from the file or buffer.
|
||||
|
||||
:param n:
|
||||
Maximum amount of bytes that should be returned.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class OutFileLike(Protocol):
|
||||
|
@ -115,18 +126,21 @@ class OutFileLike(Protocol):
|
|||
|
||||
|
||||
class OutWrapper:
|
||||
__slots__ = ("_fd", "_owned")
|
||||
__slots__ = ("_fd", "_owned_fd")
|
||||
|
||||
_fd: Union[OutFileLike, BufferedWriter]
|
||||
_owned_fd: Optional[BufferedWriter]
|
||||
|
||||
def __init__(self, file: Union[str, Path, OutFileLike]):
|
||||
if isinstance(file, str):
|
||||
file = Path(file)
|
||||
|
||||
if isinstance(file, Path):
|
||||
self._fd: Union[OutFileLike, BufferedWriter] = file.open("wb")
|
||||
self._owned = True
|
||||
self._fd = file.open("wb")
|
||||
self._owned_fd = self._fd
|
||||
else:
|
||||
self._fd = file
|
||||
self._owned = False
|
||||
self._owned_fd = None
|
||||
|
||||
async def write(self, chunk: bytes) -> None:
|
||||
ret = self._fd.write(chunk)
|
||||
|
@ -134,9 +148,9 @@ class OutWrapper:
|
|||
await ret
|
||||
|
||||
def close(self) -> None:
|
||||
if self._owned:
|
||||
assert hasattr(self._fd, "close")
|
||||
self._fd.close()
|
||||
if self._owned_fd is not None:
|
||||
assert hasattr(self._owned_fd, "close")
|
||||
self._owned_fd.close()
|
||||
|
||||
|
||||
class File(metaclass=NoPublicConstructor):
|
||||
|
@ -150,7 +164,7 @@ class File(metaclass=NoPublicConstructor):
|
|||
def __init__(
|
||||
self,
|
||||
*,
|
||||
attributes: List[abcs.DocumentAttribute],
|
||||
attributes: Sequence[abcs.DocumentAttribute],
|
||||
size: int,
|
||||
name: str,
|
||||
mime: str,
|
||||
|
@ -158,7 +172,7 @@ class File(metaclass=NoPublicConstructor):
|
|||
muted: bool,
|
||||
input_media: abcs.InputMedia,
|
||||
thumb: Optional[abcs.PhotoSize],
|
||||
thumbs: Optional[List[abcs.PhotoSize]],
|
||||
thumbs: Optional[Sequence[abcs.PhotoSize]],
|
||||
raw: Optional[Union[abcs.MessageMedia, abcs.Photo, abcs.Document]],
|
||||
client: Optional[Client],
|
||||
):
|
||||
|
@ -405,9 +419,9 @@ class File(metaclass=NoPublicConstructor):
|
|||
id=self._input_media.id.id,
|
||||
access_hash=self._input_media.id.access_hash,
|
||||
file_reference=self._input_media.id.file_reference,
|
||||
thumb_size=self._thumb.type
|
||||
if isinstance(self._thumb, thumb_types)
|
||||
else "",
|
||||
thumb_size=(
|
||||
self._thumb.type if isinstance(self._thumb, thumb_types) else ""
|
||||
),
|
||||
)
|
||||
elif isinstance(self._input_media, types.InputMediaPhoto):
|
||||
assert isinstance(self._input_media.id, types.InputPhoto)
|
||||
|
|
|
@ -2,7 +2,17 @@ from __future__ import annotations
|
|||
|
||||
import datetime
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Self, Tuple, Union
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Self,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
from ...tl import abcs, types
|
||||
from ..parsers import (
|
||||
|
@ -502,7 +512,7 @@ class Message(metaclass=NoPublicConstructor):
|
|||
|
||||
|
||||
def build_msg_map(
|
||||
client: Client, messages: List[abcs.Message], chat_map: Dict[int, Chat]
|
||||
client: Client, messages: Sequence[abcs.Message], chat_map: Dict[int, Chat]
|
||||
) -> Dict[int, Message]:
|
||||
return {
|
||||
msg.id: msg
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
"""
|
||||
Class definitions stolen from `trio`, with some modifications.
|
||||
"""
|
||||
|
||||
import abc
|
||||
from typing import Type, TypeVar
|
||||
from typing import Any, Type, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
@ -28,7 +29,7 @@ class Final(abc.ABCMeta):
|
|||
|
||||
|
||||
class NoPublicConstructor(Final):
|
||||
def __call__(cls) -> None:
|
||||
def __call__(cls, *args: Any, **kwds: Any) -> Any:
|
||||
raise TypeError(
|
||||
f"{cls.__module__}.{cls.__qualname__} has no public constructor"
|
||||
)
|
||||
|
|
|
@ -100,12 +100,8 @@ class Participant(metaclass=NoPublicConstructor):
|
|||
),
|
||||
):
|
||||
return self._raw.user_id
|
||||
elif isinstance(
|
||||
self._raw, (types.ChannelParticipantBanned, types.ChannelParticipantLeft)
|
||||
):
|
||||
return peer_id(self._raw.peer)
|
||||
else:
|
||||
raise RuntimeError("unexpected case")
|
||||
return peer_id(self._raw.peer)
|
||||
|
||||
@property
|
||||
def user(self) -> Optional[User]:
|
||||
|
|
|
@ -1,67 +1,16 @@
|
|||
import pyaes
|
||||
|
||||
|
||||
def ige_encrypt(plaintext: bytes, key: bytes, iv: bytes) -> bytes:
|
||||
assert len(plaintext) % 16 == 0
|
||||
assert len(iv) == 32
|
||||
|
||||
aes = pyaes.AES(key)
|
||||
iv1 = iv[:16]
|
||||
iv2 = iv[16:]
|
||||
|
||||
ciphertext = bytearray()
|
||||
|
||||
for block_offset in range(0, len(plaintext), 16):
|
||||
plaintext_block = plaintext[block_offset : block_offset + 16]
|
||||
ciphertext_block = bytes(
|
||||
a ^ b
|
||||
for a, b in zip(
|
||||
aes.encrypt([a ^ b for a, b in zip(plaintext_block, iv1)]), iv2
|
||||
)
|
||||
)
|
||||
iv1 = ciphertext_block
|
||||
iv2 = plaintext_block
|
||||
|
||||
ciphertext += ciphertext_block
|
||||
|
||||
return bytes(ciphertext)
|
||||
|
||||
|
||||
def ige_decrypt(ciphertext: bytes, key: bytes, iv: bytes) -> bytes:
|
||||
assert len(ciphertext) % 16 == 0
|
||||
assert len(iv) == 32
|
||||
|
||||
aes = pyaes.AES(key)
|
||||
iv1 = iv[:16]
|
||||
iv2 = iv[16:]
|
||||
|
||||
plaintext = bytearray()
|
||||
|
||||
for block_offset in range(0, len(ciphertext), 16):
|
||||
ciphertext_block = ciphertext[block_offset : block_offset + 16]
|
||||
plaintext_block = bytes(
|
||||
a ^ b
|
||||
for a, b in zip(
|
||||
aes.decrypt([a ^ b for a, b in zip(ciphertext_block, iv2)]), iv1
|
||||
)
|
||||
)
|
||||
iv1 = ciphertext_block
|
||||
iv2 = plaintext_block
|
||||
|
||||
plaintext += plaintext_block
|
||||
|
||||
return bytes(plaintext)
|
||||
|
||||
|
||||
try:
|
||||
import cryptg
|
||||
|
||||
def ige_encrypt(plaintext: bytes, key: bytes, iv: bytes) -> bytes: # noqa: F811
|
||||
def ige_encrypt(
|
||||
plaintext: bytes | bytearray | memoryview, key: bytes, iv: bytes
|
||||
) -> bytes: # noqa: F811
|
||||
return cryptg.encrypt_ige(
|
||||
bytes(plaintext) if not isinstance(plaintext, bytes) else plaintext, key, iv
|
||||
)
|
||||
|
||||
def ige_decrypt(ciphertext: bytes, key: bytes, iv: bytes) -> bytes: # noqa: F811
|
||||
def ige_decrypt(
|
||||
ciphertext: bytes | bytearray | memoryview, key: bytes, iv: bytes
|
||||
) -> bytes: # noqa: F811
|
||||
return cryptg.decrypt_ige(
|
||||
bytes(ciphertext) if not isinstance(ciphertext, bytes) else ciphertext,
|
||||
key,
|
||||
|
@ -69,4 +18,58 @@ try:
|
|||
)
|
||||
|
||||
except ImportError:
|
||||
pass
|
||||
import pyaes
|
||||
|
||||
def ige_encrypt(
|
||||
plaintext: bytes | bytearray | memoryview, key: bytes, iv: bytes
|
||||
) -> bytes:
|
||||
assert len(plaintext) % 16 == 0
|
||||
assert len(iv) == 32
|
||||
|
||||
aes = pyaes.AES(key)
|
||||
iv1 = iv[:16]
|
||||
iv2 = iv[16:]
|
||||
|
||||
ciphertext = bytearray()
|
||||
|
||||
for block_offset in range(0, len(plaintext), 16):
|
||||
plaintext_block = plaintext[block_offset : block_offset + 16]
|
||||
ciphertext_block = bytes(
|
||||
a ^ b
|
||||
for a, b in zip(
|
||||
aes.encrypt([a ^ b for a, b in zip(plaintext_block, iv1)]), iv2
|
||||
)
|
||||
)
|
||||
iv1 = ciphertext_block
|
||||
iv2 = plaintext_block
|
||||
|
||||
ciphertext += ciphertext_block
|
||||
|
||||
return bytes(ciphertext)
|
||||
|
||||
def ige_decrypt(
|
||||
ciphertext: bytes | bytearray | memoryview, key: bytes, iv: bytes
|
||||
) -> bytes:
|
||||
assert len(ciphertext) % 16 == 0
|
||||
assert len(iv) == 32
|
||||
|
||||
aes = pyaes.AES(key)
|
||||
iv1 = iv[:16]
|
||||
iv2 = iv[16:]
|
||||
|
||||
plaintext = bytearray()
|
||||
|
||||
for block_offset in range(0, len(ciphertext), 16):
|
||||
ciphertext_block = ciphertext[block_offset : block_offset + 16]
|
||||
plaintext_block = bytes(
|
||||
a ^ b
|
||||
for a, b in zip(
|
||||
aes.decrypt([a ^ b for a, b in zip(ciphertext_block, iv2)]), iv1
|
||||
)
|
||||
)
|
||||
iv1 = ciphertext_block
|
||||
iv2 = plaintext_block
|
||||
|
||||
plaintext += plaintext_block
|
||||
|
||||
return bytes(plaintext)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import os
|
||||
from collections import namedtuple
|
||||
from enum import IntEnum
|
||||
from hashlib import sha1, sha256
|
||||
from typing import NamedTuple
|
||||
|
||||
from .aes import ige_decrypt, ige_encrypt
|
||||
from .auth_key import AuthKey
|
||||
|
@ -13,15 +13,19 @@ class Side(IntEnum):
|
|||
SERVER = 8
|
||||
|
||||
|
||||
CalcKey = namedtuple("CalcKey", ("key", "iv"))
|
||||
class CalcKey(NamedTuple):
|
||||
key: bytes
|
||||
iv: bytes
|
||||
|
||||
|
||||
# https://core.telegram.org/mtproto/description#defining-aes-key-and-initialization-vector
|
||||
def calc_key(auth_key: AuthKey, msg_key: bytes, side: Side) -> CalcKey:
|
||||
def calc_key(
|
||||
auth_key: AuthKey, msg_key: bytes | bytearray | memoryview, side: Side
|
||||
) -> CalcKey:
|
||||
x = int(side)
|
||||
|
||||
# sha256_a = SHA256 (msg_key + substr (auth_key, x, 36))
|
||||
sha256_a = sha256(msg_key + auth_key.data[x : x + 36]).digest()
|
||||
sha256_a = sha256(bytes(msg_key) + auth_key.data[x : x + 36]).digest()
|
||||
|
||||
# sha256_b = SHA256 (substr (auth_key, 40+x, 36) + msg_key)
|
||||
sha256_b = sha256(auth_key.data[x + 40 : x + 76] + msg_key).digest()
|
||||
|
@ -66,7 +70,9 @@ def encrypt_data_v2(plaintext: bytes, auth_key: AuthKey) -> bytes:
|
|||
return _do_encrypt_data_v2(plaintext, auth_key, random_padding)
|
||||
|
||||
|
||||
def decrypt_data_v2(ciphertext: bytes, auth_key: AuthKey) -> bytes:
|
||||
def decrypt_data_v2(
|
||||
ciphertext: bytes | bytearray | memoryview, auth_key: AuthKey
|
||||
) -> bytes:
|
||||
side = Side.SERVER
|
||||
x = int(side)
|
||||
|
||||
|
|
|
@ -1,45 +1,58 @@
|
|||
# Ported from https://github.com/Lonami/grammers/blob/d91dc82/lib/grammers-crypto/src/two_factor_auth.rs
|
||||
from collections import namedtuple
|
||||
from hashlib import pbkdf2_hmac, sha256
|
||||
from typing import NamedTuple
|
||||
|
||||
from .factorize import factorize
|
||||
|
||||
TwoFactorAuth = namedtuple("TwoFactorAuth", ("m1", "g_a"))
|
||||
|
||||
class TwoFactorAuth(NamedTuple):
|
||||
m1: bytes
|
||||
g_a: bytes
|
||||
|
||||
|
||||
def pad_to_256(data: bytes) -> bytes:
|
||||
def pad_to_256(data: bytes | bytearray | memoryview) -> bytes:
|
||||
return bytes(256 - len(data)) + data
|
||||
|
||||
|
||||
# H(data) := sha256(data)
|
||||
def h(*data: bytes) -> bytes:
|
||||
def h(*data: bytes | bytearray | memoryview) -> bytes:
|
||||
return sha256(b"".join(data)).digest()
|
||||
|
||||
|
||||
# SH(data, salt) := H(salt | data | salt)
|
||||
def sh(data: bytes, salt: bytes) -> bytes:
|
||||
def sh(
|
||||
data: bytes | bytearray | memoryview, salt: bytes | bytearray | memoryview
|
||||
) -> bytes:
|
||||
return h(salt, data, salt)
|
||||
|
||||
|
||||
# PH1(password, salt1, salt2) := SH(SH(password, salt1), salt2)
|
||||
def ph1(password: bytes, salt1: bytes, salt2: bytes) -> bytes:
|
||||
def ph1(
|
||||
password: bytes | bytearray | memoryview,
|
||||
salt1: bytes | bytearray | memoryview,
|
||||
salt2: bytes | bytearray | memoryview,
|
||||
) -> bytes:
|
||||
return sh(sh(password, salt1), salt2)
|
||||
|
||||
|
||||
# PH2(password, salt1, salt2) := SH(pbkdf2(sha512, PH1(password, salt1, salt2), salt1, 100000), salt2)
|
||||
def ph2(password: bytes, salt1: bytes, salt2: bytes) -> bytes:
|
||||
def ph2(
|
||||
password: bytes | bytearray | memoryview,
|
||||
salt1: bytes | bytearray | memoryview,
|
||||
salt2: bytes | bytearray | memoryview,
|
||||
) -> bytes:
|
||||
return sh(pbkdf2_hmac("sha512", ph1(password, salt1, salt2), salt1, 100000), salt2)
|
||||
|
||||
|
||||
# https://core.telegram.org/api/srp
|
||||
def calculate_2fa(
|
||||
*,
|
||||
salt1: bytes,
|
||||
salt2: bytes,
|
||||
salt1: bytes | bytearray | memoryview,
|
||||
salt2: bytes | bytearray | memoryview,
|
||||
g: int,
|
||||
p: bytes,
|
||||
g_b: bytes,
|
||||
a: bytes,
|
||||
p: bytes | bytearray | memoryview,
|
||||
g_b: bytes | bytearray | memoryview,
|
||||
a: bytes | bytearray | memoryview,
|
||||
password: bytes,
|
||||
) -> TwoFactorAuth:
|
||||
big_p = int.from_bytes(p)
|
||||
|
@ -100,16 +113,16 @@ def calculate_2fa(
|
|||
return TwoFactorAuth(m1, g_a)
|
||||
|
||||
|
||||
def check_p_len(p: bytes) -> bool:
|
||||
def check_p_len(p: bytes | bytearray | memoryview) -> bool:
|
||||
return len(p) == 256
|
||||
|
||||
|
||||
def check_known_prime(p: bytes, g: int) -> bool:
|
||||
def check_known_prime(p: bytes | bytearray | memoryview, g: int) -> bool:
|
||||
good_prime = b"\xc7\x1c\xae\xb9\xc6\xb1\xc9\x04\x8elR/p\xf1?s\x98\r@#\x8e>!\xc1I4\xd07V=\x93\x0fH\x19\x8a\n\xa7\xc1@X\"\x94\x93\xd2%0\xf4\xdb\xfa3on\n\xc9%\x13\x95C\xae\xd4L\xce|7 \xfdQ\xf6\x94XpZ\xc6\x8c\xd4\xfekk\x13\xab\xdc\x97FQ)i2\x84T\xf1\x8f\xaf\x8cY_d$w\xfe\x96\xbb*\x94\x1d[\xcd\x1dJ\xc8\xccI\x88\x07\x08\xfa\x9b7\x8e<O:\x90`\xbe\xe6|\xf9\xa4\xa4\xa6\x95\x81\x10Q\x90~\x16'S\xb5k\x0fkA\r\xbat\xd8\xa8K*\x14\xb3\x14N\x0e\xf1(GT\xfd\x17\xed\x95\rYe\xb4\xb9\xddFX-\xb1\x17\x8d\x16\x9ck\xc4e\xb0\xd6\xff\x9c\xa3\x92\x8f\xef[\x9a\xe4\xe4\x18\xfc\x15\xe8>\xbe\xa0\xf8\x7f\xa9\xff^\xedp\x05\r\xed(I\xf4{\xf9Y\xd9V\x85\x0c\xe9)\x85\x1f\r\x81\x15\xf65\xb1\x05\xee.N\x15\xd0K$T\xbfoO\xad\xf04\xb1\x04\x03\x11\x9c\xd8\xe3\xb9/\xcc["
|
||||
return p == good_prime and g in (3, 4, 5, 7)
|
||||
|
||||
|
||||
def check_p_prime_and_subgroup(p: bytes, g: int) -> bool:
|
||||
def check_p_prime_and_subgroup(p: bytes | bytearray | memoryview, g: int) -> bool:
|
||||
if check_known_prime(p, g):
|
||||
return True
|
||||
|
||||
|
@ -133,7 +146,7 @@ def check_p_prime_and_subgroup(p: bytes, g: int) -> bool:
|
|||
return candidate and factorize((big_p - 1) // 2)[0] == 1
|
||||
|
||||
|
||||
def check_p_and_g(p: bytes, g: int) -> bool:
|
||||
def check_p_and_g(p: bytes | bytearray | memoryview, g: int) -> bool:
|
||||
if not check_p_len(p):
|
||||
return False
|
||||
|
||||
|
|
|
@ -164,6 +164,7 @@ def _do_step3(
|
|||
)
|
||||
|
||||
key, iv = generate_key_data_from_nonce(server_nonce, new_nonce)
|
||||
assert isinstance(server_dh_params.encrypted_answer, bytes)
|
||||
plain_text_answer = decrypt_ige(server_dh_params.encrypted_answer, key, iv)
|
||||
|
||||
got_answer_hash = plain_text_answer[:20]
|
||||
|
|
|
@ -260,7 +260,7 @@ class Encrypted(Mtp):
|
|||
self._store_own_updates(result)
|
||||
self._deserialization.append(RpcResult(msg_id, result))
|
||||
|
||||
def _store_own_updates(self, body: bytes) -> None:
|
||||
def _store_own_updates(self, body: bytes | bytearray | memoryview) -> None:
|
||||
constructor_id = struct.unpack_from("I", body)[0]
|
||||
if constructor_id in UPDATE_IDS:
|
||||
self._deserialization.append(Update(body))
|
||||
|
@ -332,7 +332,7 @@ class Encrypted(Mtp):
|
|||
)
|
||||
|
||||
self._start_salt_time = (salts.now, self._adjusted_now())
|
||||
self._salts = salts.salts
|
||||
self._salts = list(salts.salts)
|
||||
self._salts.sort(key=lambda salt: -salt.valid_since)
|
||||
|
||||
def _handle_future_salt(self, message: Message) -> None:
|
||||
|
@ -439,7 +439,9 @@ class Encrypted(Mtp):
|
|||
msg_id, buffer = result
|
||||
return msg_id, encrypt_data_v2(buffer, self._auth_key)
|
||||
|
||||
def deserialize(self, payload: bytes) -> List[Deserialization]:
|
||||
def deserialize(
|
||||
self, payload: bytes | bytearray | memoryview
|
||||
) -> List[Deserialization]:
|
||||
check_message_buffer(payload)
|
||||
|
||||
plaintext = decrypt_data_v2(payload, self._auth_key)
|
||||
|
|
|
@ -31,7 +31,9 @@ class Plain(Mtp):
|
|||
self._buffer.clear()
|
||||
return MsgId(0), result
|
||||
|
||||
def deserialize(self, payload: bytes) -> List[Deserialization]:
|
||||
def deserialize(
|
||||
self, payload: bytes | bytearray | memoryview
|
||||
) -> List[Deserialization]:
|
||||
check_message_buffer(payload)
|
||||
|
||||
auth_key_id, msg_id, length = struct.unpack_from("<qqi", payload)
|
||||
|
|
|
@ -15,7 +15,7 @@ class Update:
|
|||
|
||||
__slots__ = ("body",)
|
||||
|
||||
def __init__(self, body: bytes):
|
||||
def __init__(self, body: bytes | bytearray | memoryview):
|
||||
self.body = body
|
||||
|
||||
|
||||
|
@ -26,7 +26,7 @@ class RpcResult:
|
|||
|
||||
__slots__ = ("msg_id", "body")
|
||||
|
||||
def __init__(self, msg_id: MsgId, body: bytes):
|
||||
def __init__(self, msg_id: MsgId, body: bytes | bytearray | memoryview):
|
||||
self.msg_id = msg_id
|
||||
self.body = body
|
||||
|
||||
|
@ -201,7 +201,9 @@ class Mtp(ABC):
|
|||
"""
|
||||
|
||||
@abstractmethod
|
||||
def deserialize(self, payload: bytes) -> List[Deserialization]:
|
||||
def deserialize(
|
||||
self, payload: bytes | bytearray | memoryview
|
||||
) -> List[Deserialization]:
|
||||
"""
|
||||
Deserialize incoming buffer payload.
|
||||
"""
|
||||
|
|
|
@ -12,7 +12,7 @@ class Transport(ABC):
|
|||
pass
|
||||
|
||||
@abstractmethod
|
||||
def unpack(self, input: bytes, output: bytearray) -> int:
|
||||
def unpack(self, input: bytes | bytearray | memoryview, output: bytearray) -> int:
|
||||
pass
|
||||
|
||||
|
||||
|
|
|
@ -36,7 +36,7 @@ class Abridged(Transport):
|
|||
write(struct.pack("<i", 0x7F | (length << 8)))
|
||||
write(input)
|
||||
|
||||
def unpack(self, input: bytes, output: bytearray) -> int:
|
||||
def unpack(self, input: bytes | bytearray | memoryview, output: bytearray) -> int:
|
||||
if not input:
|
||||
raise MissingBytes(expected=1, got=0)
|
||||
|
||||
|
|
|
@ -35,7 +35,7 @@ class Full(Transport):
|
|||
write(struct.pack("<I", crc32(tmp)))
|
||||
self._send_seq += 1
|
||||
|
||||
def unpack(self, input: bytes, output: bytearray) -> int:
|
||||
def unpack(self, input: bytes | bytearray | memoryview, output: bytearray) -> int:
|
||||
if len(input) < 4:
|
||||
raise MissingBytes(expected=4, got=len(input))
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ class Intermediate(Transport):
|
|||
write(struct.pack("<i", len(input)))
|
||||
write(input)
|
||||
|
||||
def unpack(self, input: bytes, output: bytearray) -> int:
|
||||
def unpack(self, input: bytes | bytearray | memoryview, output: bytearray) -> int:
|
||||
if len(input) < 4:
|
||||
raise MissingBytes(expected=4, got=len(input))
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ CONTAINER_MAX_LENGTH = 100
|
|||
MESSAGE_SIZE_OVERHEAD = 8 + 4 + 4 # msg_id, seq_no, bytes
|
||||
|
||||
|
||||
def check_message_buffer(message: bytes) -> None:
|
||||
def check_message_buffer(message: bytes | bytearray | memoryview) -> None:
|
||||
if len(message) < 20:
|
||||
raise ValueError(
|
||||
f"server payload is too small to be a valid message: {message.hex()}"
|
||||
|
|
|
@ -70,6 +70,7 @@ class AsyncReader(Protocol):
|
|||
:param n:
|
||||
Amount of bytes to read at most.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class AsyncWriter(Protocol):
|
||||
|
@ -77,7 +78,7 @@ class AsyncWriter(Protocol):
|
|||
A :class:`asyncio.StreamWriter`-like class.
|
||||
"""
|
||||
|
||||
def write(self, data: bytes) -> None:
|
||||
def write(self, data: bytes | bytearray | memoryview) -> None:
|
||||
"""
|
||||
Must behave like :meth:`asyncio.StreamWriter.write`.
|
||||
|
||||
|
@ -127,7 +128,7 @@ class Connector(Protocol):
|
|||
"""
|
||||
|
||||
async def __call__(self, ip: str, port: int) -> Tuple[AsyncReader, AsyncWriter]:
|
||||
pass
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class RequestState(ABC):
|
||||
|
@ -340,12 +341,12 @@ class Sender:
|
|||
self._process_result(result)
|
||||
elif isinstance(result, RpcError):
|
||||
self._process_error(result)
|
||||
elif isinstance(result, BadMessage):
|
||||
self._process_bad_message(result)
|
||||
else:
|
||||
raise RuntimeError("unexpected case")
|
||||
self._process_bad_message(result)
|
||||
|
||||
def _process_update(self, updates: List[Updates], update: bytes) -> None:
|
||||
def _process_update(
|
||||
self, updates: List[Updates], update: bytes | bytearray | memoryview
|
||||
) -> None:
|
||||
try:
|
||||
updates.append(Updates.from_bytes(update))
|
||||
except ValueError:
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, Optional, Sequence, Tuple
|
||||
|
||||
from ...tl import abcs, types
|
||||
from .packed import PackedChat, PackedType
|
||||
|
@ -131,7 +131,7 @@ class ChatHashCache:
|
|||
else:
|
||||
raise RuntimeError("unexpected case")
|
||||
|
||||
def extend(self, users: List[abcs.User], chats: List[abcs.Chat]) -> bool:
|
||||
def extend(self, users: Sequence[abcs.User], chats: Sequence[abcs.Chat]) -> bool:
|
||||
# See https://core.telegram.org/api/min for "issues" with "min constructors".
|
||||
success = True
|
||||
|
||||
|
@ -299,6 +299,7 @@ class ChatHashCache:
|
|||
success &= self._has_notify_peer(peer)
|
||||
|
||||
for field in ("users",):
|
||||
user: Any
|
||||
users = getattr(message.action, field, None)
|
||||
if isinstance(users, list):
|
||||
for user in users:
|
||||
|
|
|
@ -4,11 +4,35 @@ from typing import List, Literal, Union
|
|||
|
||||
from ...tl import abcs
|
||||
|
||||
NO_DATE = 0 # used on adapted messages.affected* from lower layers
|
||||
NO_SEQ = 0
|
||||
|
||||
NO_PTS = 0
|
||||
|
||||
# https://core.telegram.org/method/updates.getChannelDifference
|
||||
BOT_CHANNEL_DIFF_LIMIT = 100000
|
||||
USER_CHANNEL_DIFF_LIMIT = 100
|
||||
|
||||
POSSIBLE_GAP_TIMEOUT = 0.5
|
||||
|
||||
# https://core.telegram.org/api/updates
|
||||
NO_UPDATES_TIMEOUT = 15 * 60
|
||||
|
||||
ENTRY_ACCOUNT: Literal["ACCOUNT"] = "ACCOUNT"
|
||||
ENTRY_SECRET: Literal["SECRET"] = "SECRET"
|
||||
Entry = Union[Literal["ACCOUNT", "SECRET"], int]
|
||||
|
||||
# Python's logging doesn't define a TRACE level. Pick halfway between DEBUG and NOTSET.
|
||||
# We don't define a name for this as libraries shouldn't do that though.
|
||||
LOG_LEVEL_TRACE = (logging.DEBUG - logging.NOTSET) // 2
|
||||
|
||||
|
||||
class PtsInfo:
|
||||
__slots__ = ("pts", "pts_count", "entry")
|
||||
|
||||
def __init__(self, entry: "Entry", pts: int, pts_count: int) -> None:
|
||||
entry: Entry # pyright needs this or it infers int | str
|
||||
|
||||
def __init__(self, entry: Entry, pts: int, pts_count: int) -> None:
|
||||
self.pts = pts
|
||||
self.pts_count = pts_count
|
||||
self.entry = entry
|
||||
|
@ -59,26 +83,3 @@ class PrematureEndReason(Enum):
|
|||
class Gap(ValueError):
|
||||
def __repr__(self) -> str:
|
||||
return "Gap()"
|
||||
|
||||
|
||||
NO_DATE = 0 # used on adapted messages.affected* from lower layers
|
||||
NO_SEQ = 0
|
||||
|
||||
NO_PTS = 0
|
||||
|
||||
# https://core.telegram.org/method/updates.getChannelDifference
|
||||
BOT_CHANNEL_DIFF_LIMIT = 100000
|
||||
USER_CHANNEL_DIFF_LIMIT = 100
|
||||
|
||||
POSSIBLE_GAP_TIMEOUT = 0.5
|
||||
|
||||
# https://core.telegram.org/api/updates
|
||||
NO_UPDATES_TIMEOUT = 15 * 60
|
||||
|
||||
ENTRY_ACCOUNT: Literal["ACCOUNT"] = "ACCOUNT"
|
||||
ENTRY_SECRET: Literal["SECRET"] = "SECRET"
|
||||
Entry = Union[Literal["ACCOUNT", "SECRET"], int]
|
||||
|
||||
# Python's logging doesn't define a TRACE level. Pick halfway between DEBUG and NOTSET.
|
||||
# We don't define a name for this as libraries shouldn't do that though.
|
||||
LOG_LEVEL_TRACE = (logging.DEBUG - logging.NOTSET) // 2
|
||||
|
|
|
@ -2,7 +2,7 @@ import asyncio
|
|||
import datetime
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
from typing import Dict, List, Optional, Sequence, Set, Tuple
|
||||
|
||||
from ...tl import Request, abcs, functions, types
|
||||
from ..chat import ChatHashCache
|
||||
|
@ -168,6 +168,7 @@ class MessageBox:
|
|||
if not entries:
|
||||
return
|
||||
|
||||
entry: Entry = ENTRY_ACCOUNT # for pyright to know it's not unbound
|
||||
for entry in entries:
|
||||
if entry not in self.map:
|
||||
raise RuntimeError(
|
||||
|
@ -258,7 +259,7 @@ class MessageBox:
|
|||
self,
|
||||
updates: abcs.Updates,
|
||||
chat_hashes: ChatHashCache,
|
||||
) -> Tuple[List[abcs.Update], List[abcs.User], List[abcs.Chat]]:
|
||||
) -> Tuple[List[abcs.Update], Sequence[abcs.User], Sequence[abcs.Chat]]:
|
||||
result: List[abcs.Update] = []
|
||||
combined = adapt(updates, chat_hashes)
|
||||
|
||||
|
@ -289,7 +290,7 @@ class MessageBox:
|
|||
sorted_updates = list(sorted(combined.updates, key=update_sort_key))
|
||||
|
||||
any_pts_applied = False
|
||||
reset_deadlines_for = set()
|
||||
reset_deadlines_for: Set[Entry] = set()
|
||||
for update in sorted_updates:
|
||||
entry, applied = self.apply_pts_info(update)
|
||||
if entry is not None:
|
||||
|
@ -420,9 +421,11 @@ class MessageBox:
|
|||
pts_limit=None,
|
||||
pts_total_limit=None,
|
||||
date=int(self.date.timestamp()),
|
||||
qts=self.map[ENTRY_SECRET].pts
|
||||
if ENTRY_SECRET in self.map
|
||||
else NO_SEQ,
|
||||
qts=(
|
||||
self.map[ENTRY_SECRET].pts
|
||||
if ENTRY_SECRET in self.map
|
||||
else NO_SEQ
|
||||
),
|
||||
qts_limit=None,
|
||||
)
|
||||
if __debug__:
|
||||
|
@ -435,12 +438,12 @@ class MessageBox:
|
|||
self,
|
||||
diff: abcs.updates.Difference,
|
||||
chat_hashes: ChatHashCache,
|
||||
) -> Tuple[List[abcs.Update], List[abcs.User], List[abcs.Chat]]:
|
||||
) -> Tuple[List[abcs.Update], Sequence[abcs.User], Sequence[abcs.Chat]]:
|
||||
if __debug__:
|
||||
self._trace("applying account difference: %s", diff)
|
||||
|
||||
finish: bool
|
||||
result: Tuple[List[abcs.Update], List[abcs.User], List[abcs.Chat]]
|
||||
result: Tuple[List[abcs.Update], Sequence[abcs.User], Sequence[abcs.Chat]]
|
||||
if isinstance(diff, types.updates.DifferenceEmpty):
|
||||
finish = True
|
||||
self.date = datetime.datetime.fromtimestamp(
|
||||
|
@ -493,7 +496,7 @@ class MessageBox:
|
|||
self,
|
||||
diff: types.updates.Difference,
|
||||
chat_hashes: ChatHashCache,
|
||||
) -> Tuple[List[abcs.Update], List[abcs.User], List[abcs.Chat]]:
|
||||
) -> Tuple[List[abcs.Update], Sequence[abcs.User], Sequence[abcs.Chat]]:
|
||||
state = diff.state
|
||||
assert isinstance(state, types.updates.State)
|
||||
self.map[ENTRY_ACCOUNT].pts = state.pts
|
||||
|
@ -556,9 +559,11 @@ class MessageBox:
|
|||
channel=channel,
|
||||
filter=types.ChannelMessagesFilterEmpty(),
|
||||
pts=state.pts,
|
||||
limit=BOT_CHANNEL_DIFF_LIMIT
|
||||
if chat_hashes.is_self_bot
|
||||
else USER_CHANNEL_DIFF_LIMIT,
|
||||
limit=(
|
||||
BOT_CHANNEL_DIFF_LIMIT
|
||||
if chat_hashes.is_self_bot
|
||||
else USER_CHANNEL_DIFF_LIMIT
|
||||
),
|
||||
)
|
||||
if __debug__:
|
||||
self._trace("requesting channel difference: %s", gd)
|
||||
|
@ -577,7 +582,7 @@ class MessageBox:
|
|||
channel_id: int,
|
||||
diff: abcs.updates.ChannelDifference,
|
||||
chat_hashes: ChatHashCache,
|
||||
) -> Tuple[List[abcs.Update], List[abcs.User], List[abcs.Chat]]:
|
||||
) -> Tuple[List[abcs.Update], Sequence[abcs.User], Sequence[abcs.Chat]]:
|
||||
entry: Entry = channel_id
|
||||
if __debug__:
|
||||
self._trace("applying channel=%r difference: %s", entry, diff)
|
||||
|
|
|
@ -4,7 +4,7 @@ from .memory import MemorySession
|
|||
from .storage import Storage
|
||||
|
||||
try:
|
||||
from .sqlite import SqliteSession
|
||||
from .sqlite import SqliteSession # pyright: ignore [reportAssignmentType]
|
||||
except ImportError as e:
|
||||
import_err = e
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ def _bootstrap_get_ty(constructor_id: int) -> Optional[Type["Serializable"]]:
|
|||
class Reader:
|
||||
__slots__ = ("_view", "_pos", "_len")
|
||||
|
||||
def __init__(self, buffer: bytes | memoryview) -> None:
|
||||
def __init__(self, buffer: bytes | bytearray | memoryview) -> None:
|
||||
self._view = (
|
||||
memoryview(buffer) if not isinstance(buffer, memoryview) else buffer
|
||||
)
|
||||
|
|
|
@ -9,10 +9,10 @@ class HasSlots(Protocol):
|
|||
__slots__: Tuple[str, ...]
|
||||
|
||||
|
||||
def obj_repr(obj: HasSlots) -> str:
|
||||
fields = ((attr, getattr(obj, attr)) for attr in obj.__slots__)
|
||||
def obj_repr(self: HasSlots) -> str:
|
||||
fields = ((attr, getattr(self, attr)) for attr in self.__slots__)
|
||||
params = ", ".join(f"{name}={field!r}" for name, field in fields)
|
||||
return f"{obj.__class__.__name__}({params})"
|
||||
return f"{self.__class__.__name__}({params})"
|
||||
|
||||
|
||||
class Serializable(abc.ABC):
|
||||
|
@ -36,7 +36,7 @@ class Serializable(abc.ABC):
|
|||
pass
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, blob: bytes) -> Self:
|
||||
def from_bytes(cls, blob: bytes | bytearray | memoryview) -> Self:
|
||||
return Reader(blob).read_serializable(cls)
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
|
@ -54,7 +54,7 @@ class Serializable(abc.ABC):
|
|||
)
|
||||
|
||||
|
||||
def serialize_bytes_to(buffer: bytearray, data: bytes) -> None:
|
||||
def serialize_bytes_to(buffer: bytearray, data: bytes | bytearray | memoryview) -> None:
|
||||
length = len(data)
|
||||
if length < 0xFE:
|
||||
buffer += struct.pack("<B", length)
|
||||
|
|
|
@ -65,7 +65,7 @@ def test_key_from_nonce() -> None:
|
|||
server_nonce = int.from_bytes(bytes(range(16)))
|
||||
new_nonce = int.from_bytes(bytes(range(32)))
|
||||
|
||||
(key, iv) = generate_key_data_from_nonce(server_nonce, new_nonce)
|
||||
key, iv = generate_key_data_from_nonce(server_nonce, new_nonce)
|
||||
assert (
|
||||
key
|
||||
== b'\x07X\xf1S;a]$\xf6\xe8\xa9Jo\xcb\xee\nU\xea\xab"\x17\xd7)\\\xa9!=\x1a-}\x16\xa6'
|
||||
|
|
|
@ -7,7 +7,7 @@ from telethon._impl.mtproto import Abridged
|
|||
class Output(bytearray):
|
||||
__slots__ = ()
|
||||
|
||||
def __call__(self, data: bytes) -> None:
|
||||
def __call__(self, data: bytes | bytearray | memoryview) -> None:
|
||||
self += data
|
||||
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ from telethon._impl.mtproto import Full
|
|||
class Output(bytearray):
|
||||
__slots__ = ()
|
||||
|
||||
def __call__(self, data: bytes) -> None:
|
||||
def __call__(self, data: bytes | bytearray | memoryview) -> None:
|
||||
self += data
|
||||
|
||||
|
||||
|
@ -20,7 +20,7 @@ def setup_unpack(n: int) -> Tuple[bytes, Full, bytes, bytearray]:
|
|||
transport, expected_output, input = setup_pack(n)
|
||||
transport.pack(expected_output, input)
|
||||
|
||||
return expected_output, Full(), input, bytearray()
|
||||
return expected_output, Full(), bytes(input), bytearray()
|
||||
|
||||
|
||||
def test_pack_empty() -> None:
|
||||
|
|
|
@ -7,7 +7,7 @@ from telethon._impl.mtproto import Intermediate
|
|||
class Output(bytearray):
|
||||
__slots__ = ()
|
||||
|
||||
def __call__(self, data: bytes) -> None:
|
||||
def __call__(self, data: bytes | bytearray | memoryview) -> None:
|
||||
self += data
|
||||
|
||||
|
||||
|
|
|
@ -46,14 +46,14 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
|
|||
|
||||
ignored_types = {"true", "boolTrue", "boolFalse"} # also "compiler built-ins"
|
||||
|
||||
abc_namespaces = set()
|
||||
type_namespaces = set()
|
||||
function_namespaces = set()
|
||||
abc_namespaces: Set[str] = set()
|
||||
type_namespaces: Set[str] = set()
|
||||
function_namespaces: Set[str] = set()
|
||||
|
||||
abc_class_names = set()
|
||||
type_class_names = set()
|
||||
function_def_names = set()
|
||||
generated_type_names = set()
|
||||
abc_class_names: Set[str] = set()
|
||||
type_class_names: Set[str] = set()
|
||||
function_def_names: Set[str] = set()
|
||||
generated_type_names: Set[str] = set()
|
||||
|
||||
for typedef in tl.typedefs:
|
||||
if typedef.ty.full_name not in generated_types:
|
||||
|
@ -67,6 +67,7 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
|
|||
abc_path = Path("abcs/_nons.py")
|
||||
|
||||
if abc_path not in fs:
|
||||
fs.write(abc_path, "# pyright: reportUnusedImport=false\n")
|
||||
fs.write(abc_path, "from abc import ABCMeta\n")
|
||||
fs.write(abc_path, "from ..core import Serializable\n")
|
||||
|
||||
|
@ -93,11 +94,14 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
|
|||
writer = fs.open(type_path)
|
||||
|
||||
if type_path not in fs:
|
||||
writer.write(
|
||||
"# pyright: reportUnusedImport=false, reportConstantRedefinition=false"
|
||||
)
|
||||
writer.write("import struct")
|
||||
writer.write("from typing import List, Optional, Self, Sequence")
|
||||
writer.write("from typing import Optional, Self, Sequence")
|
||||
writer.write("from .. import abcs")
|
||||
writer.write("from ..core import Reader, Serializable, serialize_bytes_to")
|
||||
writer.write("_bytes = bytes")
|
||||
writer.write("_bytes = bytes | bytearray | memoryview")
|
||||
|
||||
ns = f"{typedef.namespace[0]}." if typedef.namespace else ""
|
||||
generated_type_names.add(f"{ns}{to_class_name(typedef.name)}")
|
||||
|
@ -113,14 +117,13 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
|
|||
|
||||
# def constructor_id()
|
||||
writer.write(" @classmethod")
|
||||
writer.write(" def constructor_id(_) -> int:")
|
||||
writer.write(" def constructor_id(cls) -> int:")
|
||||
writer.write(f" return {hex(typedef.id)}")
|
||||
|
||||
# def __init__()
|
||||
if property_params:
|
||||
params = "".join(
|
||||
f", {p.name}: {param_type_fmt(p.ty, immutable=False)}"
|
||||
for p in property_params
|
||||
f", {p.name}: {param_type_fmt(p.ty)}" for p in property_params
|
||||
)
|
||||
writer.write(f" def __init__(_s, *{params}) -> None:")
|
||||
for p in property_params:
|
||||
|
@ -159,22 +162,18 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
|
|||
writer = fs.open(function_path)
|
||||
|
||||
if function_path not in fs:
|
||||
writer.write("# pyright: reportUnusedImport=false")
|
||||
writer.write("import struct")
|
||||
writer.write("from typing import List, Optional, Self, Sequence")
|
||||
writer.write("from typing import Optional, Self, Sequence")
|
||||
writer.write("from .. import abcs")
|
||||
writer.write("from ..core import Request, serialize_bytes_to")
|
||||
writer.write("_bytes = bytes")
|
||||
writer.write("_bytes = bytes | bytearray | memoryview")
|
||||
|
||||
# def name(params, ...)
|
||||
required_params = [p for p in functiondef.params if not is_computed(p.ty)]
|
||||
params = "".join(
|
||||
f", {p.name}: {param_type_fmt(p.ty, immutable=True)}"
|
||||
for p in required_params
|
||||
)
|
||||
params = "".join(f", {p.name}: {param_type_fmt(p.ty)}" for p in required_params)
|
||||
star = "*" if params else ""
|
||||
return_ty = param_type_fmt(
|
||||
NormalParameter(ty=functiondef.ty, flag=None), immutable=False
|
||||
)
|
||||
return_ty = param_type_fmt(NormalParameter(ty=functiondef.ty, flag=None))
|
||||
writer.write(
|
||||
f"def {to_method_name(functiondef.name)}({star}{params}) -> Request[{return_ty}]:"
|
||||
)
|
||||
|
@ -189,6 +188,7 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
|
|||
)
|
||||
|
||||
writer = fs.open(Path("layer.py"))
|
||||
writer.write("# pyright: reportUnusedImport=false")
|
||||
writer.write("from . import abcs, types")
|
||||
writer.write(
|
||||
"from .core import Serializable, Reader, deserialize_bool, deserialize_i32_list, deserialize_i64_list, deserialize_identity, single_deserializer, list_deserializer"
|
||||
|
|
|
@ -86,7 +86,7 @@ def inner_type_fmt(ty: Type) -> str:
|
|||
return f"abcs.{ns}{to_class_name(ty.name)}"
|
||||
|
||||
|
||||
def param_type_fmt(ty: BaseParameter, *, immutable: bool) -> str:
|
||||
def param_type_fmt(ty: BaseParameter) -> str:
|
||||
if isinstance(ty, FlagsParameter):
|
||||
return "int"
|
||||
elif not isinstance(ty, NormalParameter):
|
||||
|
@ -104,10 +104,7 @@ def param_type_fmt(ty: BaseParameter, *, immutable: bool) -> str:
|
|||
res = "_bytes" if inner_ty.name == "Object" else inner_type_fmt(inner_ty)
|
||||
|
||||
if ty.ty.generic_arg:
|
||||
if immutable:
|
||||
res = f"Sequence[{res}]"
|
||||
else:
|
||||
res = f"List[{res}]"
|
||||
res = f"Sequence[{res}]"
|
||||
|
||||
if ty.flag and ty.ty.name != "true":
|
||||
res = f"Optional[{res}]"
|
||||
|
|
|
@ -15,7 +15,8 @@ class ParsedTl:
|
|||
|
||||
|
||||
def load_tl_file(path: Union[str, Path]) -> ParsedTl:
|
||||
typedefs, functiondefs = [], []
|
||||
typedefs: List[TypeDef] = []
|
||||
functiondefs: List[FunctionDef] = []
|
||||
with open(path, "r", encoding="utf-8") as fd:
|
||||
contents = fd.read()
|
||||
|
||||
|
@ -31,10 +32,8 @@ def load_tl_file(path: Union[str, Path]) -> ParsedTl:
|
|||
raise
|
||||
elif isinstance(definition, TypeDef):
|
||||
typedefs.append(definition)
|
||||
elif isinstance(definition, FunctionDef):
|
||||
functiondefs.append(definition)
|
||||
else:
|
||||
raise TypeError(f"unexpected type: {type(definition)}")
|
||||
functiondefs.append(definition)
|
||||
|
||||
return ParsedTl(
|
||||
layer=layer, typedefs=list(typedefs), functiondefs=list(functiondefs)
|
||||
|
|
|
@ -59,8 +59,8 @@ class Definition:
|
|||
raise ValueError("invalid id")
|
||||
|
||||
type_defs: List[str] = []
|
||||
flag_defs = []
|
||||
params = []
|
||||
flag_defs: List[str] = []
|
||||
params: List[Parameter] = []
|
||||
|
||||
for param_str in middle.split():
|
||||
try:
|
||||
|
|
|
@ -79,7 +79,7 @@ def test_recursive_vec() -> None:
|
|||
"""
|
||||
)
|
||||
result = gen_py_code(typedefs=definitions)
|
||||
assert "value: List[abcs.JsonObjectValue]" in result
|
||||
assert "value: Sequence[abcs.JsonObjectValue]" in result
|
||||
|
||||
|
||||
def test_object_blob_special_case() -> None:
|
||||
|
|
|
@ -10,6 +10,7 @@ Imports of new definitions and formatting must be added with other tools.
|
|||
Properties and private methods can use a different parameter name than `self`
|
||||
to avoid being included.
|
||||
"""
|
||||
|
||||
import ast
|
||||
import subprocess
|
||||
import sys
|
||||
|
@ -31,6 +32,8 @@ class FunctionMethodsVisitor(ast.NodeVisitor):
|
|||
match node.args.args:
|
||||
case [ast.arg(arg="self", annotation=ast.Name(id="Client")), *_]:
|
||||
self.methods.append(node)
|
||||
case _:
|
||||
pass
|
||||
|
||||
|
||||
class MethodVisitor(ast.NodeVisitor):
|
||||
|
@ -59,6 +62,8 @@ class MethodVisitor(ast.NodeVisitor):
|
|||
match node.body:
|
||||
case [ast.Expr(value=ast.Constant(value=str(doc))), *_]:
|
||||
self.method_docs[node.name] = doc
|
||||
case _:
|
||||
pass
|
||||
|
||||
|
||||
def main() -> None:
|
||||
|
@ -81,10 +86,10 @@ def main() -> None:
|
|||
|
||||
m_visitor.visit(ast.parse(contents))
|
||||
|
||||
class_body = []
|
||||
class_body: List[ast.stmt] = []
|
||||
|
||||
for function in sorted(fm_visitor.methods, key=lambda f: f.name):
|
||||
function.body = []
|
||||
function_body: List[ast.stmt] = []
|
||||
if doc := m_visitor.method_docs.get(function.name):
|
||||
function.body.append(ast.Expr(value=ast.Constant(value=doc)))
|
||||
|
||||
|
@ -108,7 +113,7 @@ def main() -> None:
|
|||
case _:
|
||||
call = ast.Return(value=call)
|
||||
|
||||
function.body.append(call)
|
||||
function.body = function_body
|
||||
class_body.append(function)
|
||||
|
||||
generated = ast.unparse(
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
Scan the `client/` directory for __init__.py files.
|
||||
For every depth-1 import, add it, in order, to the __all__ variable.
|
||||
"""
|
||||
|
||||
import ast
|
||||
import os
|
||||
import re
|
||||
|
@ -25,7 +26,7 @@ def main() -> None:
|
|||
rf"(tl|mtproto){re.escape(os.path.sep)}(abcs|functions|types)"
|
||||
)
|
||||
|
||||
files = []
|
||||
files: List[str] = []
|
||||
for file in impl_root.rglob("__init__.py"):
|
||||
file_str = str(file)
|
||||
if autogenerated_re.search(file_str):
|
||||
|
@ -52,6 +53,8 @@ def main() -> None:
|
|||
+ "]\n"
|
||||
]
|
||||
break
|
||||
case _:
|
||||
pass
|
||||
|
||||
with file.open("w", encoding="utf-8", newline="\n") as fd:
|
||||
fd.writelines(lines)
|
||||
|
|
19
typings/setuptools.pyi
Normal file
19
typings/setuptools.pyi
Normal file
|
@ -0,0 +1,19 @@
|
|||
from typing import Any, Dict, Optional
|
||||
|
||||
class build_meta:
|
||||
@staticmethod
|
||||
def build_wheel(
|
||||
wheel_directory: str,
|
||||
config_settings: Optional[Dict[Any, Any]] = None,
|
||||
metadata_directory: Optional[str] = None,
|
||||
) -> str: ...
|
||||
@staticmethod
|
||||
def build_sdist(
|
||||
sdist_directory: str, config_settings: Optional[Dict[Any, Any]] = None
|
||||
) -> str: ...
|
||||
@staticmethod
|
||||
def build_editable(
|
||||
wheel_directory: str,
|
||||
config_settings: Optional[Dict[Any, Any]] = None,
|
||||
metadata_directory: Optional[str] = None,
|
||||
) -> str: ...
|
Loading…
Reference in New Issue
Block a user