Implemnet chat hash cache and adapting updates

This commit is contained in:
Lonami Exo 2023-08-31 17:36:08 +02:00
parent 7c112d8b0f
commit c77c10b48f
8 changed files with 697 additions and 0 deletions

View File

@ -0,0 +1,4 @@
from .hash_cache import ChatHashCache
from .packed import PackedChat, PackedType
__all__ = ["ChatHashCache", "PackedChat", "PackedType"]

View File

@ -0,0 +1,314 @@
from typing import Dict, List, Optional, Tuple
from ...tl import abcs, types
from .packed import PackedChat, PackedType
class ChatHashCache:
__slots__ = ("_hash_map", "_self_id", "_self_bot")
def __init__(self, self_user: Optional[Tuple[int, bool]]):
self._hash_map: Dict[int, Tuple[int, PackedType]] = {}
self._self_id = self_user[0] if self_user else None
self._self_bot = self_user[1] if self_user else False
@property
def self_id(self) -> int:
assert self._self_id is not None
return self._self_id
@property
def is_self_bot(self) -> bool:
return self._self_bot
def set_self_user(self, user: PackedChat) -> None:
assert user.ty in (PackedType.USER, PackedType.BOT)
self._self_bot = user.ty == PackedType.BOT
self._self_id = user.id
def get(self, id: int) -> Optional[PackedChat]:
if (entry := self._hash_map.get(id)) is not None:
hash, ty = entry
return PackedChat(ty, id, hash)
else:
return None
def _has(self, id: int) -> bool:
return id in self._hash_map
def _has_peer(self, peer: abcs.Peer) -> bool:
if isinstance(peer, types.PeerUser):
return self._has(peer.user_id)
elif isinstance(peer, types.PeerChat):
return True # no hash needed, so we always have it
elif isinstance(peer, types.PeerChannel):
return self._has(peer.channel_id)
else:
raise RuntimeError("unexpected case")
def _has_dialog_peer(self, peer: abcs.DialogPeer) -> bool:
if isinstance(peer, types.DialogPeer):
return self._has_peer(peer.peer)
elif isinstance(peer, types.DialogPeerFolder):
return True
else:
raise RuntimeError("unexpected case")
def _has_notify_peer(self, peer: abcs.NotifyPeer) -> bool:
if isinstance(peer, types.NotifyPeer):
return self._has_peer(peer.peer)
elif isinstance(peer, types.NotifyForumTopic):
return self._has_peer(peer.peer)
elif isinstance(
peer, (types.NotifyUsers, types.NotifyChats, types.NotifyBroadcasts)
):
return True
else:
raise RuntimeError("unexpected case")
def _has_button(self, button: abcs.KeyboardButton) -> bool:
if isinstance(button, types.InputKeyboardButtonUrlAuth):
return self._has_user(button.bot)
elif isinstance(button, types.InputKeyboardButtonUserProfile):
return self._has_user(button.user_id)
elif isinstance(button, types.KeyboardButtonUserProfile):
return self._has(button.user_id)
else:
return True
def _has_entity(self, entity: abcs.MessageEntity) -> bool:
if isinstance(entity, types.MessageEntityMentionName):
return self._has(entity.user_id)
elif isinstance(entity, types.InputMessageEntityMentionName):
return self._has_user(entity.user_id)
else:
return True
def _has_user(self, peer: abcs.InputUser) -> bool:
if isinstance(peer, (types.InputUserEmpty, types.InputUserSelf)):
return True
elif isinstance(peer, types.InputUser):
return self._has(peer.user_id)
elif isinstance(peer, types.InputUserFromMessage):
return self._has(peer.user_id)
else:
raise RuntimeError("unexpected case")
def _has_participant(self, participant: abcs.ChatParticipant) -> bool:
if isinstance(participant, types.ChatParticipant):
return self._has(participant.user_id) and self._has(participant.inviter_id)
elif isinstance(participant, types.ChatParticipantCreator):
return self._has(participant.user_id)
elif isinstance(participant, types.ChatParticipantAdmin):
return self._has(participant.user_id) and self._has(participant.inviter_id)
else:
raise RuntimeError("unexpected case")
def _has_channel_participant(self, participant: abcs.ChannelParticipant) -> bool:
if isinstance(participant, types.ChannelParticipant):
return self._has(participant.user_id)
elif isinstance(participant, types.ChannelParticipantSelf):
return self._has(participant.user_id) and self._has(participant.inviter_id)
elif isinstance(participant, types.ChannelParticipantCreator):
return self._has(participant.user_id)
elif isinstance(participant, types.ChannelParticipantAdmin):
return (
self._has(participant.user_id)
and (
participant.inviter_id is None or self._has(participant.inviter_id)
)
and self._has(participant.promoted_by)
)
elif isinstance(participant, types.ChannelParticipantBanned):
return self._has_peer(participant.peer) and self._has(participant.kicked_by)
elif isinstance(participant, types.ChannelParticipantLeft):
return self._has_peer(participant.peer)
else:
raise RuntimeError("unexpected case")
def extend(self, users: List[abcs.User], chats: List[abcs.Chat]) -> bool:
# See https://core.telegram.org/api/min for "issues" with "min constructors".
success = True
for user in users:
if isinstance(user, types.UserEmpty):
pass
elif isinstance(user, types.User):
if not user.min and user.access_hash is not None:
ty = PackedType.BOT if user.bot else PackedType.USER
self._hash_map[user.id] = (user.access_hash, ty)
else:
success &= user.id in self._hash_map
else:
raise RuntimeError("unexpected case")
for chat in chats:
if isinstance(chat, (types.ChatEmpty, types.Chat, types.ChatForbidden)):
pass
elif isinstance(chat, types.Channel):
if not chat.min and chat.access_hash is not None:
if chat.megagroup:
ty = PackedType.MEGAGROUP
elif chat.gigagroup:
ty = PackedType.GIGAGROUP
else:
ty = PackedType.BROADCAST
self._hash_map[chat.id] = (chat.access_hash, ty)
else:
success &= chat.id in self._hash_map
elif isinstance(chat, types.ChannelForbidden):
ty = PackedType.MEGAGROUP if chat.megagroup else PackedType.BROADCAST
self._hash_map[chat.id] = (chat.access_hash, ty)
else:
raise RuntimeError("unexpected case")
return success
def extend_from_updates(self, updates: abcs.Updates) -> bool:
if isinstance(updates, types.UpdatesTooLong):
return True
elif isinstance(updates, types.UpdateShortMessage):
return self._has(updates.user_id)
elif isinstance(updates, types.UpdateShortChatMessage):
return self._has(updates.from_id)
elif isinstance(updates, types.UpdateShort):
success = True
update = updates.update
# In Python, we get to cheat rather than having hundreds of `if isinstance`
for field in ("message",):
message = getattr(update, field, None)
if isinstance(message, abcs.Message):
success &= self.extend_from_message(message)
for field in ("user_id", "inviter_id", "channel_id", "bot_id", "actor_id"):
int_id = getattr(update, field, None)
if isinstance(int_id, int):
success &= self._has(int_id)
for field in ("from_id", "peer"):
peer = getattr(update, field, None)
if isinstance(peer, abcs.Peer):
success &= self._has_peer(peer)
elif isinstance(peer, abcs.DialogPeer):
success &= self._has_dialog_peer(peer)
elif isinstance(peer, abcs.NotifyPeer):
success &= self._has_notify_peer(peer)
# TODO cover?:
# ChatParticipants.participants
# PinnedDialogs.order
# FolderPeers.folder_peers
# PeerLocated.peers
# GroupCallParticipants.participants
# ChatParticipant and ChannelParticipant .prev_participant, new_participant, invite
# BotChatInviteRequester.invite
return success
elif isinstance(updates, types.UpdatesCombined):
return self.extend(updates.users, updates.chats)
elif isinstance(updates, types.Updates):
return self.extend(updates.users, updates.chats)
elif isinstance(updates, types.UpdateShortSentMessage):
return True
else:
raise RuntimeError("unexpected case")
def extend_from_message(self, message: abcs.Message) -> bool:
if isinstance(message, types.MessageEmpty):
return message.peer_id is None or self._has_peer(message.peer_id)
elif isinstance(message, types.Message):
success = True
if message.from_id is not None:
success &= self._has_peer(message.from_id)
success &= self._has_peer(message.peer_id)
if isinstance(message.fwd_from, types.MessageFwdHeader):
if message.fwd_from.from_id:
success &= self._has_peer(message.fwd_from.from_id)
if message.fwd_from.saved_from_peer:
success &= self._has_peer(message.fwd_from.saved_from_peer)
elif message.fwd_from is not None:
raise RuntimeError("unexpected case")
if isinstance(message.reply_to, types.MessageReplyHeader):
if message.reply_to.reply_to_peer_id:
success &= self._has_peer(message.reply_to.reply_to_peer_id)
elif message.reply_to is not None:
raise RuntimeError("unexpected case")
if message.reply_markup is not None:
if isinstance(message.reply_markup, types.ReplyKeyboardMarkup):
for row in message.reply_markup.rows:
if isinstance(row, types.KeyboardButtonRow):
for button in row.buttons:
success &= self._has_button(button)
elif isinstance(message.reply_markup, types.ReplyInlineMarkup):
for row in message.reply_markup.rows:
if isinstance(row, types.KeyboardButtonRow):
for button in row.buttons:
success &= self._has_button(button)
if message.entities:
for entity in message.entities:
success &= self._has_entity(entity)
if isinstance(message.replies, types.MessageReplies):
if message.replies.recent_repliers:
for p in message.replies.recent_repliers:
success &= self._has_peer(p)
elif message.replies is not None:
raise RuntimeError("unexpected case")
if isinstance(message.reactions, types.MessageReactions):
if message.reactions.recent_reactions:
for r in message.reactions.recent_reactions:
if isinstance(r, types.MessagePeerReaction):
success &= self._has_peer(r.peer_id)
else:
raise RuntimeError("unexpected case")
elif message.reactions is not None:
raise RuntimeError("unexpected case")
return success
elif isinstance(message, types.MessageService):
success = True
if message.from_id:
success &= self._has_peer(message.from_id)
if message.peer_id:
success &= self._has_peer(message.peer_id)
if isinstance(message.reply_to, types.MessageReplyHeader):
if message.reply_to.reply_to_peer_id:
success &= self._has_peer(message.reply_to.reply_to_peer_id)
elif message.reply_to is not None:
raise RuntimeError("unexpected case")
for field in ("user_id", "inviter_id", "channel_id"):
int_id = getattr(message.action, field, None)
if isinstance(int_id, int):
success &= self._has(int_id)
for field in ("from_id", "to_id", "peer"):
peer = getattr(message.action, field, None)
if isinstance(peer, abcs.Peer):
success &= self._has_peer(peer)
elif isinstance(peer, abcs.DialogPeer):
success &= self._has_dialog_peer(peer)
elif isinstance(peer, abcs.NotifyPeer):
success &= self._has_notify_peer(peer)
for field in ("users",):
users = getattr(message.action, field, None)
if isinstance(users, list):
for user in users:
if isinstance(user, int):
success &= self._has(user)
return success
else:
raise RuntimeError("unexpected case")

View File

@ -0,0 +1,107 @@
import struct
from enum import Enum
from typing import Optional, Self
from telethon._impl.tl import abcs, types
class PackedType(Enum):
# bits: zero, has-access-hash, channel, broadcast, group, chat, user, bot
USER = 0b0000_0010
BOT = 0b0000_0011
CHAT = 0b0000_0100
MEGAGROUP = 0b0010_1000
BROADCAST = 0b0011_0000
GIGAGROUP = 0b0011_1000
class PackedChat:
__slots__ = ("ty", "id", "access_hash")
def __init__(self, ty: PackedType, id: int, access_hash: Optional[int]) -> None:
self.ty = ty
self.id = id
self.access_hash = access_hash
def __bytes__(self) -> bytes:
return struct.pack(
"<Bqq",
self.ty.value | (0 if self.access_hash is None else 0b0100_0000),
self.id,
self.access_hash or 0,
)
@classmethod
def from_bytes(cls, data: bytes) -> Self:
ty_byte, id, access_hash = struct.unpack("<Bqq", data)
has_hash = (ty_byte & 0b0100_0000) != 0
ty = PackedType(ty_byte & 0b0011_1111)
return cls(ty, id, access_hash if has_hash else None)
def is_user(self) -> bool:
return self.ty in (PackedType.USER, PackedType.BOT)
def is_chat(self) -> bool:
return self.ty in (PackedType.CHAT,)
def is_channel(self) -> bool:
return self.ty in (
PackedType.MEGAGROUP,
PackedType.BROADCAST,
PackedType.GIGAGROUP,
)
def to_peer(self) -> abcs.Peer:
if self.is_user():
return types.PeerUser(user_id=self.id)
elif self.is_chat():
return types.PeerChat(chat_id=self.id)
elif self.is_channel():
return types.PeerChannel(channel_id=self.id)
else:
raise RuntimeError("unexpected case")
def to_input_peer(self) -> abcs.InputPeer:
if self.is_user():
return types.InputPeerUser(
user_id=self.id, access_hash=self.access_hash or 0
)
elif self.is_chat():
return types.InputPeerChat(chat_id=self.id)
elif self.is_channel():
return types.InputPeerChannel(
channel_id=self.id, access_hash=self.access_hash or 0
)
else:
raise RuntimeError("unexpected case")
def try_to_input_user(self) -> Optional[abcs.InputUser]:
if self.is_user():
return types.InputUser(user_id=self.id, access_hash=self.access_hash or 0)
else:
return None
def to_input_user_lossy(self) -> abcs.InputUser:
return self.try_to_input_user() or types.InputUser(user_id=0, access_hash=0)
def try_to_chat_id(self) -> Optional[int]:
return self.id if self.is_chat() else None
def try_to_input_channel(self) -> Optional[abcs.InputChannel]:
return (
types.InputChannel(channel_id=self.id, access_hash=self.access_hash or 0)
if self.is_channel()
else None
)
def __eq__(self, other: object) -> bool:
if not isinstance(other, self.__class__):
return NotImplemented
return (
self.ty == other.ty
and self.id == other.id
and self.access_hash == other.access_hash
)
def __str__(self) -> str:
return f"PackedChat.{self.ty.name}({self.id})"

View File

@ -0,0 +1,243 @@
from typing import Optional, Tuple
from ...tl import abcs, types
from ..chat.hash_cache import ChatHashCache
from .defs import ACCOUNT_WIDE, NO_SEQ, SECRET_CHATS, Gap
def updates_(updates: types.Updates) -> types.UpdatesCombined:
return types.UpdatesCombined(
updates=updates.updates,
users=updates.users,
chats=updates.chats,
date=updates.date,
seq_start=updates.seq,
seq=updates.seq,
)
def update_short(short: types.UpdateShort) -> types.UpdatesCombined:
return types.UpdatesCombined(
updates=[short.update],
users=[],
chats=[],
date=short.date,
seq_start=NO_SEQ,
seq=NO_SEQ,
)
def update_short_message(
short: types.UpdateShortMessage, self_id: int
) -> types.UpdatesCombined:
return update_short(
types.UpdateShort(
update=types.UpdateNewMessage(
message=types.Message(
out=short.out,
mentioned=short.mentioned,
media_unread=short.media_unread,
silent=short.silent,
post=False,
from_scheduled=False,
legacy=False,
edit_hide=False,
pinned=False,
noforwards=False,
reactions=None,
id=short.id,
from_id=types.PeerUser(
user_id=self_id if short.out else short.user_id
),
peer_id=types.PeerChat(
chat_id=short.user_id,
),
fwd_from=short.fwd_from,
via_bot_id=short.via_bot_id,
reply_to=short.reply_to,
date=short.date,
message=short.message,
media=None,
reply_markup=None,
entities=short.entities,
views=None,
forwards=None,
replies=None,
edit_date=None,
post_author=None,
grouped_id=None,
restriction_reason=None,
ttl_period=short.ttl_period,
),
pts=short.pts,
pts_count=short.pts_count,
),
date=short.date,
)
)
def update_short_chat_message(
short: types.UpdateShortChatMessage,
) -> types.UpdatesCombined:
return update_short(
types.UpdateShort(
update=types.UpdateNewMessage(
message=types.Message(
out=short.out,
mentioned=short.mentioned,
media_unread=short.media_unread,
silent=short.silent,
post=False,
from_scheduled=False,
legacy=False,
edit_hide=False,
pinned=False,
noforwards=False,
reactions=None,
id=short.id,
from_id=types.PeerUser(
user_id=short.from_id,
),
peer_id=types.PeerChat(
chat_id=short.chat_id,
),
fwd_from=short.fwd_from,
via_bot_id=short.via_bot_id,
reply_to=short.reply_to,
date=short.date,
message=short.message,
media=None,
reply_markup=None,
entities=short.entities,
views=None,
forwards=None,
replies=None,
edit_date=None,
post_author=None,
grouped_id=None,
restriction_reason=None,
ttl_period=short.ttl_period,
),
pts=short.pts,
pts_count=short.pts_count,
),
date=short.date,
)
)
def update_short_sent_message(
short: types.UpdateShortSentMessage,
) -> types.UpdatesCombined:
return update_short(
types.UpdateShort(
update=types.UpdateNewMessage(
message=types.MessageEmpty(
id=short.id,
peer_id=None,
),
pts=short.pts,
pts_count=short.pts_count,
),
date=short.date,
)
)
def adapt(updates: abcs.Updates, chat_hashes: ChatHashCache) -> types.UpdatesCombined:
if isinstance(updates, types.UpdatesTooLong):
raise Gap
elif isinstance(updates, types.UpdateShortMessage):
return update_short_message(updates, chat_hashes.self_id)
elif isinstance(updates, types.UpdateShortChatMessage):
return update_short_chat_message(updates)
elif isinstance(updates, types.UpdateShort):
return update_short(updates)
elif isinstance(updates, types.UpdatesCombined):
return updates
elif isinstance(updates, types.Updates):
return updates_(updates)
elif isinstance(updates, types.UpdateShortSentMessage):
return update_short_sent_message(updates)
else:
raise RuntimeError("unexpected case")
def message_peer(message: abcs.Message) -> Optional[abcs.Peer]:
if isinstance(message, types.MessageEmpty):
return None
elif isinstance(message, types.Message):
return message.peer_id
elif isinstance(message, types.MessageService):
return message.peer_id
else:
raise RuntimeError("unexpected case")
def message_channel_id(message: abcs.Message) -> Optional[int]:
peer = message_peer(message)
return peer.channel_id if isinstance(peer, types.PeerChannel) else None
def pts_info_from_update(update: abcs.Update) -> Optional[Tuple[int | str, int, int]]:
if isinstance(update, types.UpdateNewMessage):
assert not isinstance(message_peer(update.message), types.PeerChannel)
return ACCOUNT_WIDE, update.pts, update.pts_count
elif isinstance(update, types.UpdateDeleteMessages):
return ACCOUNT_WIDE, update.pts, update.pts_count
elif isinstance(update, types.UpdateNewEncryptedMessage):
return SECRET_CHATS, update.qts, 1
elif isinstance(update, types.UpdateReadHistoryInbox):
assert not isinstance(update.peer, types.PeerChannel)
return ACCOUNT_WIDE, update.pts, update.pts_count
elif isinstance(update, types.UpdateReadHistoryOutbox):
assert not isinstance(update.peer, types.PeerChannel)
return ACCOUNT_WIDE, update.pts, update.pts_count
elif isinstance(update, types.UpdateWebPage):
return ACCOUNT_WIDE, update.pts, update.pts_count
elif isinstance(update, types.UpdateReadMessagesContents):
return ACCOUNT_WIDE, update.pts, update.pts_count
elif isinstance(update, types.UpdateChannelTooLong):
if update.pts is not None:
return update.channel_id, update.pts, 0
else:
return None
elif isinstance(update, types.UpdateNewChannelMessage):
channel_id = message_channel_id(update.message)
if channel_id is not None:
return channel_id, update.pts, update.pts_count
else:
return None
elif isinstance(update, types.UpdateReadChannelInbox):
return update.channel_id, update.pts, 0
elif isinstance(update, types.UpdateDeleteChannelMessages):
return update.channel_id, update.pts, update.pts_count
elif isinstance(update, types.UpdateEditChannelMessage):
channel_id = message_channel_id(update.message)
if channel_id is not None:
return channel_id, update.pts, update.pts_count
else:
return None
elif isinstance(update, types.UpdateEditMessage):
assert not isinstance(message_peer(update.message), types.PeerChannel)
return ACCOUNT_WIDE, update.pts, update.pts_count
elif isinstance(update, types.UpdateChannelWebPage):
return update.channel_id, update.pts, update.pts_count
elif isinstance(update, types.UpdateFolderPeers):
return ACCOUNT_WIDE, update.pts, update.pts_count
elif isinstance(update, types.UpdatePinnedMessages):
assert not isinstance(update.peer, types.PeerChannel)
return ACCOUNT_WIDE, update.pts, update.pts_count
elif isinstance(update, types.UpdatePinnedChannelMessages):
return update.channel_id, update.pts, update.pts_count
elif isinstance(update, types.UpdateChatParticipant):
return SECRET_CHATS, update.qts, 0
elif isinstance(update, types.UpdateChannelParticipant):
return SECRET_CHATS, update.qts, 0
elif isinstance(update, types.UpdateBotStopped):
return SECRET_CHATS, update.qts, 0
elif isinstance(update, types.UpdateBotChatInviteRequester):
return SECRET_CHATS, update.qts, 0
else:
return None

View File

@ -0,0 +1,19 @@
NO_SEQ = 0
NO_PTS = 0
# https://core.telegram.org/method/updates.getChannelDifference
BOT_CHANNEL_DIFF_LIMIT = 100000
USER_CHANNEL_DIFF_LIMIT = 100
POSSIBLE_GAP_TIMEOUT = 0.5
# https://core.telegram.org/api/updates
NO_UPDATES_TIMEOUT = 15 * 60
ACCOUNT_WIDE = "ACCOUNT"
SECRET_CHATS = "SECRET"
class Gap(ValueError):
pass

View File

@ -0,0 +1,10 @@
from telethon._impl.session.chat.packed import PackedChat, PackedType
def test_hash_optional() -> None:
for ty in PackedType:
pc = PackedChat(ty, 123, 456789)
assert PackedChat.from_bytes(bytes(pc)) == pc
pc = PackedChat(ty, 987, None)
assert PackedChat.from_bytes(bytes(pc)) == pc