Automatically generate Client methods

This commit is contained in:
Lonami Exo 2023-09-02 21:09:40 +02:00
parent 938126691c
commit f75acee7e8
8 changed files with 510 additions and 253 deletions

View File

@ -25,35 +25,35 @@ async def is_authorized(self: Client) -> bool:
raise raise
async def complete_login(self: Client, auth: abcs.auth.Authorization) -> User: async def complete_login(client: Client, auth: abcs.auth.Authorization) -> User:
assert isinstance(auth, types.auth.Authorization) assert isinstance(auth, types.auth.Authorization)
assert isinstance(auth.user, types.User) assert isinstance(auth.user, types.User)
user = User._from_raw(auth.user) user = User._from_raw(auth.user)
self._config.session.user = SessionUser( client._config.session.user = SessionUser(
id=user.id, id=user.id,
dc=self._dc_id, dc=client._dc_id,
bot=user.bot, bot=user.bot,
) )
packed = user.pack() packed = user.pack()
assert packed is not None assert packed is not None
self._chat_hashes.set_self_user(packed) client._chat_hashes.set_self_user(packed)
try: try:
state = await self(functions.updates.get_state()) state = await client(functions.updates.get_state())
self._message_box.set_state(state) client._message_box.set_state(state)
except Exception: except Exception:
pass pass
return user return user
async def handle_migrate(self: Client, dc_id: Optional[int]) -> None: async def handle_migrate(client: Client, dc_id: Optional[int]) -> None:
assert dc_id is not None assert dc_id is not None
sender = await connect_sender(dc_id, self._config) sender = await connect_sender(dc_id, client._config)
async with self._sender_lock: async with client._sender_lock:
self._sender = sender client._sender = sender
self._dc_id = dc_id client._dc_id = dc_id
async def bot_sign_in(self: Client, token: str) -> User: async def bot_sign_in(self: Client, token: str) -> User:
@ -127,8 +127,8 @@ async def sign_in(
return await complete_login(self, result) return await complete_login(self, result)
async def get_password_information(self: Client) -> PasswordToken: async def get_password_information(client: Client) -> PasswordToken:
result = self(functions.account.get_password()) result = client(functions.account.get_password())
assert isinstance(result, types.account.Password) assert isinstance(result, types.account.Password)
return PasswordToken._new(result) return PasswordToken._new(result)
@ -144,6 +144,6 @@ async def sign_out(self: Client) -> None:
await self(functions.auth.log_out()) await self(functions.auth.log_out())
def session(self: Client) -> Session: def session(client: Client) -> Session:
self._config.session.state = self._message_box.session_state() client._config.session.state = client._message_box.session_state()
return self._config.session return client._config.session

View File

@ -1,8 +1,19 @@
import asyncio import asyncio
import datetime import datetime
from collections import deque from collections import deque
from pathlib import Path
from types import TracebackType from types import TracebackType
from typing import Deque, List, Literal, Optional, Self, Type, TypeVar, Union from typing import (
AsyncIterator,
Deque,
List,
Literal,
Optional,
Self,
Type,
TypeVar,
Union,
)
from ...mtsender.sender import Sender from ...mtsender.sender import Sender
from ...session.chat.hash_cache import ChatHashCache from ...session.chat.hash_cache import ChatHashCache
@ -27,7 +38,7 @@ from .auth import (
sign_in, sign_in,
sign_out, sign_out,
) )
from .bots import inline_query from .bots import InlineResult, inline_query
from .buttons import build_reply_markup from .buttons import build_reply_markup
from .chats import ( from .chats import (
action, action,
@ -42,6 +53,10 @@ from .chats import (
) )
from .dialogs import conversation, delete_dialog, edit_folder, iter_dialogs, iter_drafts from .dialogs import conversation, delete_dialog, edit_folder, iter_dialogs, iter_drafts
from .files import ( from .files import (
File,
InFileLike,
MediaLike,
OutFileLike,
download, download,
iter_download, iter_download,
send_audio, send_audio,
@ -71,7 +86,6 @@ from .net import (
disconnect, disconnect,
invoke_request, invoke_request,
run_until_disconnected, run_until_disconnected,
step,
) )
from .updates import ( from .updates import (
add_event_handler, add_event_handler,
@ -87,6 +101,8 @@ from .users import (
get_me, get_me,
get_peer_id, get_peer_id,
input_to_peer, input_to_peer,
is_bot,
is_user_authorized,
resolve_to_packed, resolve_to_packed,
) )
@ -110,177 +126,60 @@ class Client:
if config.catch_up and config.session.state: if config.catch_up and config.session.state:
self._message_box.load(config.session.state) self._message_box.load(config.session.state)
def takeout(self) -> None: def action(self) -> None:
takeout(self) action(self)
async def end_takeout(self) -> None: def add_event_handler(self) -> None:
await end_takeout(self) add_event_handler(self)
async def edit_2fa(self) -> None:
await edit_2fa(self)
async def is_authorized(self) -> bool:
return await is_authorized(self)
async def bot_sign_in(self, token: str) -> User: async def bot_sign_in(self, token: str) -> User:
return await bot_sign_in(self, token) return await bot_sign_in(self, token)
async def request_login_code(self, phone: str) -> LoginToken: def build_reply_markup(self) -> None:
return await request_login_code(self, phone) build_reply_markup(self)
async def sign_in(self, token: LoginToken, code: str) -> Union[User, PasswordToken]: async def catch_up(self) -> None:
return await sign_in(self, token, code) await catch_up(self)
async def check_password( async def check_password(
self, token: PasswordToken, password: Union[str, bytes] self, token: PasswordToken, password: Union[str, bytes]
) -> User: ) -> User:
return await check_password(self, token, password) return await check_password(self, token, password)
async def sign_out(self) -> None: async def connect(self) -> None:
await sign_out(self) await connect(self)
@property
def session(self) -> Session:
"""
Up-to-date session state, useful for persisting it to storage.
Mutating the returned object may cause the library to misbehave.
"""
return session(self)
async def inline_query(
self, bot: ChatLike, query: str, *, chat: Optional[ChatLike] = None
) -> None:
await inline_query(self, bot, query, chat=chat)
def build_reply_markup(self) -> None:
build_reply_markup(self)
def iter_participants(self) -> None:
iter_participants(self)
def iter_admin_log(self) -> None:
iter_admin_log(self)
def iter_profile_photos(self) -> None:
iter_profile_photos(self)
def action(self) -> None:
action(self)
async def edit_admin(self) -> None:
await edit_admin(self)
async def edit_permissions(self) -> None:
await edit_permissions(self)
async def kick_participant(self) -> None:
await kick_participant(self)
async def get_permissions(self) -> None:
await get_permissions(self)
async def get_stats(self) -> None:
await get_stats(self)
def iter_dialogs(self) -> None:
iter_dialogs(self)
def iter_drafts(self) -> None:
iter_drafts(self)
async def edit_folder(self) -> None:
await edit_folder(self)
async def delete_dialog(self) -> None:
await delete_dialog(self)
def conversation(self) -> None: def conversation(self) -> None:
conversation(self) conversation(self)
async def send_photo(self, *args, **kwargs) -> None: async def delete_dialog(self) -> None:
""" await delete_dialog(self)
Send a photo file.
Exactly one of path, url or file must be specified. async def delete_messages(
A `File` can also be used as the second parameter. self, chat: ChatLike, message_ids: List[int], *, revoke: bool = True
) -> int:
return await delete_messages(self, chat, message_ids, revoke=revoke)
By default, the server will be allowed to `compress` the image. async def disconnect(self) -> None:
Only compressed images can be displayed as photos in applications. await disconnect(self)
Images that cannot be compressed will be sent as file documents,
with a thumbnail if possible.
Unlike `send_file`, this method will attempt to guess the values for async def download(self, media: MediaLike, file: OutFileLike) -> None:
width and height if they are not provided and the can't be compressed.
"""
return send_photo(self, *args, **kwargs)
async def send_audio(self, *args, **kwargs) -> None:
"""
Send an audio file.
Unlike `send_file`, this method will attempt to guess the values for
duration, title and performer if they are not provided.
"""
return send_audio(self, *args, **kwargs)
async def send_video(self, *args, **kwargs) -> None:
"""
Send a video file.
Unlike `send_file`, this method will attempt to guess the values for
duration, width and height if they are not provided.
"""
return send_video(self, *args, **kwargs)
async def send_file(self, *args, **kwargs) -> None:
"""
Send any type of file with any amount of attributes.
This method will not attempt to guess any of the file metadata such as
width, duration, title, etc. If you want to let the library attempt to
guess the file metadata, use the type-specific methods to send media:
`send_photo`, `send_audio` or `send_file`.
Unlike `send_photo`, image files will be sent as documents by default.
The parameters are used to construct a `File`. See the documentation
for `File.new` to learn what they do and when they are in effect.
"""
return send_file(self, *args, **kwargs)
async def iter_download(self, *args, **kwargs) -> None:
"""
Stream server media by iterating over its bytes in chunks.
"""
return iter_download(self, *args, **kwargs)
async def download(self, *args, **kwargs) -> None:
""" """
Download a file. Download a file.
This is simply a more convenient method to `iter_download`, This is simply a more convenient method to `iter_download`,
as it will handle dealing with the file chunks and writes by itself. as it will handle dealing with the file chunks and writes by itself.
""" """
return download(self, *args, **kwargs) await download(self, media, file)
async def send_message( async def edit_2fa(self) -> None:
self, await edit_2fa(self)
chat: ChatLike,
*, async def edit_admin(self) -> None:
text: Optional[str] = None, await edit_admin(self)
markdown: Optional[str] = None,
html: Optional[str] = None, async def edit_folder(self) -> None:
link_preview: Optional[bool] = None, await edit_folder(self)
) -> Message:
return await send_message(
self,
chat,
text=text,
markdown=markdown,
html=html,
link_preview=link_preview,
)
async def edit_message( async def edit_message(
self, self,
@ -302,16 +201,26 @@ class Client:
link_preview=link_preview, link_preview=link_preview,
) )
async def delete_messages( async def edit_permissions(self) -> None:
self, chat: ChatLike, message_ids: List[int], *, revoke: bool = True await edit_permissions(self)
) -> int:
return await delete_messages(self, chat, message_ids, revoke=revoke) async def end_takeout(self) -> None:
await end_takeout(self)
async def forward_messages( async def forward_messages(
self, target: ChatLike, message_ids: List[int], source: ChatLike self, target: ChatLike, message_ids: List[int], source: ChatLike
) -> List[Message]: ) -> List[Message]:
return await forward_messages(self, target, message_ids, source) return await forward_messages(self, target, message_ids, source)
async def get_entity(self) -> None:
await get_entity(self)
async def get_input_entity(self) -> None:
await get_input_entity(self)
async def get_me(self) -> None:
await get_me(self)
def get_messages( def get_messages(
self, self,
chat: ChatLike, chat: ChatLike,
@ -325,12 +234,90 @@ class Client:
) )
def get_messages_with_ids( def get_messages_with_ids(
self, self, chat: ChatLike, message_ids: List[int]
chat: ChatLike,
message_ids: List[int],
) -> AsyncList[Message]: ) -> AsyncList[Message]:
return get_messages_with_ids(self, chat, message_ids) return get_messages_with_ids(self, chat, message_ids)
async def get_peer_id(self) -> None:
await get_peer_id(self)
async def get_permissions(self) -> None:
await get_permissions(self)
async def get_stats(self) -> None:
await get_stats(self)
async def inline_query(
self, bot: ChatLike, query: str, *, chat: Optional[ChatLike] = None
) -> AsyncIterator[InlineResult]:
return await inline_query(self, bot, query, chat=chat)
async def is_authorized(self) -> bool:
return await is_authorized(self)
async def is_bot(self) -> None:
await is_bot(self)
async def is_user_authorized(self) -> None:
await is_user_authorized(self)
def iter_admin_log(self) -> None:
iter_admin_log(self)
def iter_dialogs(self) -> None:
iter_dialogs(self)
async def iter_download(self) -> None:
"""
Stream server media by iterating over its bytes in chunks.
"""
await iter_download(self)
def iter_drafts(self) -> None:
iter_drafts(self)
def iter_participants(self) -> None:
iter_participants(self)
def iter_profile_photos(self) -> None:
iter_profile_photos(self)
async def kick_participant(self) -> None:
await kick_participant(self)
def list_event_handlers(self) -> None:
list_event_handlers(self)
def on(self) -> None:
on(self)
async def pin_message(self, chat: ChatLike, message_id: int) -> Message:
return await pin_message(self, chat, message_id)
def remove_event_handler(self) -> None:
remove_event_handler(self)
async def request_login_code(self, phone: str) -> LoginToken:
return await request_login_code(self, phone)
async def resolve_to_packed(self, chat: ChatLike) -> PackedChat:
return await resolve_to_packed(self, chat)
async def run_until_disconnected(self) -> None:
await run_until_disconnected(self)
def search_all_messages(
self,
limit: Optional[int] = None,
*,
query: Optional[str] = None,
offset_id: Optional[int] = None,
offset_date: Optional[datetime.datetime] = None,
) -> AsyncList[Message]:
return search_all_messages(
self, limit, query=query, offset_id=offset_id, offset_date=offset_date
)
def search_messages( def search_messages(
self, self,
chat: ChatLike, chat: ChatLike,
@ -344,25 +331,230 @@ class Client:
self, chat, limit, query=query, offset_id=offset_id, offset_date=offset_date self, chat, limit, query=query, offset_id=offset_id, offset_date=offset_date
) )
def search_all_messages( async def send_audio(
self, self,
limit: Optional[int] = None, chat: ChatLike,
path: Optional[Union[str, Path, File]] = None,
*, *,
query: Optional[str] = None, url: Optional[str] = None,
offset_id: Optional[int] = None, file: Optional[InFileLike] = None,
offset_date: Optional[datetime.datetime] = None, size: Optional[int] = None,
) -> AsyncList[Message]: name: Optional[str] = None,
return search_all_messages( duration: Optional[float] = None,
self, limit, query=query, offset_id=offset_id, offset_date=offset_date voice: bool = False,
title: Optional[str] = None,
performer: Optional[str] = None,
) -> Message:
"""
Send an audio file.
Unlike `send_file`, this method will attempt to guess the values for
duration, title and performer if they are not provided.
"""
return await send_audio(
self,
chat,
path,
url=url,
file=file,
size=size,
name=name,
duration=duration,
voice=voice,
title=title,
performer=performer,
) )
async def pin_message(self, chat: ChatLike, message_id: int) -> Message: async def send_file(
return await pin_message(self, chat, message_id) self,
chat: ChatLike,
path: Optional[Union[str, Path, File]] = None,
*,
url: Optional[str] = None,
file: Optional[InFileLike] = None,
size: Optional[int] = None,
name: Optional[str] = None,
mime_type: Optional[str] = None,
compress: bool = False,
animated: bool = False,
duration: Optional[float] = None,
voice: bool = False,
title: Optional[str] = None,
performer: Optional[str] = None,
emoji: Optional[str] = None,
emoji_sticker: Optional[str] = None,
width: Optional[int] = None,
height: Optional[int] = None,
round: bool = False,
supports_streaming: bool = False,
muted: bool = False,
caption: Optional[str] = None,
caption_markdown: Optional[str] = None,
caption_html: Optional[str] = None,
) -> Message:
"""
Send any type of file with any amount of attributes.
This method will not attempt to guess any of the file metadata such as
width, duration, title, etc. If you want to let the library attempt to
guess the file metadata, use the type-specific methods to send media:
`send_photo`, `send_audio` or `send_file`.
Unlike `send_photo`, image files will be sent as documents by default.
The parameters are used to construct a `File`. See the documentation
for `File.new` to learn what they do and when they are in effect.
"""
return await send_file(
self,
chat,
path,
url=url,
file=file,
size=size,
name=name,
mime_type=mime_type,
compress=compress,
animated=animated,
duration=duration,
voice=voice,
title=title,
performer=performer,
emoji=emoji,
emoji_sticker=emoji_sticker,
width=width,
height=height,
round=round,
supports_streaming=supports_streaming,
muted=muted,
caption=caption,
caption_markdown=caption_markdown,
caption_html=caption_html,
)
async def send_message(
self,
chat: ChatLike,
*,
text: Optional[str] = None,
markdown: Optional[str] = None,
html: Optional[str] = None,
link_preview: Optional[bool] = None,
) -> Message:
return await send_message(
self,
chat,
text=text,
markdown=markdown,
html=html,
link_preview=link_preview,
)
async def send_photo(
self,
chat: ChatLike,
path: Optional[Union[str, Path, File]] = None,
*,
url: Optional[str] = None,
file: Optional[InFileLike] = None,
size: Optional[int] = None,
name: Optional[str] = None,
compress: bool = True,
width: Optional[int] = None,
height: Optional[int] = None,
) -> Message:
"""
Send a photo file.
Exactly one of path, url or file must be specified.
A `File` can also be used as the second parameter.
By default, the server will be allowed to `compress` the image.
Only compressed images can be displayed as photos in applications.
Images that cannot be compressed will be sent as file documents,
with a thumbnail if possible.
Unlike `send_file`, this method will attempt to guess the values for
width and height if they are not provided and the can't be compressed.
"""
return await send_photo(
self,
chat,
path,
url=url,
file=file,
size=size,
name=name,
compress=compress,
width=width,
height=height,
)
async def send_video(
self,
chat: ChatLike,
path: Optional[Union[str, Path, File]] = None,
*,
url: Optional[str] = None,
file: Optional[InFileLike] = None,
size: Optional[int] = None,
name: Optional[str] = None,
duration: Optional[float] = None,
width: Optional[int] = None,
height: Optional[int] = None,
round: bool = False,
supports_streaming: bool = False,
) -> Message:
"""
Send a video file.
Unlike `send_file`, this method will attempt to guess the values for
duration, width and height if they are not provided.
"""
return await send_video(
self,
chat,
path,
url=url,
file=file,
size=size,
name=name,
duration=duration,
width=width,
height=height,
round=round,
supports_streaming=supports_streaming,
)
async def set_receive_updates(self) -> None:
await set_receive_updates(self)
async def sign_in(self, token: LoginToken, code: str) -> Union[User, PasswordToken]:
return await sign_in(self, token, code)
async def sign_out(self) -> None:
await sign_out(self)
def takeout(self) -> None:
takeout(self)
async def unpin_message( async def unpin_message(
self, chat: ChatLike, message_id: Union[int, Literal["all"]] self, chat: ChatLike, message_id: Union[int, Literal["all"]]
) -> None: ) -> None:
return await unpin_message(self, chat, message_id) await unpin_message(self, chat, message_id)
@property
def connected(self) -> bool:
return connected(self)
@property
def session(self) -> Session:
"""
Up-to-date session state, useful for persisting it to storage.
Mutating the returned object may cause the library to misbehave.
"""
return session(self)
def _build_message_map( def _build_message_map(
self, self,
@ -371,64 +563,18 @@ class Client:
) -> MessageMap: ) -> MessageMap:
return build_message_map(self, result, peer) return build_message_map(self, result, peer)
async def set_receive_updates(self) -> None:
await set_receive_updates(self)
def on(self) -> None:
on(self)
def add_event_handler(self) -> None:
add_event_handler(self)
def remove_event_handler(self) -> None:
remove_event_handler(self)
def list_event_handlers(self) -> None:
list_event_handlers(self)
async def catch_up(self) -> None:
await catch_up(self)
async def get_me(self) -> None:
await get_me(self)
async def get_entity(self) -> None:
await get_entity(self)
async def get_input_entity(self) -> None:
await get_input_entity(self)
async def _resolve_to_packed(self, chat: ChatLike) -> PackedChat: async def _resolve_to_packed(self, chat: ChatLike) -> PackedChat:
return await resolve_to_packed(self, chat) return await resolve_to_packed(self, chat)
def _input_to_peer(self, input: Optional[abcs.InputPeer]) -> Optional[abcs.Peer]: def _input_to_peer(self, input: Optional[abcs.InputPeer]) -> Optional[abcs.Peer]:
return input_to_peer(self, input) return input_to_peer(self, input)
async def get_peer_id(self) -> None:
await get_peer_id(self)
async def connect(self) -> None:
await connect(self)
async def disconnect(self) -> None:
await disconnect(self)
async def __call__(self, request: Request[Return]) -> Return: async def __call__(self, request: Request[Return]) -> Return:
if not self._sender: if not self._sender:
raise ConnectionError("not connected") raise ConnectionError("not connected")
return await invoke_request(self, self._sender, self._sender_lock, request) return await invoke_request(self, self._sender, self._sender_lock, request)
async def step(self) -> None:
await step(self)
async def run_until_disconnected(self) -> None:
await run_until_disconnected(self)
@property
def connected(self) -> bool:
return connected(self)
async def __aenter__(self) -> Self: async def __aenter__(self) -> Self:
await self.connect() await self.connect()
return self return self

View File

@ -414,7 +414,7 @@ async def send_file(
async def upload( async def upload(
self: Client, client: Client,
file: File, file: File,
) -> abcs.InputFile: ) -> abcs.InputFile:
file_id = generate_random_id() file_id = generate_random_id()
@ -442,7 +442,7 @@ async def upload(
continue continue
if is_big: if is_big:
await self( await client(
functions.upload.save_big_file_part( functions.upload.save_big_file_part(
file_id=file_id, file_id=file_id,
file_part=part, file_part=part,
@ -451,7 +451,7 @@ async def upload(
) )
) )
else: else:
await self( await client(
functions.upload.save_file_part( functions.upload.save_file_part(
file_id=file_id, file_part=total_parts, bytes=to_store file_id=file_id, file_part=total_parts, bytes=to_store
) )

View File

@ -502,7 +502,7 @@ class MessageMap:
def build_message_map( def build_message_map(
self: Client, client: Client,
result: abcs.Updates, result: abcs.Updates,
peer: Optional[abcs.InputPeer], peer: Optional[abcs.InputPeer],
) -> MessageMap: ) -> MessageMap:
@ -514,7 +514,7 @@ def build_message_map(
entities = {} entities = {}
raise NotImplementedError() raise NotImplementedError()
else: else:
return MessageMap(self, peer, {}, {}) return MessageMap(client, peer, {}, {})
random_id_to_id = {} random_id_to_id = {}
id_to_message = {} id_to_message = {}
@ -542,7 +542,7 @@ def build_message_map(
raise NotImplementedError() raise NotImplementedError()
return MessageMap( return MessageMap(
self, client,
peer, peer,
random_id_to_id, random_id_to_id,
id_to_message, id_to_message,

View File

@ -141,17 +141,17 @@ async def disconnect(self: Client) -> None:
async def invoke_request( async def invoke_request(
self: Client, client: Client,
sender: Sender, sender: Sender,
lock: asyncio.Lock, lock: asyncio.Lock,
request: Request[Return], request: Request[Return],
) -> Return: ) -> Return:
slept_flood = False slept_flood = False
sleep_thresh = self._config.flood_sleep_threshold or 0 sleep_thresh = client._config.flood_sleep_threshold or 0
rx = sender.enqueue(request) rx = sender.enqueue(request)
while True: while True:
while not rx.done(): while not rx.done():
await step_sender(self, sender, lock) await step_sender(client, sender, lock)
try: try:
response = rx.result() response = rx.result()
break break
@ -171,25 +171,25 @@ async def invoke_request(
return request.deserialize_response(response) return request.deserialize_response(response)
async def step(self: Client) -> None: async def step(client: Client) -> None:
if self._sender: if client._sender:
await step_sender(self, self._sender, self._sender_lock) await step_sender(client, client._sender, client._sender_lock)
async def step_sender(self: Client, sender: Sender, lock: asyncio.Lock) -> None: async def step_sender(client: Client, sender: Sender, lock: asyncio.Lock) -> None:
if lock.locked(): if lock.locked():
async with lock: async with lock:
pass pass
else: else:
async with lock: async with lock:
updates = await sender.step() updates = await sender.step()
# self._process_socket_updates(updates) # client._process_socket_updates(updates)
async def run_until_disconnected(self: Client) -> None: async def run_until_disconnected(self: Client) -> None:
while self.connected: while self.connected:
await self.step() await step(self)
def connected(self: Client) -> bool: def connected(client: Client) -> bool:
return self._sender is not None return client._sender is not None

View File

@ -11,11 +11,6 @@ async def set_receive_updates(self: Client) -> None:
raise NotImplementedError raise NotImplementedError
def run_until_disconnected(self: Client) -> None:
self
raise NotImplementedError
def on(self: Client) -> None: def on(self: Client) -> None:
self self
raise NotImplementedError raise NotImplementedError

View File

@ -86,15 +86,17 @@ async def resolve_to_packed(self: Client, chat: ChatLike) -> PackedChat:
raise ValueError("Cannot resolve chat") raise ValueError("Cannot resolve chat")
def input_to_peer(self: Client, input: Optional[abcs.InputPeer]) -> Optional[abcs.Peer]: def input_to_peer(
client: Client, input: Optional[abcs.InputPeer]
) -> Optional[abcs.Peer]:
if input is None: if input is None:
return None return None
elif isinstance(input, types.InputPeerEmpty): elif isinstance(input, types.InputPeerEmpty):
return None return None
elif isinstance(input, types.InputPeerSelf): elif isinstance(input, types.InputPeerSelf):
return ( return (
types.PeerUser(user_id=self._config.session.user.id) types.PeerUser(user_id=client._config.session.user.id)
if self._config.session.user if client._config.session.user
else None else None
) )
elif isinstance(input, types.InputPeerChat): elif isinstance(input, types.InputPeerChat):

View File

@ -0,0 +1,114 @@
"""
Scan the `client/` directory, take all function definitions with `self: Client`
as the first parameter, and generate the corresponding `Client` methods to call
them, with matching signatures.
The documentation previously existing in the `Client` definitions is preserved.
Imports of new definitions and formatting must be added with other tools.
Properties and private methods can use a different parameter name than `self`
to avoid being included.
"""
import ast
import sys
from _ast import AsyncFunctionDef, ClassDef
from pathlib import Path
from typing import Dict, List, Union
class FunctionMethodsVisitor(ast.NodeVisitor):
def __init__(self) -> None:
self.methods: List[Union[ast.FunctionDef, ast.AsyncFunctionDef]] = []
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
self._try_add_def(node)
def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> None:
self._try_add_def(node)
def _try_add_def(self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]) -> None:
match node.args.args:
case [ast.arg(arg="self", annotation=ast.Name(id="Client")), *_]:
self.methods.append(node)
class MethodVisitor(ast.NodeVisitor):
def __init__(self) -> None:
self._in_client = False
self.method_docs: Dict[str, str] = {}
def visit_ClassDef(self, node: ClassDef) -> None:
if node.name == "Client":
assert not self._in_client
self._in_client = True
for subnode in node.body:
self.visit(subnode)
self._in_client = False
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
self._try_add_doc(node)
def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> None:
self._try_add_doc(node)
def _try_add_doc(self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]) -> None:
if not self._in_client:
return
match node.body:
case [ast.Expr(value=ast.Constant(value=str(doc))), *_]:
self.method_docs[node.name] = doc
def main() -> None:
client_root = Path.cwd() / sys.argv[1]
fm_visitor = FunctionMethodsVisitor()
m_visitor = MethodVisitor()
for file in client_root.glob("*.py"):
if file.stem in ("__init__", "client"):
pass
with file.open(encoding="utf-8") as fd:
contents = fd.read()
fm_visitor.visit(ast.parse(contents))
with (client_root / "client.py").open(encoding="utf-8") as fd:
contents = fd.read()
m_visitor.visit(ast.parse(contents))
for function in sorted(fm_visitor.methods, key=lambda f: f.name):
function.body = []
if doc := m_visitor.method_docs.get(function.name):
function.body.append(ast.Expr(value=ast.Constant(value=doc)))
call: ast.AST = ast.Call(
func=ast.Name(id=function.name, ctx=ast.Load()),
args=[ast.Name(id=a.arg, ctx=ast.Load()) for a in function.args.args],
keywords=[
ast.keyword(arg=a.arg, value=ast.Name(id=a.arg, ctx=ast.Load()))
for a in function.args.kwonlyargs
],
)
function.args.args[0].annotation = None
if isinstance(function, ast.AsyncFunctionDef):
call = ast.Await(value=call)
match function.returns:
case ast.Constant(value=None):
call = ast.Expr(value=call)
case _:
call = ast.Return(value=call)
function.body.append(call)
print(ast.unparse(function))
if __name__ == "__main__":
main()