From 5e43efc55daab7ead184e6899df5a31a4143e83d Mon Sep 17 00:00:00 2001 From: Lonami Exo Date: Sat, 2 Sep 2023 00:49:16 +0200 Subject: [PATCH] Implement message-related methods --- .../src/telethon/_impl/client/client/bots.py | 5 +- .../telethon/_impl/client/client/client.py | 130 ++++- .../telethon/_impl/client/client/messages.py | 528 ++++++++++++++++-- .../telethon/_impl/client/types/async_list.py | 67 +++ .../telethon/_impl/client/types/message.py | 31 +- client/src/telethon/_impl/client/utils.py | 11 + 6 files changed, 702 insertions(+), 70 deletions(-) create mode 100644 client/src/telethon/_impl/client/types/async_list.py create mode 100644 client/src/telethon/_impl/client/utils.py diff --git a/client/src/telethon/_impl/client/client/bots.py b/client/src/telethon/_impl/client/client/bots.py index 77486316..5bc11465 100644 --- a/client/src/telethon/_impl/client/client/bots.py +++ b/client/src/telethon/_impl/client/client/bots.py @@ -100,7 +100,7 @@ class InlineResult(metaclass=NoPublicConstructor): peer = self._default_peer random_id = generate_random_id() - return self._client._find_updates_message( + return self._client._build_message_map( await self._client( functions.messages.send_inline_bot_result( silent=False, @@ -117,9 +117,8 @@ class InlineResult(metaclass=NoPublicConstructor): send_as=None, ) ), - random_id, peer, - ) + ).with_random_id(random_id) async def inline_query( diff --git a/client/src/telethon/_impl/client/client/client.py b/client/src/telethon/_impl/client/client/client.py index 757497ac..fbea02ef 100644 --- a/client/src/telethon/_impl/client/client/client.py +++ b/client/src/telethon/_impl/client/client/client.py @@ -1,7 +1,8 @@ import asyncio +import datetime from collections import deque from types import TracebackType -from typing import Deque, Optional, Self, Type, TypeVar, Union +from typing import Deque, List, Literal, Optional, Self, Type, TypeVar, Union from ...mtsender.sender import Sender from ...session.chat.hash_cache import ChatHashCache @@ -10,6 +11,7 @@ from ...session.message_box.defs import Session from ...session.message_box.messagebox import MessageBox from ...tl import abcs from ...tl.core.request import Request +from ..types.async_list import AsyncList from ..types.chat import ChatLike from ..types.chat.user import User from ..types.login_token import LoginToken @@ -41,14 +43,17 @@ from .chats import ( from .dialogs import conversation, delete_dialog, edit_folder, iter_dialogs, iter_drafts from .downloads import download_media, download_profile_photo, iter_download from .messages import ( + MessageMap, + build_message_map, delete_messages, edit_message, - find_updates_message, forward_messages, - iter_messages, + get_messages, + get_messages_with_ids, pin_message, + search_all_messages, + search_messages, send_message, - send_read_acknowledge, unpin_message, ) from .net import ( @@ -196,37 +201,112 @@ class Client: def iter_download(self) -> None: iter_download(self) - def iter_messages(self) -> None: - iter_messages(self) + async def send_message( + self, + chat: ChatLike, + *, + text: Optional[str] = None, + markdown: Optional[str] = None, + html: Optional[str] = None, + link_preview: Optional[bool] = None, + ) -> Message: + return await send_message( + self, + chat, + text=text, + markdown=markdown, + html=html, + link_preview=link_preview, + ) - async def send_message(self) -> None: - await send_message(self) + async def edit_message( + self, + chat: ChatLike, + message_id: int, + *, + text: Optional[str] = None, + markdown: Optional[str] = None, + html: Optional[str] = None, + link_preview: Optional[bool] = None, + ) -> Message: + return await edit_message( + self, + chat, + message_id, + text=text, + markdown=markdown, + html=html, + link_preview=link_preview, + ) - async def forward_messages(self) -> None: - await forward_messages(self) + async def delete_messages( + self, chat: ChatLike, message_ids: List[int], *, revoke: bool = True + ) -> int: + return await delete_messages(self, chat, message_ids, revoke=revoke) - async def edit_message(self) -> None: - await edit_message(self) + async def forward_messages( + self, target: ChatLike, message_ids: List[int], source: ChatLike + ) -> List[Message]: + return await forward_messages(self, target, message_ids, source) - async def delete_messages(self) -> None: - await delete_messages(self) + def get_messages( + self, + chat: ChatLike, + limit: Optional[int] = None, + *, + offset_id: Optional[int], + offset_date: Optional[datetime.datetime], + ) -> AsyncList[Message]: + return get_messages( + self, chat, limit, offset_id=offset_id, offset_date=offset_date + ) - async def send_read_acknowledge(self) -> None: - await send_read_acknowledge(self) + def get_messages_with_ids( + self, + chat: ChatLike, + message_ids: List[int], + ) -> AsyncList[Message]: + return get_messages_with_ids(self, chat, message_ids) - async def pin_message(self) -> None: - await pin_message(self) + def search_messages( + self, + chat: ChatLike, + limit: Optional[int] = None, + *, + query: Optional[str] = None, + offset_id: int, + offset_date: datetime.datetime, + ) -> AsyncList[Message]: + return search_messages( + self, chat, limit, query=query, offset_id=offset_id, offset_date=offset_date + ) - async def unpin_message(self) -> None: - await unpin_message(self) + def search_all_messages( + self, + limit: Optional[int] = None, + *, + query: Optional[str] = None, + offset_id: int, + offset_date: datetime.datetime, + ) -> AsyncList[Message]: + return search_all_messages( + self, limit, query=query, offset_id=offset_id, offset_date=offset_date + ) - def _find_updates_message( + async def pin_message(self, chat: ChatLike, message_id: int) -> Message: + return await pin_message(self, chat, message_id) + + async def unpin_message( + self, chat: ChatLike, message_id: Union[int, Literal["all"]] + ) -> None: + return await unpin_message(self, chat, message_id) + + def _build_message_map( self, result: abcs.Updates, - random_id: int, - chat: Optional[abcs.InputPeer], - ) -> Message: - return find_updates_message(self, result, random_id, chat) + peer: Optional[abcs.InputPeer], + ) -> MessageMap: + return build_message_map(self, result, peer) async def set_receive_updates(self) -> None: await set_receive_updates(self) diff --git a/client/src/telethon/_impl/client/client/messages.py b/client/src/telethon/_impl/client/client/messages.py index 943edaae..05025e75 100644 --- a/client/src/telethon/_impl/client/client/messages.py +++ b/client/src/telethon/_impl/client/client/messages.py @@ -1,60 +1,511 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Dict, List, Optional, Union +import datetime +import sys +from typing import ( + TYPE_CHECKING, + Any, + Coroutine, + Dict, + List, + Literal, + Optional, + Tuple, + Union, +) -from ...tl import abcs, types +from telethon._impl.client.types.async_list import AsyncList +from telethon._impl.session.chat.packed import PackedChat + +from ...tl import abcs, functions, types +from ..parsers import parse_html_message, parse_markdown_message +from ..types.chat import ChatLike from ..types.message import Message +from ..utils import generate_random_id if TYPE_CHECKING: from .client import Client -def iter_messages(self: Client) -> None: - self - raise NotImplementedError +def parse_message( + *, + text: Optional[str] = None, + markdown: Optional[str] = None, + html: Optional[str] = None, +) -> Tuple[str, Optional[List[abcs.MessageEntity]]]: + if sum((text is not None, markdown is not None, html is not None)) != 1: + raise ValueError("must specify exactly one of text, markdown or html") + + if text is not None: + parsed, entities = text, None + elif markdown is not None: + parsed, entities = parse_markdown_message(markdown) + elif html is not None: + parsed, entities = parse_html_message(html) + else: + raise RuntimeError("unexpected case") + + return parsed, entities or None -async def send_message(self: Client) -> None: - self - raise NotImplementedError +async def send_message( + self: Client, + chat: ChatLike, + *, + text: Optional[str] = None, + markdown: Optional[str] = None, + html: Optional[str] = None, + link_preview: Optional[bool] = None, +) -> Message: + peer = (await self._resolve_to_packed(chat))._to_input_peer() + message, entities = parse_message(text=text, markdown=markdown, html=html) + random_id = generate_random_id() + return self._build_message_map( + await self( + functions.messages.send_message( + no_webpage=not link_preview, + silent=False, + background=False, + clear_draft=False, + noforwards=False, + update_stickersets_order=False, + peer=peer, + reply_to_msg_id=None, + top_msg_id=None, + message=message, + random_id=random_id, + reply_markup=None, + entities=entities, + schedule_date=None, + send_as=None, + ) + ), + peer, + ).with_random_id(random_id) -async def forward_messages(self: Client) -> None: - self - raise NotImplementedError +async def edit_message( + self: Client, + chat: ChatLike, + message_id: int, + *, + text: Optional[str] = None, + markdown: Optional[str] = None, + html: Optional[str] = None, + link_preview: Optional[bool] = None, +) -> Message: + peer = (await self._resolve_to_packed(chat))._to_input_peer() + message, entities = parse_message(text=text, markdown=markdown, html=html) + return self._build_message_map( + await self( + functions.messages.edit_message( + no_webpage=not link_preview, + peer=peer, + id=message_id, + message=message, + media=None, + reply_markup=None, + entities=entities, + schedule_date=None, + ) + ), + peer, + ).with_id(message_id) -async def edit_message(self: Client) -> None: - self - raise NotImplementedError +async def delete_messages( + self: Client, chat: ChatLike, message_ids: List[int], *, revoke: bool = True +) -> int: + packed_chat = await self._resolve_to_packed(chat) + if packed_chat.is_channel(): + affected = await self( + functions.channels.delete_messages( + channel=packed_chat._to_input_channel(), id=message_ids + ) + ) + else: + affected = await self( + functions.messages.delete_messages(revoke=revoke, id=message_ids) + ) + assert isinstance(affected, types.messages.AffectedMessages) + return affected.pts_count -async def delete_messages(self: Client) -> None: - self - raise NotImplementedError +async def forward_messages( + self: Client, target: ChatLike, message_ids: List[int], source: ChatLike +) -> List[Message]: + to_peer = (await self._resolve_to_packed(target))._to_input_peer() + from_peer = (await self._resolve_to_packed(source))._to_input_peer() + random_ids = [generate_random_id() for _ in message_ids] + map = self._build_message_map( + await self( + functions.messages.forward_messages( + silent=False, + background=False, + with_my_score=False, + drop_author=False, + drop_media_captions=False, + noforwards=False, + from_peer=from_peer, + id=message_ids, + random_id=random_ids, + to_peer=to_peer, + top_msg_id=None, + schedule_date=None, + send_as=None, + ) + ), + to_peer, + ) + return [map.with_random_id(id) for id in random_ids] -async def send_read_acknowledge(self: Client) -> None: - self - raise NotImplementedError +class MessageList(AsyncList[Message]): + def _extend_buffer(self, client: Client, messages: abcs.messages.Messages) -> None: + if isinstance(messages, types.messages.Messages): + self._buffer.extend(Message._from_raw(m) for m in messages.messages) + self._total = len(messages.messages) + self._done = True + elif isinstance(messages, types.messages.MessagesSlice): + self._buffer.extend(Message._from_raw(m) for m in messages.messages) + self._total = messages.count + elif isinstance(messages, types.messages.ChannelMessages): + self._buffer.extend(Message._from_raw(m) for m in messages.messages) + self._total = messages.count + elif isinstance(messages, types.messages.MessagesNotModified): + self._total = messages.count + else: + raise RuntimeError("unexpected case") + + def _last_non_empty_message(self) -> Message: + return next( + ( + m + for m in reversed(self._buffer) + if not isinstance(m._raw, types.MessageEmpty) + ), + Message._from_raw(types.MessageEmpty(id=0, peer_id=None)), + ) -async def pin_message(self: Client) -> None: - self - raise NotImplementedError +class HistoryList(MessageList): + def __init__( + self, + client: Client, + chat: ChatLike, + limit: int, + *, + offset_id: int, + offset_date: int, + ): + super().__init__() + self._client = client + self._chat = chat + self._peer: Optional[abcs.InputPeer] = None + self._limit = limit + self._offset_id = offset_id + self._offset_date = offset_date + + async def _fetch_next(self) -> None: + if self._peer is None: + self._peer = ( + await self._client._resolve_to_packed(self._chat) + )._to_input_peer() + + result = await self._client( + functions.messages.get_history( + peer=self._peer, + offset_id=self._offset_id, + offset_date=self._offset_date, + add_offset=0, + limit=min(max(self._limit, 1), 100), + max_id=0, + min_id=0, + hash=0, + ) + ) + + self._extend_buffer(self._client, result) + self._limit -= len(self._buffer) + if self._buffer: + last = self._last_non_empty_message() + self._offset_id = self._buffer[-1].id + if (date := getattr(last._raw, "date", None)) is not None: + self._offset_date = date -async def unpin_message(self: Client) -> None: - self - raise NotImplementedError +def get_messages( + self: Client, + chat: ChatLike, + limit: Optional[int] = None, + *, + offset_id: Optional[int], + offset_date: Optional[datetime.datetime], +) -> AsyncList[Message]: + return HistoryList( + self, + chat, + sys.maxsize if limit is None else limit, + offset_id=offset_id or 0, + offset_date=int(offset_date.timestamp()) if offset_date is not None else 0, + ) -def find_updates_message( +class CherryPickedList(MessageList): + def __init__( + self, + client: Client, + chat: ChatLike, + ids: List[int], + ): + super().__init__() + self._client = client + self._chat = chat + self._packed: Optional[PackedChat] = None + self._ids = ids + + async def _fetch_next(self) -> None: + if not self._ids: + return + if self._packed is None: + self._packed = await self._client._resolve_to_packed(self._chat) + + if self._packed.is_channel(): + result = await self._client( + functions.channels.get_messages( + channel=self._packed._to_input_channel(), + id=[types.InputMessageId(id=id) for id in self._ids[:100]], + ) + ) + else: + result = await self._client( + functions.messages.get_messages( + id=[types.InputMessageId(id=id) for id in self._ids[:100]] + ) + ) + + self._extend_buffer(self._client, result) + self._ids = self._ids[100:] + + +def get_messages_with_ids( + self: Client, + chat: ChatLike, + message_ids: List[int], +) -> AsyncList[Message]: + return CherryPickedList(self, chat, message_ids) + + +class SearchList(MessageList): + def __init__( + self, + client: Client, + chat: ChatLike, + limit: int, + *, + query: str, + offset_id: int, + offset_date: int, + ): + super().__init__() + self._client = client + self._chat = chat + self._peer: Optional[abcs.InputPeer] = None + self._limit = limit + self._query = query + self._offset_id = offset_id + self._offset_date = offset_date + + async def _fetch_next(self) -> None: + if self._peer is None: + self._peer = ( + await self._client._resolve_to_packed(self._chat) + )._to_input_peer() + + result = await self._client( + functions.messages.search( + peer=self._peer, + q=self._query, + from_id=None, + top_msg_id=None, + filter=types.InputMessagesFilterEmpty(), + min_date=0, + max_date=self._offset_date, + offset_id=self._offset_id, + add_offset=0, + limit=min(max(self._limit, 1), 100), + max_id=0, + min_id=0, + hash=0, + ) + ) + + self._extend_buffer(self._client, result) + self._limit -= len(self._buffer) + if self._buffer: + last = self._last_non_empty_message() + self._offset_id = self._buffer[-1].id + if (date := getattr(last._raw, "date", None)) is not None: + self._offset_date = date + + +def search_messages( + self: Client, + chat: ChatLike, + limit: Optional[int] = None, + *, + query: Optional[str] = None, + offset_id: int, + offset_date: datetime.datetime, +) -> AsyncList[Message]: + return SearchList( + self, + chat, + sys.maxsize if limit is None else limit, + query=query or "", + offset_id=offset_id or 0, + offset_date=int(offset_date.timestamp()) if offset_date is not None else 0, + ) + + +class GlobalSearchList(MessageList): + def __init__( + self, + client: Client, + limit: int, + *, + query: str, + offset_id: int, + offset_date: int, + ): + super().__init__() + self._client = client + self._limit = limit + self._query = query + self._offset_id = offset_id + self._offset_date = offset_date + self._offset_rate = 0 + self._offset_peer: abcs.InputPeer = types.InputPeerEmpty() + + async def _fetch_next(self) -> None: + result = await self._client( + functions.messages.search_global( + folder_id=None, + q=self._query, + filter=types.InputMessagesFilterEmpty(), + min_date=0, + max_date=self._offset_date, + offset_rate=self._offset_rate, + offset_peer=self._offset_peer, + offset_id=self._offset_id, + limit=min(max(self._limit, 1), 100), + ) + ) + + self._extend_buffer(self._client, result) + self._limit -= len(self._buffer) + if self._buffer: + last = self._last_non_empty_message() + last_packed = last.chat.pack() + + self._offset_id = self._buffer[-1].id + if (date := getattr(last._raw, "date", None)) is not None: + self._offset_date = date + if isinstance(result, types.messages.MessagesSlice): + self._offset_rate = result.next_rate or 0 + self._offset_peer = ( + last_packed._to_input_peer() if last_packed else types.InputPeerEmpty() + ) + + +def search_all_messages( + self: Client, + limit: Optional[int] = None, + *, + query: Optional[str] = None, + offset_id: int, + offset_date: datetime.datetime, +) -> AsyncList[Message]: + return GlobalSearchList( + self, + sys.maxsize if limit is None else limit, + query=query or "", + offset_id=offset_id or 0, + offset_date=int(offset_date.timestamp()) if offset_date is not None else 0, + ) + + +async def pin_message(self: Client, chat: ChatLike, message_id: int) -> Message: + peer = (await self._resolve_to_packed(chat))._to_input_peer() + return self._build_message_map( + await self( + functions.messages.update_pinned_message( + silent=True, unpin=False, pm_oneside=False, peer=peer, id=message_id + ) + ), + peer, + ).get_single() + + +async def unpin_message( + self: Client, chat: ChatLike, message_id: Union[int, Literal["all"]] +) -> None: + peer = (await self._resolve_to_packed(chat))._to_input_peer() + if message_id == "all": + await self( + functions.messages.unpin_all_messages( + peer=peer, + top_msg_id=None, + ) + ) + else: + await self( + functions.messages.update_pinned_message( + silent=True, unpin=True, pm_oneside=False, peer=peer, id=message_id + ) + ) + + +class MessageMap: + __slots__ = ("_client", "_peer", "_random_id_to_id", "_id_to_message") + + def __init__( + self, + client: Client, + peer: Optional[abcs.InputPeer], + random_id_to_id: Dict[int, int], + id_to_message: Dict[int, Message], + ) -> None: + self._client = client + self._peer = peer + self._random_id_to_id = random_id_to_id + self._id_to_message = id_to_message + + def with_random_id(self, random_id: int) -> Message: + id = self._random_id_to_id.get(random_id) + return self.with_id(id) if id is not None else self._empty() + + def with_id(self, id: int) -> Message: + message = self._id_to_message.get(id) + return message if message is not None else self._empty(id) + + def get_single(self) -> Message: + if len(self._id_to_message) == 1: + for message in self._id_to_message.values(): + return message + return self._empty() + + def _empty(self, id: int = 0) -> Message: + return Message._from_raw( + types.MessageEmpty(id=id, peer_id=self._client._input_as_peer(self._peer)) + ) + + +def build_message_map( self: Client, result: abcs.Updates, - random_id: int, - chat: Optional[abcs.InputPeer], -) -> Message: + peer: Optional[abcs.InputPeer], +) -> MessageMap: if isinstance(result, types.UpdateShort): updates = [result.update] entities: Dict[int, object] = {} @@ -63,15 +514,13 @@ def find_updates_message( entities = {} raise NotImplementedError() else: - return Message._from_raw( - types.MessageEmpty(id=0, peer_id=self._input_as_peer(chat)) - ) + return MessageMap(self, peer, {}, {}) - random_to_id = {} + random_id_to_id = {} id_to_message = {} for update in updates: if isinstance(update, types.UpdateMessageId): - random_to_id[update.random_id] = update.id + random_id_to_id[update.random_id] = update.id elif isinstance( update, @@ -87,9 +536,14 @@ def find_updates_message( update.message, (types.Message, types.MessageService, types.MessageEmpty), ) - id_to_message[update.message.id] = update.message + id_to_message[update.message.id] = Message._from_raw(update.message) elif isinstance(update, types.UpdateMessagePoll): raise NotImplementedError() - return Message._from_raw(id_to_message[random_to_id[random_id]]) + return MessageMap( + self, + peer, + random_id_to_id, + id_to_message, + ) diff --git a/client/src/telethon/_impl/client/types/async_list.py b/client/src/telethon/_impl/client/types/async_list.py new file mode 100644 index 00000000..fdab0af3 --- /dev/null +++ b/client/src/telethon/_impl/client/types/async_list.py @@ -0,0 +1,67 @@ +import abc +from collections import deque +from typing import Any, Deque, Generator, Generic, List, Self, TypeVar + +T = TypeVar("T") + + +class AsyncList(abc.ABC, Generic[T]): + """ + An asynchronous list. + + It can be awaited to get all the items as a normal `list`, + or iterated over via `async for`. + + Both approaches will perform as many requests as needed to retrieve the + items, but awaiting will need to do it all at once, which can be slow. + + Using asynchronous iteration will perform the requests lazily as needed, + and lets you break out of the loop at any time to stop fetching items. + + The `len()` of the asynchronous list will be the "total count" reported + by the server. It does not necessarily reflect how many items will + actually be returned. This count can change as more items are fetched. + """ + + def __init__(self) -> None: + self._buffer: Deque[T] = deque() + self._total: int = 0 + self._done = False + + @abc.abstractmethod + async def _fetch_next(self) -> None: + """ + Fetch the next chunk of items. + + The `_buffer` should be extended from the end, not the front. + The `_total` should be updated with the count reported by the server. + The `_done` flag should be set if it is known that the end was reached + """ + + async def _collect(self) -> List[T]: + prev = -1 + while prev != len(self._buffer): + prev = len(self._buffer) + await self._fetch_next() + return list(self._buffer) + + def __await__(self) -> Generator[Any, None, List[T]]: + return self._collect().__await__() + + def __aiter__(self) -> Self: + return self + + async def __anext__(self) -> T: + if not self._buffer: + if self._done: + raise StopAsyncIteration + await self._fetch_next() + + if not self._buffer: + self._done = True + raise StopAsyncIteration + + return self._buffer.popleft() + + def __len__(self) -> int: + return self._total diff --git a/client/src/telethon/_impl/client/types/message.py b/client/src/telethon/_impl/client/types/message.py index 700d5eb9..e01bdf4e 100644 --- a/client/src/telethon/_impl/client/types/message.py +++ b/client/src/telethon/_impl/client/types/message.py @@ -1,16 +1,37 @@ -from typing import Self - -from telethon._impl.tl import abcs +import datetime +from typing import Optional, Self +from ...client.types.chat import Chat +from ...tl import abcs, types from .meta import NoPublicConstructor class Message(metaclass=NoPublicConstructor): - __slots__ = ("_message",) + __slots__ = ("_raw",) def __init__(self, message: abcs.Message) -> None: - self._message = message + assert isinstance( + message, (types.Message, types.MessageService, types.MessageEmpty) + ) + self._raw = message @classmethod def _from_raw(cls, message: abcs.Message) -> Self: return cls._create(message) + + @property + def id(self) -> int: + return self._raw.id + + @property + def date(self) -> Optional[datetime.datetime]: + date = getattr(self._raw, "date", None) + return ( + datetime.datetime.fromtimestamp(date, tz=datetime.timezone.utc) + if date is not None + else None + ) + + @property + def chat(self) -> Chat: + raise NotImplementedError diff --git a/client/src/telethon/_impl/client/utils.py b/client/src/telethon/_impl/client/utils.py new file mode 100644 index 00000000..8753dbd9 --- /dev/null +++ b/client/src/telethon/_impl/client/utils.py @@ -0,0 +1,11 @@ +import time + +_last_id = 0 + + +def generate_random_id() -> int: + global _last_id + if _last_id == 0: + _last_id = int(time.time() * 1e9) + _last_id += 1 + return _last_id