mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2024-11-28 12:23:45 +03:00
Make pyright happy
This commit is contained in:
parent
854096e9d3
commit
033b56f1d3
|
@ -25,6 +25,7 @@ def serialize_builtin(value: Any) -> bytes:
|
||||||
|
|
||||||
|
|
||||||
def overhead(obj: Obj) -> None:
|
def overhead(obj: Obj) -> None:
|
||||||
|
x: Any
|
||||||
for v in obj.__dict__.values():
|
for v in obj.__dict__.values():
|
||||||
for x in v if isinstance(v, list) else [v]:
|
for x in v if isinstance(v, list) else [v]:
|
||||||
if isinstance(x, Obj):
|
if isinstance(x, Obj):
|
||||||
|
@ -34,6 +35,7 @@ def overhead(obj: Obj) -> None:
|
||||||
|
|
||||||
|
|
||||||
def strategy_concat(obj: Obj) -> bytes:
|
def strategy_concat(obj: Obj) -> bytes:
|
||||||
|
x: Any
|
||||||
res = b""
|
res = b""
|
||||||
for v in obj.__dict__.values():
|
for v in obj.__dict__.values():
|
||||||
for x in v if isinstance(v, list) else [v]:
|
for x in v if isinstance(v, list) else [v]:
|
||||||
|
@ -45,6 +47,7 @@ def strategy_concat(obj: Obj) -> bytes:
|
||||||
|
|
||||||
|
|
||||||
def strategy_append(obj: Obj) -> bytes:
|
def strategy_append(obj: Obj) -> bytes:
|
||||||
|
x: Any
|
||||||
res = bytearray()
|
res = bytearray()
|
||||||
for v in obj.__dict__.values():
|
for v in obj.__dict__.values():
|
||||||
for x in v if isinstance(v, list) else [v]:
|
for x in v if isinstance(v, list) else [v]:
|
||||||
|
@ -57,6 +60,7 @@ def strategy_append(obj: Obj) -> bytes:
|
||||||
|
|
||||||
def strategy_append_reuse(obj: Obj) -> bytes:
|
def strategy_append_reuse(obj: Obj) -> bytes:
|
||||||
def do_append(o: Obj, res: bytearray) -> None:
|
def do_append(o: Obj, res: bytearray) -> None:
|
||||||
|
x: Any
|
||||||
for v in o.__dict__.values():
|
for v in o.__dict__.values():
|
||||||
for x in v if isinstance(v, list) else [v]:
|
for x in v if isinstance(v, list) else [v]:
|
||||||
if isinstance(x, Obj):
|
if isinstance(x, Obj):
|
||||||
|
@ -70,15 +74,18 @@ def strategy_append_reuse(obj: Obj) -> bytes:
|
||||||
|
|
||||||
|
|
||||||
def strategy_join(obj: Obj) -> bytes:
|
def strategy_join(obj: Obj) -> bytes:
|
||||||
return b"".join(
|
def iterator() -> Iterator[bytes]:
|
||||||
strategy_join(x) if isinstance(x, Obj) else serialize_builtin(x)
|
x: Any
|
||||||
for v in obj.__dict__.values()
|
for v in obj.__dict__.values():
|
||||||
for x in (v if isinstance(v, list) else [v])
|
for x in v if isinstance(v, list) else [v]:
|
||||||
)
|
yield strategy_join(x) if isinstance(x, Obj) else serialize_builtin(x)
|
||||||
|
|
||||||
|
return b"".join(iterator())
|
||||||
|
|
||||||
|
|
||||||
def strategy_join_flat(obj: Obj) -> bytes:
|
def strategy_join_flat(obj: Obj) -> bytes:
|
||||||
def flatten(o: Obj) -> Iterator[bytes]:
|
def flatten(o: Obj) -> Iterator[bytes]:
|
||||||
|
x: Any
|
||||||
for v in o.__dict__.values():
|
for v in o.__dict__.values():
|
||||||
for x in v if isinstance(v, list) else [v]:
|
for x in v if isinstance(v, list) else [v]:
|
||||||
if isinstance(x, Obj):
|
if isinstance(x, Obj):
|
||||||
|
@ -91,6 +98,7 @@ def strategy_join_flat(obj: Obj) -> bytes:
|
||||||
|
|
||||||
def strategy_write(obj: Obj) -> bytes:
|
def strategy_write(obj: Obj) -> bytes:
|
||||||
def do_write(o: Obj, buffer: io.BytesIO) -> None:
|
def do_write(o: Obj, buffer: io.BytesIO) -> None:
|
||||||
|
x: Any
|
||||||
for v in o.__dict__.values():
|
for v in o.__dict__.values():
|
||||||
for x in v if isinstance(v, list) else [v]:
|
for x in v if isinstance(v, list) else [v]:
|
||||||
if isinstance(x, Obj):
|
if isinstance(x, Obj):
|
||||||
|
|
|
@ -6,7 +6,7 @@ DATA = 42
|
||||||
|
|
||||||
|
|
||||||
def overhead(n: int) -> None:
|
def overhead(n: int) -> None:
|
||||||
n
|
n = n
|
||||||
|
|
||||||
|
|
||||||
def strategy_bool(n: int) -> bool:
|
def strategy_bool(n: int) -> bool:
|
||||||
|
|
|
@ -3,7 +3,7 @@ from pathlib import Path
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from setuptools import build_meta as _orig
|
from setuptools import build_meta as _orig
|
||||||
from setuptools.build_meta import * # noqa: F403
|
from setuptools.build_meta import * # noqa: F403 # pyright: ignore [reportWildcardImportFromLibrary]
|
||||||
|
|
||||||
|
|
||||||
def gen_types_if_needed() -> None:
|
def gen_types_if_needed() -> None:
|
||||||
|
|
|
@ -223,6 +223,7 @@ async def check_password(
|
||||||
|
|
||||||
if not two_factor_auth.check_p_and_g(algo.p, algo.g):
|
if not two_factor_auth.check_p_and_g(algo.p, algo.g):
|
||||||
token = await get_password_information(self)
|
token = await get_password_information(self)
|
||||||
|
algo = token._password.current_algo
|
||||||
if not isinstance(
|
if not isinstance(
|
||||||
algo,
|
algo,
|
||||||
types.PasswordKdfAlgoSha256Sha256Pbkdf2HmacshA512Iter100000Sha256ModPow,
|
types.PasswordKdfAlgoSha256Sha256Pbkdf2HmacshA512Iter100000Sha256ModPow,
|
||||||
|
|
|
@ -2059,5 +2059,4 @@ class Client:
|
||||||
exc: Optional[BaseException],
|
exc: Optional[BaseException],
|
||||||
tb: Optional[TracebackType],
|
tb: Optional[TracebackType],
|
||||||
) -> None:
|
) -> None:
|
||||||
exc_type, exc, tb
|
|
||||||
await disconnect(self)
|
await disconnect(self)
|
||||||
|
|
|
@ -444,6 +444,7 @@ class FileBytesList(AsyncList[bytes]):
|
||||||
|
|
||||||
if result.bytes:
|
if result.bytes:
|
||||||
self._offset += MAX_CHUNK_SIZE
|
self._offset += MAX_CHUNK_SIZE
|
||||||
|
assert isinstance(result.bytes, bytes)
|
||||||
self._buffer.append(result.bytes)
|
self._buffer.append(result.bytes)
|
||||||
|
|
||||||
self._done = len(result.bytes) < MAX_CHUNK_SIZE
|
self._done = len(result.bytes) < MAX_CHUNK_SIZE
|
||||||
|
|
|
@ -39,11 +39,13 @@ async def send_message(
|
||||||
noforwards=not text.can_forward,
|
noforwards=not text.can_forward,
|
||||||
update_stickersets_order=False,
|
update_stickersets_order=False,
|
||||||
peer=peer,
|
peer=peer,
|
||||||
reply_to=types.InputReplyToMessage(
|
reply_to=(
|
||||||
reply_to_msg_id=text.replied_message_id, top_msg_id=None
|
types.InputReplyToMessage(
|
||||||
)
|
reply_to_msg_id=text.replied_message_id, top_msg_id=None
|
||||||
if text.replied_message_id
|
)
|
||||||
else None,
|
if text.replied_message_id
|
||||||
|
else None
|
||||||
|
),
|
||||||
message=message,
|
message=message,
|
||||||
random_id=random_id,
|
random_id=random_id,
|
||||||
reply_markup=getattr(text._raw, "reply_markup", None),
|
reply_markup=getattr(text._raw, "reply_markup", None),
|
||||||
|
@ -63,11 +65,11 @@ async def send_message(
|
||||||
noforwards=False,
|
noforwards=False,
|
||||||
update_stickersets_order=False,
|
update_stickersets_order=False,
|
||||||
peer=peer,
|
peer=peer,
|
||||||
reply_to=types.InputReplyToMessage(
|
reply_to=(
|
||||||
reply_to_msg_id=reply_to, top_msg_id=None
|
types.InputReplyToMessage(reply_to_msg_id=reply_to, top_msg_id=None)
|
||||||
)
|
if reply_to
|
||||||
if reply_to
|
else None
|
||||||
else None,
|
),
|
||||||
message=message,
|
message=message,
|
||||||
random_id=random_id,
|
random_id=random_id,
|
||||||
reply_markup=btns.build_keyboard(buttons),
|
reply_markup=btns.build_keyboard(buttons),
|
||||||
|
@ -83,19 +85,23 @@ async def send_message(
|
||||||
{},
|
{},
|
||||||
out=result.out,
|
out=result.out,
|
||||||
id=result.id,
|
id=result.id,
|
||||||
from_id=types.PeerUser(user_id=self._session.user.id)
|
from_id=(
|
||||||
if self._session.user
|
types.PeerUser(user_id=self._session.user.id)
|
||||||
else None,
|
if self._session.user
|
||||||
|
else None
|
||||||
|
),
|
||||||
peer_id=packed._to_peer(),
|
peer_id=packed._to_peer(),
|
||||||
reply_to=types.MessageReplyHeader(
|
reply_to=(
|
||||||
reply_to_scheduled=False,
|
types.MessageReplyHeader(
|
||||||
forum_topic=False,
|
reply_to_scheduled=False,
|
||||||
reply_to_msg_id=reply_to,
|
forum_topic=False,
|
||||||
reply_to_peer_id=None,
|
reply_to_msg_id=reply_to,
|
||||||
reply_to_top_id=None,
|
reply_to_peer_id=None,
|
||||||
)
|
reply_to_top_id=None,
|
||||||
if reply_to
|
)
|
||||||
else None,
|
if reply_to
|
||||||
|
else None
|
||||||
|
),
|
||||||
date=result.date,
|
date=result.date,
|
||||||
message=message,
|
message=message,
|
||||||
media=result.media,
|
media=result.media,
|
||||||
|
@ -593,8 +599,8 @@ def build_message_map(
|
||||||
else:
|
else:
|
||||||
return MessageMap(client, peer, {}, {})
|
return MessageMap(client, peer, {}, {})
|
||||||
|
|
||||||
random_id_to_id = {}
|
random_id_to_id: Dict[int, int] = {}
|
||||||
id_to_message = {}
|
id_to_message: Dict[int, Message] = {}
|
||||||
for update in updates:
|
for update in updates:
|
||||||
if isinstance(update, types.UpdateMessageId):
|
if isinstance(update, types.UpdateMessageId):
|
||||||
random_id_to_id[update.random_id] = update.id
|
random_id_to_id[update.random_id] = update.id
|
||||||
|
|
|
@ -8,6 +8,7 @@ from typing import (
|
||||||
Callable,
|
Callable,
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
|
Sequence,
|
||||||
Type,
|
Type,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
)
|
)
|
||||||
|
@ -103,8 +104,8 @@ def process_socket_updates(client: Client, all_updates: List[abcs.Updates]) -> N
|
||||||
def extend_update_queue(
|
def extend_update_queue(
|
||||||
client: Client,
|
client: Client,
|
||||||
updates: List[abcs.Update],
|
updates: List[abcs.Update],
|
||||||
users: List[abcs.User],
|
users: Sequence[abcs.User],
|
||||||
chats: List[abcs.Chat],
|
chats: Sequence[abcs.Chat],
|
||||||
) -> None:
|
) -> None:
|
||||||
chat_map = build_chat_map(client, users, chats)
|
chat_map = build_chat_map(client, users, chats)
|
||||||
|
|
||||||
|
|
|
@ -91,21 +91,22 @@ async def get_chats(self: Client, chats: Sequence[ChatLike]) -> List[Chat]:
|
||||||
input_channels.append(packed._to_input_channel())
|
input_channels.append(packed._to_input_channel())
|
||||||
|
|
||||||
if input_users:
|
if input_users:
|
||||||
users = await self(functions.users.get_users(id=input_users))
|
ret_users = await self(functions.users.get_users(id=input_users))
|
||||||
|
users = list(ret_users)
|
||||||
else:
|
else:
|
||||||
users = []
|
users = []
|
||||||
|
|
||||||
if input_chats:
|
if input_chats:
|
||||||
ret_chats = await self(functions.messages.get_chats(id=input_chats))
|
ret_chats = await self(functions.messages.get_chats(id=input_chats))
|
||||||
assert isinstance(ret_chats, types.messages.Chats)
|
assert isinstance(ret_chats, types.messages.Chats)
|
||||||
groups = ret_chats.chats
|
groups = list(ret_chats.chats)
|
||||||
else:
|
else:
|
||||||
groups = []
|
groups = []
|
||||||
|
|
||||||
if input_channels:
|
if input_channels:
|
||||||
ret_chats = await self(functions.channels.get_channels(id=input_channels))
|
ret_chats = await self(functions.channels.get_channels(id=input_channels))
|
||||||
assert isinstance(ret_chats, types.messages.Chats)
|
assert isinstance(ret_chats, types.messages.Chats)
|
||||||
channels = ret_chats.chats
|
channels = list(ret_chats.chats)
|
||||||
else:
|
else:
|
||||||
channels = []
|
channels = []
|
||||||
|
|
||||||
|
@ -133,7 +134,7 @@ async def resolve_to_packed(
|
||||||
ty = PackedType.USER
|
ty = PackedType.USER
|
||||||
elif isinstance(chat, Group):
|
elif isinstance(chat, Group):
|
||||||
ty = PackedType.MEGAGROUP if chat.is_megagroup else PackedType.CHAT
|
ty = PackedType.MEGAGROUP if chat.is_megagroup else PackedType.CHAT
|
||||||
elif isinstance(chat, Channel):
|
else:
|
||||||
ty = PackedType.BROADCAST
|
ty = PackedType.BROADCAST
|
||||||
|
|
||||||
return PackedChat(ty=ty, id=chat.id, access_hash=0)
|
return PackedChat(ty=ty, id=chat.id, access_hash=0)
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Self, Union
|
from typing import TYPE_CHECKING, Dict, Optional, Self, Sequence, Union
|
||||||
|
|
||||||
from ...tl import abcs, types
|
from ...tl import abcs, types
|
||||||
from ..types import Chat, Message, expand_peer, peer_id
|
from ..types import Chat, Message, expand_peer, peer_id
|
||||||
|
@ -69,7 +69,7 @@ class MessageDeleted(Event):
|
||||||
The chat is only known when the deletion occurs in broadcast channels or supergroups.
|
The chat is only known when the deletion occurs in broadcast channels or supergroups.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, msg_ids: List[int], channel_id: Optional[int]) -> None:
|
def __init__(self, msg_ids: Sequence[int], channel_id: Optional[int]) -> None:
|
||||||
self._msg_ids = msg_ids
|
self._msg_ids = msg_ids
|
||||||
self._channel_id = channel_id
|
self._channel_id = channel_id
|
||||||
|
|
||||||
|
@ -85,7 +85,7 @@ class MessageDeleted(Event):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def message_ids(self) -> List[int]:
|
def message_ids(self) -> Sequence[int]:
|
||||||
"""
|
"""
|
||||||
The message identifiers of the messages that were deleted.
|
The message identifiers of the messages that were deleted.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -40,7 +40,7 @@ class ButtonCallback(Event):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def data(self) -> bytes:
|
def data(self) -> bytes:
|
||||||
assert self._raw.data is not None
|
assert isinstance(self._raw.data, bytes)
|
||||||
return self._raw.data
|
return self._raw.data
|
||||||
|
|
||||||
async def answer(
|
async def answer(
|
||||||
|
|
|
@ -1,7 +1,19 @@
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from html import escape
|
from html import escape
|
||||||
from html.parser import HTMLParser
|
from html.parser import HTMLParser
|
||||||
from typing import Any, Deque, Dict, Iterable, List, Optional, Tuple, Type, cast
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Deque,
|
||||||
|
Dict,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
|
||||||
from ...tl.abcs import MessageEntity
|
from ...tl.abcs import MessageEntity
|
||||||
from ...tl.types import (
|
from ...tl.types import (
|
||||||
|
@ -30,13 +42,11 @@ class HTMLToTelegramParser(HTMLParser):
|
||||||
self._open_tags: Deque[str] = deque()
|
self._open_tags: Deque[str] = deque()
|
||||||
self._open_tags_meta: Deque[Optional[str]] = deque()
|
self._open_tags_meta: Deque[Optional[str]] = deque()
|
||||||
|
|
||||||
def handle_starttag(
|
def handle_starttag(self, tag: str, attrs: List[Tuple[str, Optional[str]]]) -> None:
|
||||||
self, tag: str, attrs_seq: List[Tuple[str, Optional[str]]]
|
|
||||||
) -> None:
|
|
||||||
self._open_tags.appendleft(tag)
|
self._open_tags.appendleft(tag)
|
||||||
self._open_tags_meta.appendleft(None)
|
self._open_tags_meta.appendleft(None)
|
||||||
|
|
||||||
attrs = dict(attrs_seq)
|
attributes = dict(attrs)
|
||||||
EntityType: Optional[Type[MessageEntity]] = None
|
EntityType: Optional[Type[MessageEntity]] = None
|
||||||
args = {}
|
args = {}
|
||||||
if tag == "strong" or tag == "b":
|
if tag == "strong" or tag == "b":
|
||||||
|
@ -61,7 +71,7 @@ class HTMLToTelegramParser(HTMLParser):
|
||||||
# inside <pre> tags
|
# inside <pre> tags
|
||||||
pre = self._building_entities["pre"]
|
pre = self._building_entities["pre"]
|
||||||
assert isinstance(pre, MessageEntityPre)
|
assert isinstance(pre, MessageEntityPre)
|
||||||
if cls := attrs.get("class"):
|
if cls := attributes.get("class"):
|
||||||
pre.language = cls[len("language-") :]
|
pre.language = cls[len("language-") :]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
EntityType = MessageEntityCode
|
EntityType = MessageEntityCode
|
||||||
|
@ -69,7 +79,7 @@ class HTMLToTelegramParser(HTMLParser):
|
||||||
EntityType = MessageEntityPre
|
EntityType = MessageEntityPre
|
||||||
args["language"] = ""
|
args["language"] = ""
|
||||||
elif tag == "a":
|
elif tag == "a":
|
||||||
url = attrs.get("href")
|
url = attributes.get("href")
|
||||||
if not url:
|
if not url:
|
||||||
return
|
return
|
||||||
if url.startswith("mailto:"):
|
if url.startswith("mailto:"):
|
||||||
|
@ -94,18 +104,18 @@ class HTMLToTelegramParser(HTMLParser):
|
||||||
**args,
|
**args,
|
||||||
)
|
)
|
||||||
|
|
||||||
def handle_data(self, text: str) -> None:
|
def handle_data(self, data: str) -> None:
|
||||||
previous_tag = self._open_tags[0] if len(self._open_tags) > 0 else ""
|
previous_tag = self._open_tags[0] if len(self._open_tags) > 0 else ""
|
||||||
if previous_tag == "a":
|
if previous_tag == "a":
|
||||||
url = self._open_tags_meta[0]
|
url = self._open_tags_meta[0]
|
||||||
if url:
|
if url:
|
||||||
text = url
|
data = url
|
||||||
|
|
||||||
for entity in self._building_entities.values():
|
for entity in self._building_entities.values():
|
||||||
assert hasattr(entity, "length")
|
assert hasattr(entity, "length")
|
||||||
entity.length += len(text)
|
setattr(entity, "length", getattr(entity, "length", 0) + len(data))
|
||||||
|
|
||||||
self.text += text
|
self.text += data
|
||||||
|
|
||||||
def handle_endtag(self, tag: str) -> None:
|
def handle_endtag(self, tag: str) -> None:
|
||||||
try:
|
try:
|
||||||
|
@ -114,7 +124,7 @@ class HTMLToTelegramParser(HTMLParser):
|
||||||
except IndexError:
|
except IndexError:
|
||||||
pass
|
pass
|
||||||
entity = self._building_entities.pop(tag, None)
|
entity = self._building_entities.pop(tag, None)
|
||||||
if entity and hasattr(entity, "length") and entity.length:
|
if entity and getattr(entity, "length", None):
|
||||||
self.entities.append(entity)
|
self.entities.append(entity)
|
||||||
|
|
||||||
|
|
||||||
|
@ -135,7 +145,9 @@ def parse(html: str) -> Tuple[str, List[MessageEntity]]:
|
||||||
return del_surrogate(text), parser.entities
|
return del_surrogate(text), parser.entities
|
||||||
|
|
||||||
|
|
||||||
ENTITY_TO_FORMATTER = {
|
ENTITY_TO_FORMATTER: Dict[
|
||||||
|
Type[MessageEntity], Union[Tuple[str, str], Callable[[Any, str], Tuple[str, str]]]
|
||||||
|
] = {
|
||||||
MessageEntityBold: ("<strong>", "</strong>"),
|
MessageEntityBold: ("<strong>", "</strong>"),
|
||||||
MessageEntityItalic: ("<em>", "</em>"),
|
MessageEntityItalic: ("<em>", "</em>"),
|
||||||
MessageEntityCode: ("<code>", "</code>"),
|
MessageEntityCode: ("<code>", "</code>"),
|
||||||
|
@ -173,18 +185,20 @@ def unparse(text: str, entities: Iterable[MessageEntity]) -> str:
|
||||||
|
|
||||||
text = add_surrogate(text)
|
text = add_surrogate(text)
|
||||||
insert_at: List[Tuple[int, str]] = []
|
insert_at: List[Tuple[int, str]] = []
|
||||||
for entity in entities:
|
for e in entities:
|
||||||
assert hasattr(entity, "offset") and hasattr(entity, "length")
|
offset, length = getattr(e, "offset", None), getattr(e, "length", None)
|
||||||
s = entity.offset
|
assert isinstance(offset, int) and isinstance(length, int)
|
||||||
e = entity.offset + entity.length
|
|
||||||
delimiter = ENTITY_TO_FORMATTER.get(type(entity), None)
|
h = offset
|
||||||
|
t = offset + length
|
||||||
|
delimiter = ENTITY_TO_FORMATTER.get(type(e), None)
|
||||||
if delimiter:
|
if delimiter:
|
||||||
if callable(delimiter):
|
if callable(delimiter):
|
||||||
delim = delimiter(entity, text[s:e])
|
delim = delimiter(e, text[h:t])
|
||||||
else:
|
else:
|
||||||
delim = delimiter
|
delim = delimiter
|
||||||
insert_at.append((s, delim[0]))
|
insert_at.append((h, delim[0]))
|
||||||
insert_at.append((e, delim[1]))
|
insert_at.append((t, delim[1]))
|
||||||
|
|
||||||
insert_at.sort(key=lambda t: t[0])
|
insert_at.sort(key=lambda t: t[0])
|
||||||
next_escape_bound = len(text)
|
next_escape_bound = len(text)
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import re
|
import re
|
||||||
from typing import Any, Iterator, List, Tuple
|
from typing import Any, Dict, Iterator, List, Tuple, Type
|
||||||
|
|
||||||
import markdown_it
|
import markdown_it
|
||||||
import markdown_it.token
|
import markdown_it.token
|
||||||
|
@ -19,7 +19,7 @@ from ...tl.types import (
|
||||||
from .strings import add_surrogate, del_surrogate, within_surrogate
|
from .strings import add_surrogate, del_surrogate, within_surrogate
|
||||||
|
|
||||||
MARKDOWN = markdown_it.MarkdownIt().enable("strikethrough")
|
MARKDOWN = markdown_it.MarkdownIt().enable("strikethrough")
|
||||||
DELIMITERS = {
|
DELIMITERS: Dict[Type[MessageEntity], Tuple[str, str]] = {
|
||||||
MessageEntityBlockquote: ("> ", ""),
|
MessageEntityBlockquote: ("> ", ""),
|
||||||
MessageEntityBold: ("**", "**"),
|
MessageEntityBold: ("**", "**"),
|
||||||
MessageEntityCode: ("`", "`"),
|
MessageEntityCode: ("`", "`"),
|
||||||
|
@ -81,7 +81,9 @@ def parse(message: str) -> Tuple[str, List[MessageEntity]]:
|
||||||
else:
|
else:
|
||||||
for entity in reversed(entities):
|
for entity in reversed(entities):
|
||||||
if isinstance(entity, ty):
|
if isinstance(entity, ty):
|
||||||
entity.length = len(message) - entity.offset
|
setattr(
|
||||||
|
entity, "length", len(message) - getattr(entity, "offset", 0)
|
||||||
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
parsed = MARKDOWN.parse(add_surrogate(message.strip()))
|
parsed = MARKDOWN.parse(add_surrogate(message.strip()))
|
||||||
|
@ -156,24 +158,25 @@ def unparse(text: str, entities: List[MessageEntity]) -> str:
|
||||||
|
|
||||||
text = add_surrogate(text)
|
text = add_surrogate(text)
|
||||||
insert_at: List[Tuple[int, str]] = []
|
insert_at: List[Tuple[int, str]] = []
|
||||||
for entity in entities:
|
for e in entities:
|
||||||
assert hasattr(entity, "offset")
|
offset, length = getattr(e, "offset", None), getattr(e, "length", None)
|
||||||
assert hasattr(entity, "length")
|
assert isinstance(offset, int) and isinstance(length, int)
|
||||||
s = entity.offset
|
|
||||||
e = entity.offset + entity.length
|
h = offset
|
||||||
delimiter = DELIMITERS.get(type(entity), None)
|
t = offset + length
|
||||||
|
delimiter = DELIMITERS.get(type(e), None)
|
||||||
if delimiter:
|
if delimiter:
|
||||||
insert_at.append((s, delimiter[0]))
|
insert_at.append((h, delimiter[0]))
|
||||||
insert_at.append((e, delimiter[1]))
|
insert_at.append((t, delimiter[1]))
|
||||||
elif isinstance(entity, MessageEntityPre):
|
elif isinstance(e, MessageEntityPre):
|
||||||
insert_at.append((s, f"```{entity.language}\n"))
|
insert_at.append((h, f"```{e.language}\n"))
|
||||||
insert_at.append((e, "```\n"))
|
insert_at.append((t, "```\n"))
|
||||||
elif isinstance(entity, MessageEntityTextUrl):
|
elif isinstance(e, MessageEntityTextUrl):
|
||||||
insert_at.append((s, "["))
|
insert_at.append((h, "["))
|
||||||
insert_at.append((e, f"]({entity.url})"))
|
insert_at.append((t, f"]({e.url})"))
|
||||||
elif isinstance(entity, MessageEntityMentionName):
|
elif isinstance(e, MessageEntityMentionName):
|
||||||
insert_at.append((s, "["))
|
insert_at.append((h, "["))
|
||||||
insert_at.append((e, f"](tg://user?id={entity.user_id})"))
|
insert_at.append((t, f"](tg://user?id={e.user_id})"))
|
||||||
|
|
||||||
insert_at.sort(key=lambda t: t[0])
|
insert_at.sort(key=lambda t: t[0])
|
||||||
while insert_at:
|
while insert_at:
|
||||||
|
|
|
@ -8,9 +8,11 @@ def add_surrogate(text: str) -> str:
|
||||||
return "".join(
|
return "".join(
|
||||||
# SMP -> Surrogate Pairs (Telegram offsets are calculated with these).
|
# SMP -> Surrogate Pairs (Telegram offsets are calculated with these).
|
||||||
# See https://en.wikipedia.org/wiki/Plane_(Unicode)#Overview for more.
|
# See https://en.wikipedia.org/wiki/Plane_(Unicode)#Overview for more.
|
||||||
"".join(chr(y) for y in struct.unpack("<HH", x.encode("utf-16le")))
|
(
|
||||||
if (0x10000 <= ord(x) <= 0x10FFFF)
|
"".join(chr(y) for y in struct.unpack("<HH", x.encode("utf-16le")))
|
||||||
else x
|
if (0x10000 <= ord(x) <= 0x10FFFF)
|
||||||
|
else x
|
||||||
|
)
|
||||||
for x in text
|
for x in text
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -43,32 +45,38 @@ def strip_text(text: str, entities: List[MessageEntity]) -> str:
|
||||||
if not entities:
|
if not entities:
|
||||||
return text.strip()
|
return text.strip()
|
||||||
|
|
||||||
|
assert all(isinstance(getattr(e, "offset"), int) for e in entities)
|
||||||
|
|
||||||
while text and text[-1].isspace():
|
while text and text[-1].isspace():
|
||||||
e = entities[-1]
|
e = entities[-1]
|
||||||
assert hasattr(e, "offset") and hasattr(e, "length")
|
offset, length = getattr(e, "offset", None), getattr(e, "length", None)
|
||||||
if e.offset + e.length == len(text):
|
assert isinstance(offset, int) and isinstance(length, int)
|
||||||
if e.length == 1:
|
|
||||||
|
if offset + length == len(text):
|
||||||
|
if length == 1:
|
||||||
del entities[-1]
|
del entities[-1]
|
||||||
if not entities:
|
if not entities:
|
||||||
return text.strip()
|
return text.strip()
|
||||||
else:
|
else:
|
||||||
e.length -= 1
|
length -= 1
|
||||||
text = text[:-1]
|
text = text[:-1]
|
||||||
|
|
||||||
while text and text[0].isspace():
|
while text and text[0].isspace():
|
||||||
for i in reversed(range(len(entities))):
|
for i in reversed(range(len(entities))):
|
||||||
e = entities[i]
|
e = entities[i]
|
||||||
assert hasattr(e, "offset") and hasattr(e, "length")
|
offset, length = getattr(e, "offset", None), getattr(e, "length", None)
|
||||||
if e.offset != 0:
|
assert isinstance(offset, int) and isinstance(length, int)
|
||||||
e.offset -= 1
|
|
||||||
|
if offset != 0:
|
||||||
|
setattr(e, "offset", offset - 1)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if e.length == 1:
|
if length == 1:
|
||||||
del entities[0]
|
del entities[0]
|
||||||
if not entities:
|
if not entities:
|
||||||
return text.lstrip()
|
return text.lstrip()
|
||||||
else:
|
else:
|
||||||
e.length -= 1
|
setattr(e, "length", length - 1)
|
||||||
|
|
||||||
text = text[1:]
|
text = text[1:]
|
||||||
|
|
||||||
|
|
|
@ -29,6 +29,7 @@ class Callback(InlineButton):
|
||||||
This data will be received by :class:`telethon.events.ButtonCallback` when the button is pressed.
|
This data will be received by :class:`telethon.events.ButtonCallback` when the button is pressed.
|
||||||
"""
|
"""
|
||||||
assert isinstance(self._raw, types.KeyboardButtonCallback)
|
assert isinstance(self._raw, types.KeyboardButtonCallback)
|
||||||
|
assert isinstance(self._raw.data, bytes)
|
||||||
return self._raw.data
|
return self._raw.data
|
||||||
|
|
||||||
@data.setter
|
@data.setter
|
||||||
|
|
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||||
import itertools
|
import itertools
|
||||||
import sys
|
import sys
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import TYPE_CHECKING, DefaultDict, Dict, List, Optional, Union
|
from typing import TYPE_CHECKING, DefaultDict, Dict, List, Optional, Sequence, Union
|
||||||
|
|
||||||
from ....session import PackedChat
|
from ....session import PackedChat
|
||||||
from ....tl import abcs, types
|
from ....tl import abcs, types
|
||||||
|
@ -19,13 +19,15 @@ ChatLike = Union[Chat, PackedChat, int, str]
|
||||||
|
|
||||||
|
|
||||||
def build_chat_map(
|
def build_chat_map(
|
||||||
client: Client, users: List[abcs.User], chats: List[abcs.Chat]
|
client: Client, users: Sequence[abcs.User], chats: Sequence[abcs.Chat]
|
||||||
) -> Dict[int, Chat]:
|
) -> Dict[int, Chat]:
|
||||||
users_iter = (User._from_raw(u) for u in users)
|
users_iter = (User._from_raw(u) for u in users)
|
||||||
chats_iter = (
|
chats_iter = (
|
||||||
Channel._from_raw(c)
|
(
|
||||||
if isinstance(c, (types.Channel, types.ChannelForbidden)) and c.broadcast
|
Channel._from_raw(c)
|
||||||
else Group._from_raw(client, c)
|
if isinstance(c, (types.Channel, types.ChannelForbidden)) and c.broadcast
|
||||||
|
else Group._from_raw(client, c)
|
||||||
|
)
|
||||||
for c in chats
|
for c in chats
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,17 @@ import urllib.parse
|
||||||
from inspect import isawaitable
|
from inspect import isawaitable
|
||||||
from io import BufferedWriter
|
from io import BufferedWriter
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, Coroutine, List, Optional, Protocol, Self, Union
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Coroutine,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Protocol,
|
||||||
|
Self,
|
||||||
|
Sequence,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
from ...tl import abcs, types
|
from ...tl import abcs, types
|
||||||
from .meta import NoPublicConstructor
|
from .meta import NoPublicConstructor
|
||||||
|
@ -43,7 +53,7 @@ stripped_size_header = bytes.fromhex(
|
||||||
stripped_size_footer = bytes.fromhex("FFD9")
|
stripped_size_footer = bytes.fromhex("FFD9")
|
||||||
|
|
||||||
|
|
||||||
def expand_stripped_size(data: bytes) -> bytes:
|
def expand_stripped_size(data: bytes | bytearray | memoryview) -> bytes:
|
||||||
header = bytearray(stripped_size_header)
|
header = bytearray(stripped_size_header)
|
||||||
header[164] = data[1]
|
header[164] = data[1]
|
||||||
header[166] = data[2]
|
header[166] = data[2]
|
||||||
|
@ -87,13 +97,14 @@ class InFileLike(Protocol):
|
||||||
It's only used in function parameters.
|
It's only used in function parameters.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def read(self, n: int) -> Union[bytes, Coroutine[Any, Any, bytes]]:
|
def read(self, n: int, /) -> Union[bytes, Coroutine[Any, Any, bytes]]:
|
||||||
"""
|
"""
|
||||||
Read from the file or buffer.
|
Read from the file or buffer.
|
||||||
|
|
||||||
:param n:
|
:param n:
|
||||||
Maximum amount of bytes that should be returned.
|
Maximum amount of bytes that should be returned.
|
||||||
"""
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class OutFileLike(Protocol):
|
class OutFileLike(Protocol):
|
||||||
|
@ -115,18 +126,21 @@ class OutFileLike(Protocol):
|
||||||
|
|
||||||
|
|
||||||
class OutWrapper:
|
class OutWrapper:
|
||||||
__slots__ = ("_fd", "_owned")
|
__slots__ = ("_fd", "_owned_fd")
|
||||||
|
|
||||||
|
_fd: Union[OutFileLike, BufferedWriter]
|
||||||
|
_owned_fd: Optional[BufferedWriter]
|
||||||
|
|
||||||
def __init__(self, file: Union[str, Path, OutFileLike]):
|
def __init__(self, file: Union[str, Path, OutFileLike]):
|
||||||
if isinstance(file, str):
|
if isinstance(file, str):
|
||||||
file = Path(file)
|
file = Path(file)
|
||||||
|
|
||||||
if isinstance(file, Path):
|
if isinstance(file, Path):
|
||||||
self._fd: Union[OutFileLike, BufferedWriter] = file.open("wb")
|
self._fd = file.open("wb")
|
||||||
self._owned = True
|
self._owned_fd = self._fd
|
||||||
else:
|
else:
|
||||||
self._fd = file
|
self._fd = file
|
||||||
self._owned = False
|
self._owned_fd = None
|
||||||
|
|
||||||
async def write(self, chunk: bytes) -> None:
|
async def write(self, chunk: bytes) -> None:
|
||||||
ret = self._fd.write(chunk)
|
ret = self._fd.write(chunk)
|
||||||
|
@ -134,9 +148,9 @@ class OutWrapper:
|
||||||
await ret
|
await ret
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
if self._owned:
|
if self._owned_fd is not None:
|
||||||
assert hasattr(self._fd, "close")
|
assert hasattr(self._owned_fd, "close")
|
||||||
self._fd.close()
|
self._owned_fd.close()
|
||||||
|
|
||||||
|
|
||||||
class File(metaclass=NoPublicConstructor):
|
class File(metaclass=NoPublicConstructor):
|
||||||
|
@ -150,7 +164,7 @@ class File(metaclass=NoPublicConstructor):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
attributes: List[abcs.DocumentAttribute],
|
attributes: Sequence[abcs.DocumentAttribute],
|
||||||
size: int,
|
size: int,
|
||||||
name: str,
|
name: str,
|
||||||
mime: str,
|
mime: str,
|
||||||
|
@ -158,7 +172,7 @@ class File(metaclass=NoPublicConstructor):
|
||||||
muted: bool,
|
muted: bool,
|
||||||
input_media: abcs.InputMedia,
|
input_media: abcs.InputMedia,
|
||||||
thumb: Optional[abcs.PhotoSize],
|
thumb: Optional[abcs.PhotoSize],
|
||||||
thumbs: Optional[List[abcs.PhotoSize]],
|
thumbs: Optional[Sequence[abcs.PhotoSize]],
|
||||||
raw: Optional[Union[abcs.MessageMedia, abcs.Photo, abcs.Document]],
|
raw: Optional[Union[abcs.MessageMedia, abcs.Photo, abcs.Document]],
|
||||||
client: Optional[Client],
|
client: Optional[Client],
|
||||||
):
|
):
|
||||||
|
@ -405,9 +419,9 @@ class File(metaclass=NoPublicConstructor):
|
||||||
id=self._input_media.id.id,
|
id=self._input_media.id.id,
|
||||||
access_hash=self._input_media.id.access_hash,
|
access_hash=self._input_media.id.access_hash,
|
||||||
file_reference=self._input_media.id.file_reference,
|
file_reference=self._input_media.id.file_reference,
|
||||||
thumb_size=self._thumb.type
|
thumb_size=(
|
||||||
if isinstance(self._thumb, thumb_types)
|
self._thumb.type if isinstance(self._thumb, thumb_types) else ""
|
||||||
else "",
|
),
|
||||||
)
|
)
|
||||||
elif isinstance(self._input_media, types.InputMediaPhoto):
|
elif isinstance(self._input_media, types.InputMediaPhoto):
|
||||||
assert isinstance(self._input_media.id, types.InputPhoto)
|
assert isinstance(self._input_media.id, types.InputPhoto)
|
||||||
|
|
|
@ -2,7 +2,17 @@ from __future__ import annotations
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Self, Tuple, Union
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Self,
|
||||||
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
from ...tl import abcs, types
|
from ...tl import abcs, types
|
||||||
from ..parsers import (
|
from ..parsers import (
|
||||||
|
@ -502,7 +512,7 @@ class Message(metaclass=NoPublicConstructor):
|
||||||
|
|
||||||
|
|
||||||
def build_msg_map(
|
def build_msg_map(
|
||||||
client: Client, messages: List[abcs.Message], chat_map: Dict[int, Chat]
|
client: Client, messages: Sequence[abcs.Message], chat_map: Dict[int, Chat]
|
||||||
) -> Dict[int, Message]:
|
) -> Dict[int, Message]:
|
||||||
return {
|
return {
|
||||||
msg.id: msg
|
msg.id: msg
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
"""
|
"""
|
||||||
Class definitions stolen from `trio`, with some modifications.
|
Class definitions stolen from `trio`, with some modifications.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
from typing import Type, TypeVar
|
from typing import Any, Type, TypeVar
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
@ -28,7 +29,7 @@ class Final(abc.ABCMeta):
|
||||||
|
|
||||||
|
|
||||||
class NoPublicConstructor(Final):
|
class NoPublicConstructor(Final):
|
||||||
def __call__(cls) -> None:
|
def __call__(cls, *args: Any, **kwds: Any) -> Any:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"{cls.__module__}.{cls.__qualname__} has no public constructor"
|
f"{cls.__module__}.{cls.__qualname__} has no public constructor"
|
||||||
)
|
)
|
||||||
|
|
|
@ -100,12 +100,8 @@ class Participant(metaclass=NoPublicConstructor):
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
return self._raw.user_id
|
return self._raw.user_id
|
||||||
elif isinstance(
|
|
||||||
self._raw, (types.ChannelParticipantBanned, types.ChannelParticipantLeft)
|
|
||||||
):
|
|
||||||
return peer_id(self._raw.peer)
|
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("unexpected case")
|
return peer_id(self._raw.peer)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def user(self) -> Optional[User]:
|
def user(self) -> Optional[User]:
|
||||||
|
|
|
@ -1,67 +1,16 @@
|
||||||
import pyaes
|
|
||||||
|
|
||||||
|
|
||||||
def ige_encrypt(plaintext: bytes, key: bytes, iv: bytes) -> bytes:
|
|
||||||
assert len(plaintext) % 16 == 0
|
|
||||||
assert len(iv) == 32
|
|
||||||
|
|
||||||
aes = pyaes.AES(key)
|
|
||||||
iv1 = iv[:16]
|
|
||||||
iv2 = iv[16:]
|
|
||||||
|
|
||||||
ciphertext = bytearray()
|
|
||||||
|
|
||||||
for block_offset in range(0, len(plaintext), 16):
|
|
||||||
plaintext_block = plaintext[block_offset : block_offset + 16]
|
|
||||||
ciphertext_block = bytes(
|
|
||||||
a ^ b
|
|
||||||
for a, b in zip(
|
|
||||||
aes.encrypt([a ^ b for a, b in zip(plaintext_block, iv1)]), iv2
|
|
||||||
)
|
|
||||||
)
|
|
||||||
iv1 = ciphertext_block
|
|
||||||
iv2 = plaintext_block
|
|
||||||
|
|
||||||
ciphertext += ciphertext_block
|
|
||||||
|
|
||||||
return bytes(ciphertext)
|
|
||||||
|
|
||||||
|
|
||||||
def ige_decrypt(ciphertext: bytes, key: bytes, iv: bytes) -> bytes:
|
|
||||||
assert len(ciphertext) % 16 == 0
|
|
||||||
assert len(iv) == 32
|
|
||||||
|
|
||||||
aes = pyaes.AES(key)
|
|
||||||
iv1 = iv[:16]
|
|
||||||
iv2 = iv[16:]
|
|
||||||
|
|
||||||
plaintext = bytearray()
|
|
||||||
|
|
||||||
for block_offset in range(0, len(ciphertext), 16):
|
|
||||||
ciphertext_block = ciphertext[block_offset : block_offset + 16]
|
|
||||||
plaintext_block = bytes(
|
|
||||||
a ^ b
|
|
||||||
for a, b in zip(
|
|
||||||
aes.decrypt([a ^ b for a, b in zip(ciphertext_block, iv2)]), iv1
|
|
||||||
)
|
|
||||||
)
|
|
||||||
iv1 = ciphertext_block
|
|
||||||
iv2 = plaintext_block
|
|
||||||
|
|
||||||
plaintext += plaintext_block
|
|
||||||
|
|
||||||
return bytes(plaintext)
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import cryptg
|
import cryptg
|
||||||
|
|
||||||
def ige_encrypt(plaintext: bytes, key: bytes, iv: bytes) -> bytes: # noqa: F811
|
def ige_encrypt(
|
||||||
|
plaintext: bytes | bytearray | memoryview, key: bytes, iv: bytes
|
||||||
|
) -> bytes: # noqa: F811
|
||||||
return cryptg.encrypt_ige(
|
return cryptg.encrypt_ige(
|
||||||
bytes(plaintext) if not isinstance(plaintext, bytes) else plaintext, key, iv
|
bytes(plaintext) if not isinstance(plaintext, bytes) else plaintext, key, iv
|
||||||
)
|
)
|
||||||
|
|
||||||
def ige_decrypt(ciphertext: bytes, key: bytes, iv: bytes) -> bytes: # noqa: F811
|
def ige_decrypt(
|
||||||
|
ciphertext: bytes | bytearray | memoryview, key: bytes, iv: bytes
|
||||||
|
) -> bytes: # noqa: F811
|
||||||
return cryptg.decrypt_ige(
|
return cryptg.decrypt_ige(
|
||||||
bytes(ciphertext) if not isinstance(ciphertext, bytes) else ciphertext,
|
bytes(ciphertext) if not isinstance(ciphertext, bytes) else ciphertext,
|
||||||
key,
|
key,
|
||||||
|
@ -69,4 +18,58 @@ try:
|
||||||
)
|
)
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
import pyaes
|
||||||
|
|
||||||
|
def ige_encrypt(
|
||||||
|
plaintext: bytes | bytearray | memoryview, key: bytes, iv: bytes
|
||||||
|
) -> bytes:
|
||||||
|
assert len(plaintext) % 16 == 0
|
||||||
|
assert len(iv) == 32
|
||||||
|
|
||||||
|
aes = pyaes.AES(key)
|
||||||
|
iv1 = iv[:16]
|
||||||
|
iv2 = iv[16:]
|
||||||
|
|
||||||
|
ciphertext = bytearray()
|
||||||
|
|
||||||
|
for block_offset in range(0, len(plaintext), 16):
|
||||||
|
plaintext_block = plaintext[block_offset : block_offset + 16]
|
||||||
|
ciphertext_block = bytes(
|
||||||
|
a ^ b
|
||||||
|
for a, b in zip(
|
||||||
|
aes.encrypt([a ^ b for a, b in zip(plaintext_block, iv1)]), iv2
|
||||||
|
)
|
||||||
|
)
|
||||||
|
iv1 = ciphertext_block
|
||||||
|
iv2 = plaintext_block
|
||||||
|
|
||||||
|
ciphertext += ciphertext_block
|
||||||
|
|
||||||
|
return bytes(ciphertext)
|
||||||
|
|
||||||
|
def ige_decrypt(
|
||||||
|
ciphertext: bytes | bytearray | memoryview, key: bytes, iv: bytes
|
||||||
|
) -> bytes:
|
||||||
|
assert len(ciphertext) % 16 == 0
|
||||||
|
assert len(iv) == 32
|
||||||
|
|
||||||
|
aes = pyaes.AES(key)
|
||||||
|
iv1 = iv[:16]
|
||||||
|
iv2 = iv[16:]
|
||||||
|
|
||||||
|
plaintext = bytearray()
|
||||||
|
|
||||||
|
for block_offset in range(0, len(ciphertext), 16):
|
||||||
|
ciphertext_block = ciphertext[block_offset : block_offset + 16]
|
||||||
|
plaintext_block = bytes(
|
||||||
|
a ^ b
|
||||||
|
for a, b in zip(
|
||||||
|
aes.decrypt([a ^ b for a, b in zip(ciphertext_block, iv2)]), iv1
|
||||||
|
)
|
||||||
|
)
|
||||||
|
iv1 = ciphertext_block
|
||||||
|
iv2 = plaintext_block
|
||||||
|
|
||||||
|
plaintext += plaintext_block
|
||||||
|
|
||||||
|
return bytes(plaintext)
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import os
|
import os
|
||||||
from collections import namedtuple
|
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
from hashlib import sha1, sha256
|
from hashlib import sha1, sha256
|
||||||
|
from typing import NamedTuple
|
||||||
|
|
||||||
from .aes import ige_decrypt, ige_encrypt
|
from .aes import ige_decrypt, ige_encrypt
|
||||||
from .auth_key import AuthKey
|
from .auth_key import AuthKey
|
||||||
|
@ -13,15 +13,19 @@ class Side(IntEnum):
|
||||||
SERVER = 8
|
SERVER = 8
|
||||||
|
|
||||||
|
|
||||||
CalcKey = namedtuple("CalcKey", ("key", "iv"))
|
class CalcKey(NamedTuple):
|
||||||
|
key: bytes
|
||||||
|
iv: bytes
|
||||||
|
|
||||||
|
|
||||||
# https://core.telegram.org/mtproto/description#defining-aes-key-and-initialization-vector
|
# https://core.telegram.org/mtproto/description#defining-aes-key-and-initialization-vector
|
||||||
def calc_key(auth_key: AuthKey, msg_key: bytes, side: Side) -> CalcKey:
|
def calc_key(
|
||||||
|
auth_key: AuthKey, msg_key: bytes | bytearray | memoryview, side: Side
|
||||||
|
) -> CalcKey:
|
||||||
x = int(side)
|
x = int(side)
|
||||||
|
|
||||||
# sha256_a = SHA256 (msg_key + substr (auth_key, x, 36))
|
# sha256_a = SHA256 (msg_key + substr (auth_key, x, 36))
|
||||||
sha256_a = sha256(msg_key + auth_key.data[x : x + 36]).digest()
|
sha256_a = sha256(bytes(msg_key) + auth_key.data[x : x + 36]).digest()
|
||||||
|
|
||||||
# sha256_b = SHA256 (substr (auth_key, 40+x, 36) + msg_key)
|
# sha256_b = SHA256 (substr (auth_key, 40+x, 36) + msg_key)
|
||||||
sha256_b = sha256(auth_key.data[x + 40 : x + 76] + msg_key).digest()
|
sha256_b = sha256(auth_key.data[x + 40 : x + 76] + msg_key).digest()
|
||||||
|
@ -66,7 +70,9 @@ def encrypt_data_v2(plaintext: bytes, auth_key: AuthKey) -> bytes:
|
||||||
return _do_encrypt_data_v2(plaintext, auth_key, random_padding)
|
return _do_encrypt_data_v2(plaintext, auth_key, random_padding)
|
||||||
|
|
||||||
|
|
||||||
def decrypt_data_v2(ciphertext: bytes, auth_key: AuthKey) -> bytes:
|
def decrypt_data_v2(
|
||||||
|
ciphertext: bytes | bytearray | memoryview, auth_key: AuthKey
|
||||||
|
) -> bytes:
|
||||||
side = Side.SERVER
|
side = Side.SERVER
|
||||||
x = int(side)
|
x = int(side)
|
||||||
|
|
||||||
|
|
|
@ -1,45 +1,58 @@
|
||||||
# Ported from https://github.com/Lonami/grammers/blob/d91dc82/lib/grammers-crypto/src/two_factor_auth.rs
|
# Ported from https://github.com/Lonami/grammers/blob/d91dc82/lib/grammers-crypto/src/two_factor_auth.rs
|
||||||
from collections import namedtuple
|
|
||||||
from hashlib import pbkdf2_hmac, sha256
|
from hashlib import pbkdf2_hmac, sha256
|
||||||
|
from typing import NamedTuple
|
||||||
|
|
||||||
from .factorize import factorize
|
from .factorize import factorize
|
||||||
|
|
||||||
TwoFactorAuth = namedtuple("TwoFactorAuth", ("m1", "g_a"))
|
|
||||||
|
class TwoFactorAuth(NamedTuple):
|
||||||
|
m1: bytes
|
||||||
|
g_a: bytes
|
||||||
|
|
||||||
|
|
||||||
def pad_to_256(data: bytes) -> bytes:
|
def pad_to_256(data: bytes | bytearray | memoryview) -> bytes:
|
||||||
return bytes(256 - len(data)) + data
|
return bytes(256 - len(data)) + data
|
||||||
|
|
||||||
|
|
||||||
# H(data) := sha256(data)
|
# H(data) := sha256(data)
|
||||||
def h(*data: bytes) -> bytes:
|
def h(*data: bytes | bytearray | memoryview) -> bytes:
|
||||||
return sha256(b"".join(data)).digest()
|
return sha256(b"".join(data)).digest()
|
||||||
|
|
||||||
|
|
||||||
# SH(data, salt) := H(salt | data | salt)
|
# SH(data, salt) := H(salt | data | salt)
|
||||||
def sh(data: bytes, salt: bytes) -> bytes:
|
def sh(
|
||||||
|
data: bytes | bytearray | memoryview, salt: bytes | bytearray | memoryview
|
||||||
|
) -> bytes:
|
||||||
return h(salt, data, salt)
|
return h(salt, data, salt)
|
||||||
|
|
||||||
|
|
||||||
# PH1(password, salt1, salt2) := SH(SH(password, salt1), salt2)
|
# PH1(password, salt1, salt2) := SH(SH(password, salt1), salt2)
|
||||||
def ph1(password: bytes, salt1: bytes, salt2: bytes) -> bytes:
|
def ph1(
|
||||||
|
password: bytes | bytearray | memoryview,
|
||||||
|
salt1: bytes | bytearray | memoryview,
|
||||||
|
salt2: bytes | bytearray | memoryview,
|
||||||
|
) -> bytes:
|
||||||
return sh(sh(password, salt1), salt2)
|
return sh(sh(password, salt1), salt2)
|
||||||
|
|
||||||
|
|
||||||
# PH2(password, salt1, salt2) := SH(pbkdf2(sha512, PH1(password, salt1, salt2), salt1, 100000), salt2)
|
# PH2(password, salt1, salt2) := SH(pbkdf2(sha512, PH1(password, salt1, salt2), salt1, 100000), salt2)
|
||||||
def ph2(password: bytes, salt1: bytes, salt2: bytes) -> bytes:
|
def ph2(
|
||||||
|
password: bytes | bytearray | memoryview,
|
||||||
|
salt1: bytes | bytearray | memoryview,
|
||||||
|
salt2: bytes | bytearray | memoryview,
|
||||||
|
) -> bytes:
|
||||||
return sh(pbkdf2_hmac("sha512", ph1(password, salt1, salt2), salt1, 100000), salt2)
|
return sh(pbkdf2_hmac("sha512", ph1(password, salt1, salt2), salt1, 100000), salt2)
|
||||||
|
|
||||||
|
|
||||||
# https://core.telegram.org/api/srp
|
# https://core.telegram.org/api/srp
|
||||||
def calculate_2fa(
|
def calculate_2fa(
|
||||||
*,
|
*,
|
||||||
salt1: bytes,
|
salt1: bytes | bytearray | memoryview,
|
||||||
salt2: bytes,
|
salt2: bytes | bytearray | memoryview,
|
||||||
g: int,
|
g: int,
|
||||||
p: bytes,
|
p: bytes | bytearray | memoryview,
|
||||||
g_b: bytes,
|
g_b: bytes | bytearray | memoryview,
|
||||||
a: bytes,
|
a: bytes | bytearray | memoryview,
|
||||||
password: bytes,
|
password: bytes,
|
||||||
) -> TwoFactorAuth:
|
) -> TwoFactorAuth:
|
||||||
big_p = int.from_bytes(p)
|
big_p = int.from_bytes(p)
|
||||||
|
@ -100,16 +113,16 @@ def calculate_2fa(
|
||||||
return TwoFactorAuth(m1, g_a)
|
return TwoFactorAuth(m1, g_a)
|
||||||
|
|
||||||
|
|
||||||
def check_p_len(p: bytes) -> bool:
|
def check_p_len(p: bytes | bytearray | memoryview) -> bool:
|
||||||
return len(p) == 256
|
return len(p) == 256
|
||||||
|
|
||||||
|
|
||||||
def check_known_prime(p: bytes, g: int) -> bool:
|
def check_known_prime(p: bytes | bytearray | memoryview, g: int) -> bool:
|
||||||
good_prime = b"\xc7\x1c\xae\xb9\xc6\xb1\xc9\x04\x8elR/p\xf1?s\x98\r@#\x8e>!\xc1I4\xd07V=\x93\x0fH\x19\x8a\n\xa7\xc1@X\"\x94\x93\xd2%0\xf4\xdb\xfa3on\n\xc9%\x13\x95C\xae\xd4L\xce|7 \xfdQ\xf6\x94XpZ\xc6\x8c\xd4\xfekk\x13\xab\xdc\x97FQ)i2\x84T\xf1\x8f\xaf\x8cY_d$w\xfe\x96\xbb*\x94\x1d[\xcd\x1dJ\xc8\xccI\x88\x07\x08\xfa\x9b7\x8e<O:\x90`\xbe\xe6|\xf9\xa4\xa4\xa6\x95\x81\x10Q\x90~\x16'S\xb5k\x0fkA\r\xbat\xd8\xa8K*\x14\xb3\x14N\x0e\xf1(GT\xfd\x17\xed\x95\rYe\xb4\xb9\xddFX-\xb1\x17\x8d\x16\x9ck\xc4e\xb0\xd6\xff\x9c\xa3\x92\x8f\xef[\x9a\xe4\xe4\x18\xfc\x15\xe8>\xbe\xa0\xf8\x7f\xa9\xff^\xedp\x05\r\xed(I\xf4{\xf9Y\xd9V\x85\x0c\xe9)\x85\x1f\r\x81\x15\xf65\xb1\x05\xee.N\x15\xd0K$T\xbfoO\xad\xf04\xb1\x04\x03\x11\x9c\xd8\xe3\xb9/\xcc["
|
good_prime = b"\xc7\x1c\xae\xb9\xc6\xb1\xc9\x04\x8elR/p\xf1?s\x98\r@#\x8e>!\xc1I4\xd07V=\x93\x0fH\x19\x8a\n\xa7\xc1@X\"\x94\x93\xd2%0\xf4\xdb\xfa3on\n\xc9%\x13\x95C\xae\xd4L\xce|7 \xfdQ\xf6\x94XpZ\xc6\x8c\xd4\xfekk\x13\xab\xdc\x97FQ)i2\x84T\xf1\x8f\xaf\x8cY_d$w\xfe\x96\xbb*\x94\x1d[\xcd\x1dJ\xc8\xccI\x88\x07\x08\xfa\x9b7\x8e<O:\x90`\xbe\xe6|\xf9\xa4\xa4\xa6\x95\x81\x10Q\x90~\x16'S\xb5k\x0fkA\r\xbat\xd8\xa8K*\x14\xb3\x14N\x0e\xf1(GT\xfd\x17\xed\x95\rYe\xb4\xb9\xddFX-\xb1\x17\x8d\x16\x9ck\xc4e\xb0\xd6\xff\x9c\xa3\x92\x8f\xef[\x9a\xe4\xe4\x18\xfc\x15\xe8>\xbe\xa0\xf8\x7f\xa9\xff^\xedp\x05\r\xed(I\xf4{\xf9Y\xd9V\x85\x0c\xe9)\x85\x1f\r\x81\x15\xf65\xb1\x05\xee.N\x15\xd0K$T\xbfoO\xad\xf04\xb1\x04\x03\x11\x9c\xd8\xe3\xb9/\xcc["
|
||||||
return p == good_prime and g in (3, 4, 5, 7)
|
return p == good_prime and g in (3, 4, 5, 7)
|
||||||
|
|
||||||
|
|
||||||
def check_p_prime_and_subgroup(p: bytes, g: int) -> bool:
|
def check_p_prime_and_subgroup(p: bytes | bytearray | memoryview, g: int) -> bool:
|
||||||
if check_known_prime(p, g):
|
if check_known_prime(p, g):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -133,7 +146,7 @@ def check_p_prime_and_subgroup(p: bytes, g: int) -> bool:
|
||||||
return candidate and factorize((big_p - 1) // 2)[0] == 1
|
return candidate and factorize((big_p - 1) // 2)[0] == 1
|
||||||
|
|
||||||
|
|
||||||
def check_p_and_g(p: bytes, g: int) -> bool:
|
def check_p_and_g(p: bytes | bytearray | memoryview, g: int) -> bool:
|
||||||
if not check_p_len(p):
|
if not check_p_len(p):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
|
@ -164,6 +164,7 @@ def _do_step3(
|
||||||
)
|
)
|
||||||
|
|
||||||
key, iv = generate_key_data_from_nonce(server_nonce, new_nonce)
|
key, iv = generate_key_data_from_nonce(server_nonce, new_nonce)
|
||||||
|
assert isinstance(server_dh_params.encrypted_answer, bytes)
|
||||||
plain_text_answer = decrypt_ige(server_dh_params.encrypted_answer, key, iv)
|
plain_text_answer = decrypt_ige(server_dh_params.encrypted_answer, key, iv)
|
||||||
|
|
||||||
got_answer_hash = plain_text_answer[:20]
|
got_answer_hash = plain_text_answer[:20]
|
||||||
|
|
|
@ -260,7 +260,7 @@ class Encrypted(Mtp):
|
||||||
self._store_own_updates(result)
|
self._store_own_updates(result)
|
||||||
self._deserialization.append(RpcResult(msg_id, result))
|
self._deserialization.append(RpcResult(msg_id, result))
|
||||||
|
|
||||||
def _store_own_updates(self, body: bytes) -> None:
|
def _store_own_updates(self, body: bytes | bytearray | memoryview) -> None:
|
||||||
constructor_id = struct.unpack_from("I", body)[0]
|
constructor_id = struct.unpack_from("I", body)[0]
|
||||||
if constructor_id in UPDATE_IDS:
|
if constructor_id in UPDATE_IDS:
|
||||||
self._deserialization.append(Update(body))
|
self._deserialization.append(Update(body))
|
||||||
|
@ -332,7 +332,7 @@ class Encrypted(Mtp):
|
||||||
)
|
)
|
||||||
|
|
||||||
self._start_salt_time = (salts.now, self._adjusted_now())
|
self._start_salt_time = (salts.now, self._adjusted_now())
|
||||||
self._salts = salts.salts
|
self._salts = list(salts.salts)
|
||||||
self._salts.sort(key=lambda salt: -salt.valid_since)
|
self._salts.sort(key=lambda salt: -salt.valid_since)
|
||||||
|
|
||||||
def _handle_future_salt(self, message: Message) -> None:
|
def _handle_future_salt(self, message: Message) -> None:
|
||||||
|
@ -439,7 +439,9 @@ class Encrypted(Mtp):
|
||||||
msg_id, buffer = result
|
msg_id, buffer = result
|
||||||
return msg_id, encrypt_data_v2(buffer, self._auth_key)
|
return msg_id, encrypt_data_v2(buffer, self._auth_key)
|
||||||
|
|
||||||
def deserialize(self, payload: bytes) -> List[Deserialization]:
|
def deserialize(
|
||||||
|
self, payload: bytes | bytearray | memoryview
|
||||||
|
) -> List[Deserialization]:
|
||||||
check_message_buffer(payload)
|
check_message_buffer(payload)
|
||||||
|
|
||||||
plaintext = decrypt_data_v2(payload, self._auth_key)
|
plaintext = decrypt_data_v2(payload, self._auth_key)
|
||||||
|
|
|
@ -31,7 +31,9 @@ class Plain(Mtp):
|
||||||
self._buffer.clear()
|
self._buffer.clear()
|
||||||
return MsgId(0), result
|
return MsgId(0), result
|
||||||
|
|
||||||
def deserialize(self, payload: bytes) -> List[Deserialization]:
|
def deserialize(
|
||||||
|
self, payload: bytes | bytearray | memoryview
|
||||||
|
) -> List[Deserialization]:
|
||||||
check_message_buffer(payload)
|
check_message_buffer(payload)
|
||||||
|
|
||||||
auth_key_id, msg_id, length = struct.unpack_from("<qqi", payload)
|
auth_key_id, msg_id, length = struct.unpack_from("<qqi", payload)
|
||||||
|
|
|
@ -15,7 +15,7 @@ class Update:
|
||||||
|
|
||||||
__slots__ = ("body",)
|
__slots__ = ("body",)
|
||||||
|
|
||||||
def __init__(self, body: bytes):
|
def __init__(self, body: bytes | bytearray | memoryview):
|
||||||
self.body = body
|
self.body = body
|
||||||
|
|
||||||
|
|
||||||
|
@ -26,7 +26,7 @@ class RpcResult:
|
||||||
|
|
||||||
__slots__ = ("msg_id", "body")
|
__slots__ = ("msg_id", "body")
|
||||||
|
|
||||||
def __init__(self, msg_id: MsgId, body: bytes):
|
def __init__(self, msg_id: MsgId, body: bytes | bytearray | memoryview):
|
||||||
self.msg_id = msg_id
|
self.msg_id = msg_id
|
||||||
self.body = body
|
self.body = body
|
||||||
|
|
||||||
|
@ -201,7 +201,9 @@ class Mtp(ABC):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def deserialize(self, payload: bytes) -> List[Deserialization]:
|
def deserialize(
|
||||||
|
self, payload: bytes | bytearray | memoryview
|
||||||
|
) -> List[Deserialization]:
|
||||||
"""
|
"""
|
||||||
Deserialize incoming buffer payload.
|
Deserialize incoming buffer payload.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -12,7 +12,7 @@ class Transport(ABC):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def unpack(self, input: bytes, output: bytearray) -> int:
|
def unpack(self, input: bytes | bytearray | memoryview, output: bytearray) -> int:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -36,7 +36,7 @@ class Abridged(Transport):
|
||||||
write(struct.pack("<i", 0x7F | (length << 8)))
|
write(struct.pack("<i", 0x7F | (length << 8)))
|
||||||
write(input)
|
write(input)
|
||||||
|
|
||||||
def unpack(self, input: bytes, output: bytearray) -> int:
|
def unpack(self, input: bytes | bytearray | memoryview, output: bytearray) -> int:
|
||||||
if not input:
|
if not input:
|
||||||
raise MissingBytes(expected=1, got=0)
|
raise MissingBytes(expected=1, got=0)
|
||||||
|
|
||||||
|
|
|
@ -35,7 +35,7 @@ class Full(Transport):
|
||||||
write(struct.pack("<I", crc32(tmp)))
|
write(struct.pack("<I", crc32(tmp)))
|
||||||
self._send_seq += 1
|
self._send_seq += 1
|
||||||
|
|
||||||
def unpack(self, input: bytes, output: bytearray) -> int:
|
def unpack(self, input: bytes | bytearray | memoryview, output: bytearray) -> int:
|
||||||
if len(input) < 4:
|
if len(input) < 4:
|
||||||
raise MissingBytes(expected=4, got=len(input))
|
raise MissingBytes(expected=4, got=len(input))
|
||||||
|
|
||||||
|
|
|
@ -32,7 +32,7 @@ class Intermediate(Transport):
|
||||||
write(struct.pack("<i", len(input)))
|
write(struct.pack("<i", len(input)))
|
||||||
write(input)
|
write(input)
|
||||||
|
|
||||||
def unpack(self, input: bytes, output: bytearray) -> int:
|
def unpack(self, input: bytes | bytearray | memoryview, output: bytearray) -> int:
|
||||||
if len(input) < 4:
|
if len(input) < 4:
|
||||||
raise MissingBytes(expected=4, got=len(input))
|
raise MissingBytes(expected=4, got=len(input))
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,7 @@ CONTAINER_MAX_LENGTH = 100
|
||||||
MESSAGE_SIZE_OVERHEAD = 8 + 4 + 4 # msg_id, seq_no, bytes
|
MESSAGE_SIZE_OVERHEAD = 8 + 4 + 4 # msg_id, seq_no, bytes
|
||||||
|
|
||||||
|
|
||||||
def check_message_buffer(message: bytes) -> None:
|
def check_message_buffer(message: bytes | bytearray | memoryview) -> None:
|
||||||
if len(message) < 20:
|
if len(message) < 20:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"server payload is too small to be a valid message: {message.hex()}"
|
f"server payload is too small to be a valid message: {message.hex()}"
|
||||||
|
|
|
@ -70,6 +70,7 @@ class AsyncReader(Protocol):
|
||||||
:param n:
|
:param n:
|
||||||
Amount of bytes to read at most.
|
Amount of bytes to read at most.
|
||||||
"""
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class AsyncWriter(Protocol):
|
class AsyncWriter(Protocol):
|
||||||
|
@ -77,7 +78,7 @@ class AsyncWriter(Protocol):
|
||||||
A :class:`asyncio.StreamWriter`-like class.
|
A :class:`asyncio.StreamWriter`-like class.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def write(self, data: bytes) -> None:
|
def write(self, data: bytes | bytearray | memoryview) -> None:
|
||||||
"""
|
"""
|
||||||
Must behave like :meth:`asyncio.StreamWriter.write`.
|
Must behave like :meth:`asyncio.StreamWriter.write`.
|
||||||
|
|
||||||
|
@ -127,7 +128,7 @@ class Connector(Protocol):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def __call__(self, ip: str, port: int) -> Tuple[AsyncReader, AsyncWriter]:
|
async def __call__(self, ip: str, port: int) -> Tuple[AsyncReader, AsyncWriter]:
|
||||||
pass
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class RequestState(ABC):
|
class RequestState(ABC):
|
||||||
|
@ -340,12 +341,12 @@ class Sender:
|
||||||
self._process_result(result)
|
self._process_result(result)
|
||||||
elif isinstance(result, RpcError):
|
elif isinstance(result, RpcError):
|
||||||
self._process_error(result)
|
self._process_error(result)
|
||||||
elif isinstance(result, BadMessage):
|
|
||||||
self._process_bad_message(result)
|
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("unexpected case")
|
self._process_bad_message(result)
|
||||||
|
|
||||||
def _process_update(self, updates: List[Updates], update: bytes) -> None:
|
def _process_update(
|
||||||
|
self, updates: List[Updates], update: bytes | bytearray | memoryview
|
||||||
|
) -> None:
|
||||||
try:
|
try:
|
||||||
updates.append(Updates.from_bytes(update))
|
updates.append(Updates.from_bytes(update))
|
||||||
except ValueError:
|
except ValueError:
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Any, Dict, Optional, Sequence, Tuple
|
||||||
|
|
||||||
from ...tl import abcs, types
|
from ...tl import abcs, types
|
||||||
from .packed import PackedChat, PackedType
|
from .packed import PackedChat, PackedType
|
||||||
|
@ -131,7 +131,7 @@ class ChatHashCache:
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("unexpected case")
|
raise RuntimeError("unexpected case")
|
||||||
|
|
||||||
def extend(self, users: List[abcs.User], chats: List[abcs.Chat]) -> bool:
|
def extend(self, users: Sequence[abcs.User], chats: Sequence[abcs.Chat]) -> bool:
|
||||||
# See https://core.telegram.org/api/min for "issues" with "min constructors".
|
# See https://core.telegram.org/api/min for "issues" with "min constructors".
|
||||||
success = True
|
success = True
|
||||||
|
|
||||||
|
@ -299,6 +299,7 @@ class ChatHashCache:
|
||||||
success &= self._has_notify_peer(peer)
|
success &= self._has_notify_peer(peer)
|
||||||
|
|
||||||
for field in ("users",):
|
for field in ("users",):
|
||||||
|
user: Any
|
||||||
users = getattr(message.action, field, None)
|
users = getattr(message.action, field, None)
|
||||||
if isinstance(users, list):
|
if isinstance(users, list):
|
||||||
for user in users:
|
for user in users:
|
||||||
|
|
|
@ -4,11 +4,35 @@ from typing import List, Literal, Union
|
||||||
|
|
||||||
from ...tl import abcs
|
from ...tl import abcs
|
||||||
|
|
||||||
|
NO_DATE = 0 # used on adapted messages.affected* from lower layers
|
||||||
|
NO_SEQ = 0
|
||||||
|
|
||||||
|
NO_PTS = 0
|
||||||
|
|
||||||
|
# https://core.telegram.org/method/updates.getChannelDifference
|
||||||
|
BOT_CHANNEL_DIFF_LIMIT = 100000
|
||||||
|
USER_CHANNEL_DIFF_LIMIT = 100
|
||||||
|
|
||||||
|
POSSIBLE_GAP_TIMEOUT = 0.5
|
||||||
|
|
||||||
|
# https://core.telegram.org/api/updates
|
||||||
|
NO_UPDATES_TIMEOUT = 15 * 60
|
||||||
|
|
||||||
|
ENTRY_ACCOUNT: Literal["ACCOUNT"] = "ACCOUNT"
|
||||||
|
ENTRY_SECRET: Literal["SECRET"] = "SECRET"
|
||||||
|
Entry = Union[Literal["ACCOUNT", "SECRET"], int]
|
||||||
|
|
||||||
|
# Python's logging doesn't define a TRACE level. Pick halfway between DEBUG and NOTSET.
|
||||||
|
# We don't define a name for this as libraries shouldn't do that though.
|
||||||
|
LOG_LEVEL_TRACE = (logging.DEBUG - logging.NOTSET) // 2
|
||||||
|
|
||||||
|
|
||||||
class PtsInfo:
|
class PtsInfo:
|
||||||
__slots__ = ("pts", "pts_count", "entry")
|
__slots__ = ("pts", "pts_count", "entry")
|
||||||
|
|
||||||
def __init__(self, entry: "Entry", pts: int, pts_count: int) -> None:
|
entry: Entry # pyright needs this or it infers int | str
|
||||||
|
|
||||||
|
def __init__(self, entry: Entry, pts: int, pts_count: int) -> None:
|
||||||
self.pts = pts
|
self.pts = pts
|
||||||
self.pts_count = pts_count
|
self.pts_count = pts_count
|
||||||
self.entry = entry
|
self.entry = entry
|
||||||
|
@ -59,26 +83,3 @@ class PrematureEndReason(Enum):
|
||||||
class Gap(ValueError):
|
class Gap(ValueError):
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return "Gap()"
|
return "Gap()"
|
||||||
|
|
||||||
|
|
||||||
NO_DATE = 0 # used on adapted messages.affected* from lower layers
|
|
||||||
NO_SEQ = 0
|
|
||||||
|
|
||||||
NO_PTS = 0
|
|
||||||
|
|
||||||
# https://core.telegram.org/method/updates.getChannelDifference
|
|
||||||
BOT_CHANNEL_DIFF_LIMIT = 100000
|
|
||||||
USER_CHANNEL_DIFF_LIMIT = 100
|
|
||||||
|
|
||||||
POSSIBLE_GAP_TIMEOUT = 0.5
|
|
||||||
|
|
||||||
# https://core.telegram.org/api/updates
|
|
||||||
NO_UPDATES_TIMEOUT = 15 * 60
|
|
||||||
|
|
||||||
ENTRY_ACCOUNT: Literal["ACCOUNT"] = "ACCOUNT"
|
|
||||||
ENTRY_SECRET: Literal["SECRET"] = "SECRET"
|
|
||||||
Entry = Union[Literal["ACCOUNT", "SECRET"], int]
|
|
||||||
|
|
||||||
# Python's logging doesn't define a TRACE level. Pick halfway between DEBUG and NOTSET.
|
|
||||||
# We don't define a name for this as libraries shouldn't do that though.
|
|
||||||
LOG_LEVEL_TRACE = (logging.DEBUG - logging.NOTSET) // 2
|
|
||||||
|
|
|
@ -2,7 +2,7 @@ import asyncio
|
||||||
import datetime
|
import datetime
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import Dict, List, Optional, Set, Tuple
|
from typing import Dict, List, Optional, Sequence, Set, Tuple
|
||||||
|
|
||||||
from ...tl import Request, abcs, functions, types
|
from ...tl import Request, abcs, functions, types
|
||||||
from ..chat import ChatHashCache
|
from ..chat import ChatHashCache
|
||||||
|
@ -168,6 +168,7 @@ class MessageBox:
|
||||||
if not entries:
|
if not entries:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
entry: Entry = ENTRY_ACCOUNT # for pyright to know it's not unbound
|
||||||
for entry in entries:
|
for entry in entries:
|
||||||
if entry not in self.map:
|
if entry not in self.map:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
@ -258,7 +259,7 @@ class MessageBox:
|
||||||
self,
|
self,
|
||||||
updates: abcs.Updates,
|
updates: abcs.Updates,
|
||||||
chat_hashes: ChatHashCache,
|
chat_hashes: ChatHashCache,
|
||||||
) -> Tuple[List[abcs.Update], List[abcs.User], List[abcs.Chat]]:
|
) -> Tuple[List[abcs.Update], Sequence[abcs.User], Sequence[abcs.Chat]]:
|
||||||
result: List[abcs.Update] = []
|
result: List[abcs.Update] = []
|
||||||
combined = adapt(updates, chat_hashes)
|
combined = adapt(updates, chat_hashes)
|
||||||
|
|
||||||
|
@ -289,7 +290,7 @@ class MessageBox:
|
||||||
sorted_updates = list(sorted(combined.updates, key=update_sort_key))
|
sorted_updates = list(sorted(combined.updates, key=update_sort_key))
|
||||||
|
|
||||||
any_pts_applied = False
|
any_pts_applied = False
|
||||||
reset_deadlines_for = set()
|
reset_deadlines_for: Set[Entry] = set()
|
||||||
for update in sorted_updates:
|
for update in sorted_updates:
|
||||||
entry, applied = self.apply_pts_info(update)
|
entry, applied = self.apply_pts_info(update)
|
||||||
if entry is not None:
|
if entry is not None:
|
||||||
|
@ -420,9 +421,11 @@ class MessageBox:
|
||||||
pts_limit=None,
|
pts_limit=None,
|
||||||
pts_total_limit=None,
|
pts_total_limit=None,
|
||||||
date=int(self.date.timestamp()),
|
date=int(self.date.timestamp()),
|
||||||
qts=self.map[ENTRY_SECRET].pts
|
qts=(
|
||||||
if ENTRY_SECRET in self.map
|
self.map[ENTRY_SECRET].pts
|
||||||
else NO_SEQ,
|
if ENTRY_SECRET in self.map
|
||||||
|
else NO_SEQ
|
||||||
|
),
|
||||||
qts_limit=None,
|
qts_limit=None,
|
||||||
)
|
)
|
||||||
if __debug__:
|
if __debug__:
|
||||||
|
@ -435,12 +438,12 @@ class MessageBox:
|
||||||
self,
|
self,
|
||||||
diff: abcs.updates.Difference,
|
diff: abcs.updates.Difference,
|
||||||
chat_hashes: ChatHashCache,
|
chat_hashes: ChatHashCache,
|
||||||
) -> Tuple[List[abcs.Update], List[abcs.User], List[abcs.Chat]]:
|
) -> Tuple[List[abcs.Update], Sequence[abcs.User], Sequence[abcs.Chat]]:
|
||||||
if __debug__:
|
if __debug__:
|
||||||
self._trace("applying account difference: %s", diff)
|
self._trace("applying account difference: %s", diff)
|
||||||
|
|
||||||
finish: bool
|
finish: bool
|
||||||
result: Tuple[List[abcs.Update], List[abcs.User], List[abcs.Chat]]
|
result: Tuple[List[abcs.Update], Sequence[abcs.User], Sequence[abcs.Chat]]
|
||||||
if isinstance(diff, types.updates.DifferenceEmpty):
|
if isinstance(diff, types.updates.DifferenceEmpty):
|
||||||
finish = True
|
finish = True
|
||||||
self.date = datetime.datetime.fromtimestamp(
|
self.date = datetime.datetime.fromtimestamp(
|
||||||
|
@ -493,7 +496,7 @@ class MessageBox:
|
||||||
self,
|
self,
|
||||||
diff: types.updates.Difference,
|
diff: types.updates.Difference,
|
||||||
chat_hashes: ChatHashCache,
|
chat_hashes: ChatHashCache,
|
||||||
) -> Tuple[List[abcs.Update], List[abcs.User], List[abcs.Chat]]:
|
) -> Tuple[List[abcs.Update], Sequence[abcs.User], Sequence[abcs.Chat]]:
|
||||||
state = diff.state
|
state = diff.state
|
||||||
assert isinstance(state, types.updates.State)
|
assert isinstance(state, types.updates.State)
|
||||||
self.map[ENTRY_ACCOUNT].pts = state.pts
|
self.map[ENTRY_ACCOUNT].pts = state.pts
|
||||||
|
@ -556,9 +559,11 @@ class MessageBox:
|
||||||
channel=channel,
|
channel=channel,
|
||||||
filter=types.ChannelMessagesFilterEmpty(),
|
filter=types.ChannelMessagesFilterEmpty(),
|
||||||
pts=state.pts,
|
pts=state.pts,
|
||||||
limit=BOT_CHANNEL_DIFF_LIMIT
|
limit=(
|
||||||
if chat_hashes.is_self_bot
|
BOT_CHANNEL_DIFF_LIMIT
|
||||||
else USER_CHANNEL_DIFF_LIMIT,
|
if chat_hashes.is_self_bot
|
||||||
|
else USER_CHANNEL_DIFF_LIMIT
|
||||||
|
),
|
||||||
)
|
)
|
||||||
if __debug__:
|
if __debug__:
|
||||||
self._trace("requesting channel difference: %s", gd)
|
self._trace("requesting channel difference: %s", gd)
|
||||||
|
@ -577,7 +582,7 @@ class MessageBox:
|
||||||
channel_id: int,
|
channel_id: int,
|
||||||
diff: abcs.updates.ChannelDifference,
|
diff: abcs.updates.ChannelDifference,
|
||||||
chat_hashes: ChatHashCache,
|
chat_hashes: ChatHashCache,
|
||||||
) -> Tuple[List[abcs.Update], List[abcs.User], List[abcs.Chat]]:
|
) -> Tuple[List[abcs.Update], Sequence[abcs.User], Sequence[abcs.Chat]]:
|
||||||
entry: Entry = channel_id
|
entry: Entry = channel_id
|
||||||
if __debug__:
|
if __debug__:
|
||||||
self._trace("applying channel=%r difference: %s", entry, diff)
|
self._trace("applying channel=%r difference: %s", entry, diff)
|
||||||
|
|
|
@ -4,7 +4,7 @@ from .memory import MemorySession
|
||||||
from .storage import Storage
|
from .storage import Storage
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from .sqlite import SqliteSession
|
from .sqlite import SqliteSession # pyright: ignore [reportAssignmentType]
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
import_err = e
|
import_err = e
|
||||||
|
|
||||||
|
|
|
@ -32,7 +32,7 @@ def _bootstrap_get_ty(constructor_id: int) -> Optional[Type["Serializable"]]:
|
||||||
class Reader:
|
class Reader:
|
||||||
__slots__ = ("_view", "_pos", "_len")
|
__slots__ = ("_view", "_pos", "_len")
|
||||||
|
|
||||||
def __init__(self, buffer: bytes | memoryview) -> None:
|
def __init__(self, buffer: bytes | bytearray | memoryview) -> None:
|
||||||
self._view = (
|
self._view = (
|
||||||
memoryview(buffer) if not isinstance(buffer, memoryview) else buffer
|
memoryview(buffer) if not isinstance(buffer, memoryview) else buffer
|
||||||
)
|
)
|
||||||
|
|
|
@ -9,10 +9,10 @@ class HasSlots(Protocol):
|
||||||
__slots__: Tuple[str, ...]
|
__slots__: Tuple[str, ...]
|
||||||
|
|
||||||
|
|
||||||
def obj_repr(obj: HasSlots) -> str:
|
def obj_repr(self: HasSlots) -> str:
|
||||||
fields = ((attr, getattr(obj, attr)) for attr in obj.__slots__)
|
fields = ((attr, getattr(self, attr)) for attr in self.__slots__)
|
||||||
params = ", ".join(f"{name}={field!r}" for name, field in fields)
|
params = ", ".join(f"{name}={field!r}" for name, field in fields)
|
||||||
return f"{obj.__class__.__name__}({params})"
|
return f"{self.__class__.__name__}({params})"
|
||||||
|
|
||||||
|
|
||||||
class Serializable(abc.ABC):
|
class Serializable(abc.ABC):
|
||||||
|
@ -36,7 +36,7 @@ class Serializable(abc.ABC):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_bytes(cls, blob: bytes) -> Self:
|
def from_bytes(cls, blob: bytes | bytearray | memoryview) -> Self:
|
||||||
return Reader(blob).read_serializable(cls)
|
return Reader(blob).read_serializable(cls)
|
||||||
|
|
||||||
def __bytes__(self) -> bytes:
|
def __bytes__(self) -> bytes:
|
||||||
|
@ -54,7 +54,7 @@ class Serializable(abc.ABC):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def serialize_bytes_to(buffer: bytearray, data: bytes) -> None:
|
def serialize_bytes_to(buffer: bytearray, data: bytes | bytearray | memoryview) -> None:
|
||||||
length = len(data)
|
length = len(data)
|
||||||
if length < 0xFE:
|
if length < 0xFE:
|
||||||
buffer += struct.pack("<B", length)
|
buffer += struct.pack("<B", length)
|
||||||
|
|
|
@ -65,7 +65,7 @@ def test_key_from_nonce() -> None:
|
||||||
server_nonce = int.from_bytes(bytes(range(16)))
|
server_nonce = int.from_bytes(bytes(range(16)))
|
||||||
new_nonce = int.from_bytes(bytes(range(32)))
|
new_nonce = int.from_bytes(bytes(range(32)))
|
||||||
|
|
||||||
(key, iv) = generate_key_data_from_nonce(server_nonce, new_nonce)
|
key, iv = generate_key_data_from_nonce(server_nonce, new_nonce)
|
||||||
assert (
|
assert (
|
||||||
key
|
key
|
||||||
== b'\x07X\xf1S;a]$\xf6\xe8\xa9Jo\xcb\xee\nU\xea\xab"\x17\xd7)\\\xa9!=\x1a-}\x16\xa6'
|
== b'\x07X\xf1S;a]$\xf6\xe8\xa9Jo\xcb\xee\nU\xea\xab"\x17\xd7)\\\xa9!=\x1a-}\x16\xa6'
|
||||||
|
|
|
@ -7,7 +7,7 @@ from telethon._impl.mtproto import Abridged
|
||||||
class Output(bytearray):
|
class Output(bytearray):
|
||||||
__slots__ = ()
|
__slots__ = ()
|
||||||
|
|
||||||
def __call__(self, data: bytes) -> None:
|
def __call__(self, data: bytes | bytearray | memoryview) -> None:
|
||||||
self += data
|
self += data
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,7 @@ from telethon._impl.mtproto import Full
|
||||||
class Output(bytearray):
|
class Output(bytearray):
|
||||||
__slots__ = ()
|
__slots__ = ()
|
||||||
|
|
||||||
def __call__(self, data: bytes) -> None:
|
def __call__(self, data: bytes | bytearray | memoryview) -> None:
|
||||||
self += data
|
self += data
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,7 +20,7 @@ def setup_unpack(n: int) -> Tuple[bytes, Full, bytes, bytearray]:
|
||||||
transport, expected_output, input = setup_pack(n)
|
transport, expected_output, input = setup_pack(n)
|
||||||
transport.pack(expected_output, input)
|
transport.pack(expected_output, input)
|
||||||
|
|
||||||
return expected_output, Full(), input, bytearray()
|
return expected_output, Full(), bytes(input), bytearray()
|
||||||
|
|
||||||
|
|
||||||
def test_pack_empty() -> None:
|
def test_pack_empty() -> None:
|
||||||
|
|
|
@ -7,7 +7,7 @@ from telethon._impl.mtproto import Intermediate
|
||||||
class Output(bytearray):
|
class Output(bytearray):
|
||||||
__slots__ = ()
|
__slots__ = ()
|
||||||
|
|
||||||
def __call__(self, data: bytes) -> None:
|
def __call__(self, data: bytes | bytearray | memoryview) -> None:
|
||||||
self += data
|
self += data
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -46,14 +46,14 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
|
||||||
|
|
||||||
ignored_types = {"true", "boolTrue", "boolFalse"} # also "compiler built-ins"
|
ignored_types = {"true", "boolTrue", "boolFalse"} # also "compiler built-ins"
|
||||||
|
|
||||||
abc_namespaces = set()
|
abc_namespaces: Set[str] = set()
|
||||||
type_namespaces = set()
|
type_namespaces: Set[str] = set()
|
||||||
function_namespaces = set()
|
function_namespaces: Set[str] = set()
|
||||||
|
|
||||||
abc_class_names = set()
|
abc_class_names: Set[str] = set()
|
||||||
type_class_names = set()
|
type_class_names: Set[str] = set()
|
||||||
function_def_names = set()
|
function_def_names: Set[str] = set()
|
||||||
generated_type_names = set()
|
generated_type_names: Set[str] = set()
|
||||||
|
|
||||||
for typedef in tl.typedefs:
|
for typedef in tl.typedefs:
|
||||||
if typedef.ty.full_name not in generated_types:
|
if typedef.ty.full_name not in generated_types:
|
||||||
|
@ -67,6 +67,7 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
|
||||||
abc_path = Path("abcs/_nons.py")
|
abc_path = Path("abcs/_nons.py")
|
||||||
|
|
||||||
if abc_path not in fs:
|
if abc_path not in fs:
|
||||||
|
fs.write(abc_path, "# pyright: reportUnusedImport=false\n")
|
||||||
fs.write(abc_path, "from abc import ABCMeta\n")
|
fs.write(abc_path, "from abc import ABCMeta\n")
|
||||||
fs.write(abc_path, "from ..core import Serializable\n")
|
fs.write(abc_path, "from ..core import Serializable\n")
|
||||||
|
|
||||||
|
@ -93,11 +94,14 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
|
||||||
writer = fs.open(type_path)
|
writer = fs.open(type_path)
|
||||||
|
|
||||||
if type_path not in fs:
|
if type_path not in fs:
|
||||||
|
writer.write(
|
||||||
|
"# pyright: reportUnusedImport=false, reportConstantRedefinition=false"
|
||||||
|
)
|
||||||
writer.write("import struct")
|
writer.write("import struct")
|
||||||
writer.write("from typing import List, Optional, Self, Sequence")
|
writer.write("from typing import Optional, Self, Sequence")
|
||||||
writer.write("from .. import abcs")
|
writer.write("from .. import abcs")
|
||||||
writer.write("from ..core import Reader, Serializable, serialize_bytes_to")
|
writer.write("from ..core import Reader, Serializable, serialize_bytes_to")
|
||||||
writer.write("_bytes = bytes")
|
writer.write("_bytes = bytes | bytearray | memoryview")
|
||||||
|
|
||||||
ns = f"{typedef.namespace[0]}." if typedef.namespace else ""
|
ns = f"{typedef.namespace[0]}." if typedef.namespace else ""
|
||||||
generated_type_names.add(f"{ns}{to_class_name(typedef.name)}")
|
generated_type_names.add(f"{ns}{to_class_name(typedef.name)}")
|
||||||
|
@ -113,14 +117,13 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
|
||||||
|
|
||||||
# def constructor_id()
|
# def constructor_id()
|
||||||
writer.write(" @classmethod")
|
writer.write(" @classmethod")
|
||||||
writer.write(" def constructor_id(_) -> int:")
|
writer.write(" def constructor_id(cls) -> int:")
|
||||||
writer.write(f" return {hex(typedef.id)}")
|
writer.write(f" return {hex(typedef.id)}")
|
||||||
|
|
||||||
# def __init__()
|
# def __init__()
|
||||||
if property_params:
|
if property_params:
|
||||||
params = "".join(
|
params = "".join(
|
||||||
f", {p.name}: {param_type_fmt(p.ty, immutable=False)}"
|
f", {p.name}: {param_type_fmt(p.ty)}" for p in property_params
|
||||||
for p in property_params
|
|
||||||
)
|
)
|
||||||
writer.write(f" def __init__(_s, *{params}) -> None:")
|
writer.write(f" def __init__(_s, *{params}) -> None:")
|
||||||
for p in property_params:
|
for p in property_params:
|
||||||
|
@ -159,22 +162,18 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
|
||||||
writer = fs.open(function_path)
|
writer = fs.open(function_path)
|
||||||
|
|
||||||
if function_path not in fs:
|
if function_path not in fs:
|
||||||
|
writer.write("# pyright: reportUnusedImport=false")
|
||||||
writer.write("import struct")
|
writer.write("import struct")
|
||||||
writer.write("from typing import List, Optional, Self, Sequence")
|
writer.write("from typing import Optional, Self, Sequence")
|
||||||
writer.write("from .. import abcs")
|
writer.write("from .. import abcs")
|
||||||
writer.write("from ..core import Request, serialize_bytes_to")
|
writer.write("from ..core import Request, serialize_bytes_to")
|
||||||
writer.write("_bytes = bytes")
|
writer.write("_bytes = bytes | bytearray | memoryview")
|
||||||
|
|
||||||
# def name(params, ...)
|
# def name(params, ...)
|
||||||
required_params = [p for p in functiondef.params if not is_computed(p.ty)]
|
required_params = [p for p in functiondef.params if not is_computed(p.ty)]
|
||||||
params = "".join(
|
params = "".join(f", {p.name}: {param_type_fmt(p.ty)}" for p in required_params)
|
||||||
f", {p.name}: {param_type_fmt(p.ty, immutable=True)}"
|
|
||||||
for p in required_params
|
|
||||||
)
|
|
||||||
star = "*" if params else ""
|
star = "*" if params else ""
|
||||||
return_ty = param_type_fmt(
|
return_ty = param_type_fmt(NormalParameter(ty=functiondef.ty, flag=None))
|
||||||
NormalParameter(ty=functiondef.ty, flag=None), immutable=False
|
|
||||||
)
|
|
||||||
writer.write(
|
writer.write(
|
||||||
f"def {to_method_name(functiondef.name)}({star}{params}) -> Request[{return_ty}]:"
|
f"def {to_method_name(functiondef.name)}({star}{params}) -> Request[{return_ty}]:"
|
||||||
)
|
)
|
||||||
|
@ -189,6 +188,7 @@ def generate(fs: FakeFs, tl: ParsedTl) -> None:
|
||||||
)
|
)
|
||||||
|
|
||||||
writer = fs.open(Path("layer.py"))
|
writer = fs.open(Path("layer.py"))
|
||||||
|
writer.write("# pyright: reportUnusedImport=false")
|
||||||
writer.write("from . import abcs, types")
|
writer.write("from . import abcs, types")
|
||||||
writer.write(
|
writer.write(
|
||||||
"from .core import Serializable, Reader, deserialize_bool, deserialize_i32_list, deserialize_i64_list, deserialize_identity, single_deserializer, list_deserializer"
|
"from .core import Serializable, Reader, deserialize_bool, deserialize_i32_list, deserialize_i64_list, deserialize_identity, single_deserializer, list_deserializer"
|
||||||
|
|
|
@ -86,7 +86,7 @@ def inner_type_fmt(ty: Type) -> str:
|
||||||
return f"abcs.{ns}{to_class_name(ty.name)}"
|
return f"abcs.{ns}{to_class_name(ty.name)}"
|
||||||
|
|
||||||
|
|
||||||
def param_type_fmt(ty: BaseParameter, *, immutable: bool) -> str:
|
def param_type_fmt(ty: BaseParameter) -> str:
|
||||||
if isinstance(ty, FlagsParameter):
|
if isinstance(ty, FlagsParameter):
|
||||||
return "int"
|
return "int"
|
||||||
elif not isinstance(ty, NormalParameter):
|
elif not isinstance(ty, NormalParameter):
|
||||||
|
@ -104,10 +104,7 @@ def param_type_fmt(ty: BaseParameter, *, immutable: bool) -> str:
|
||||||
res = "_bytes" if inner_ty.name == "Object" else inner_type_fmt(inner_ty)
|
res = "_bytes" if inner_ty.name == "Object" else inner_type_fmt(inner_ty)
|
||||||
|
|
||||||
if ty.ty.generic_arg:
|
if ty.ty.generic_arg:
|
||||||
if immutable:
|
res = f"Sequence[{res}]"
|
||||||
res = f"Sequence[{res}]"
|
|
||||||
else:
|
|
||||||
res = f"List[{res}]"
|
|
||||||
|
|
||||||
if ty.flag and ty.ty.name != "true":
|
if ty.flag and ty.ty.name != "true":
|
||||||
res = f"Optional[{res}]"
|
res = f"Optional[{res}]"
|
||||||
|
|
|
@ -15,7 +15,8 @@ class ParsedTl:
|
||||||
|
|
||||||
|
|
||||||
def load_tl_file(path: Union[str, Path]) -> ParsedTl:
|
def load_tl_file(path: Union[str, Path]) -> ParsedTl:
|
||||||
typedefs, functiondefs = [], []
|
typedefs: List[TypeDef] = []
|
||||||
|
functiondefs: List[FunctionDef] = []
|
||||||
with open(path, "r", encoding="utf-8") as fd:
|
with open(path, "r", encoding="utf-8") as fd:
|
||||||
contents = fd.read()
|
contents = fd.read()
|
||||||
|
|
||||||
|
@ -31,10 +32,8 @@ def load_tl_file(path: Union[str, Path]) -> ParsedTl:
|
||||||
raise
|
raise
|
||||||
elif isinstance(definition, TypeDef):
|
elif isinstance(definition, TypeDef):
|
||||||
typedefs.append(definition)
|
typedefs.append(definition)
|
||||||
elif isinstance(definition, FunctionDef):
|
|
||||||
functiondefs.append(definition)
|
|
||||||
else:
|
else:
|
||||||
raise TypeError(f"unexpected type: {type(definition)}")
|
functiondefs.append(definition)
|
||||||
|
|
||||||
return ParsedTl(
|
return ParsedTl(
|
||||||
layer=layer, typedefs=list(typedefs), functiondefs=list(functiondefs)
|
layer=layer, typedefs=list(typedefs), functiondefs=list(functiondefs)
|
||||||
|
|
|
@ -59,8 +59,8 @@ class Definition:
|
||||||
raise ValueError("invalid id")
|
raise ValueError("invalid id")
|
||||||
|
|
||||||
type_defs: List[str] = []
|
type_defs: List[str] = []
|
||||||
flag_defs = []
|
flag_defs: List[str] = []
|
||||||
params = []
|
params: List[Parameter] = []
|
||||||
|
|
||||||
for param_str in middle.split():
|
for param_str in middle.split():
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -79,7 +79,7 @@ def test_recursive_vec() -> None:
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
result = gen_py_code(typedefs=definitions)
|
result = gen_py_code(typedefs=definitions)
|
||||||
assert "value: List[abcs.JsonObjectValue]" in result
|
assert "value: Sequence[abcs.JsonObjectValue]" in result
|
||||||
|
|
||||||
|
|
||||||
def test_object_blob_special_case() -> None:
|
def test_object_blob_special_case() -> None:
|
||||||
|
|
|
@ -10,6 +10,7 @@ Imports of new definitions and formatting must be added with other tools.
|
||||||
Properties and private methods can use a different parameter name than `self`
|
Properties and private methods can use a different parameter name than `self`
|
||||||
to avoid being included.
|
to avoid being included.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import ast
|
import ast
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
@ -31,6 +32,8 @@ class FunctionMethodsVisitor(ast.NodeVisitor):
|
||||||
match node.args.args:
|
match node.args.args:
|
||||||
case [ast.arg(arg="self", annotation=ast.Name(id="Client")), *_]:
|
case [ast.arg(arg="self", annotation=ast.Name(id="Client")), *_]:
|
||||||
self.methods.append(node)
|
self.methods.append(node)
|
||||||
|
case _:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class MethodVisitor(ast.NodeVisitor):
|
class MethodVisitor(ast.NodeVisitor):
|
||||||
|
@ -59,6 +62,8 @@ class MethodVisitor(ast.NodeVisitor):
|
||||||
match node.body:
|
match node.body:
|
||||||
case [ast.Expr(value=ast.Constant(value=str(doc))), *_]:
|
case [ast.Expr(value=ast.Constant(value=str(doc))), *_]:
|
||||||
self.method_docs[node.name] = doc
|
self.method_docs[node.name] = doc
|
||||||
|
case _:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
|
@ -81,10 +86,10 @@ def main() -> None:
|
||||||
|
|
||||||
m_visitor.visit(ast.parse(contents))
|
m_visitor.visit(ast.parse(contents))
|
||||||
|
|
||||||
class_body = []
|
class_body: List[ast.stmt] = []
|
||||||
|
|
||||||
for function in sorted(fm_visitor.methods, key=lambda f: f.name):
|
for function in sorted(fm_visitor.methods, key=lambda f: f.name):
|
||||||
function.body = []
|
function_body: List[ast.stmt] = []
|
||||||
if doc := m_visitor.method_docs.get(function.name):
|
if doc := m_visitor.method_docs.get(function.name):
|
||||||
function.body.append(ast.Expr(value=ast.Constant(value=doc)))
|
function.body.append(ast.Expr(value=ast.Constant(value=doc)))
|
||||||
|
|
||||||
|
@ -108,7 +113,7 @@ def main() -> None:
|
||||||
case _:
|
case _:
|
||||||
call = ast.Return(value=call)
|
call = ast.Return(value=call)
|
||||||
|
|
||||||
function.body.append(call)
|
function.body = function_body
|
||||||
class_body.append(function)
|
class_body.append(function)
|
||||||
|
|
||||||
generated = ast.unparse(
|
generated = ast.unparse(
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
Scan the `client/` directory for __init__.py files.
|
Scan the `client/` directory for __init__.py files.
|
||||||
For every depth-1 import, add it, in order, to the __all__ variable.
|
For every depth-1 import, add it, in order, to the __all__ variable.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import ast
|
import ast
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
@ -25,7 +26,7 @@ def main() -> None:
|
||||||
rf"(tl|mtproto){re.escape(os.path.sep)}(abcs|functions|types)"
|
rf"(tl|mtproto){re.escape(os.path.sep)}(abcs|functions|types)"
|
||||||
)
|
)
|
||||||
|
|
||||||
files = []
|
files: List[str] = []
|
||||||
for file in impl_root.rglob("__init__.py"):
|
for file in impl_root.rglob("__init__.py"):
|
||||||
file_str = str(file)
|
file_str = str(file)
|
||||||
if autogenerated_re.search(file_str):
|
if autogenerated_re.search(file_str):
|
||||||
|
@ -52,6 +53,8 @@ def main() -> None:
|
||||||
+ "]\n"
|
+ "]\n"
|
||||||
]
|
]
|
||||||
break
|
break
|
||||||
|
case _:
|
||||||
|
pass
|
||||||
|
|
||||||
with file.open("w", encoding="utf-8", newline="\n") as fd:
|
with file.open("w", encoding="utf-8", newline="\n") as fd:
|
||||||
fd.writelines(lines)
|
fd.writelines(lines)
|
||||||
|
|
19
typings/setuptools.pyi
Normal file
19
typings/setuptools.pyi
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
class build_meta:
|
||||||
|
@staticmethod
|
||||||
|
def build_wheel(
|
||||||
|
wheel_directory: str,
|
||||||
|
config_settings: Optional[Dict[Any, Any]] = None,
|
||||||
|
metadata_directory: Optional[str] = None,
|
||||||
|
) -> str: ...
|
||||||
|
@staticmethod
|
||||||
|
def build_sdist(
|
||||||
|
sdist_directory: str, config_settings: Optional[Dict[Any, Any]] = None
|
||||||
|
) -> str: ...
|
||||||
|
@staticmethod
|
||||||
|
def build_editable(
|
||||||
|
wheel_directory: str,
|
||||||
|
config_settings: Optional[Dict[Any, Any]] = None,
|
||||||
|
metadata_directory: Optional[str] = None,
|
||||||
|
) -> str: ...
|
Loading…
Reference in New Issue
Block a user