From 69d794185257d43e186d925666cfac9dc2aeee07 Mon Sep 17 00:00:00 2001 From: Lonami Exo Date: Thu, 31 Aug 2023 20:11:35 +0200 Subject: [PATCH] Implement MessageBox --- .../_impl/session/message_box/adaptor.py | 48 +- .../_impl/session/message_box/defs.py | 141 +++- .../_impl/session/message_box/messagebox.py | 637 ++++++++++++++++++ 3 files changed, 797 insertions(+), 29 deletions(-) create mode 100644 client/src/telethon/_impl/session/message_box/messagebox.py diff --git a/client/src/telethon/_impl/session/message_box/adaptor.py b/client/src/telethon/_impl/session/message_box/adaptor.py index 43081009..6bd585f4 100644 --- a/client/src/telethon/_impl/session/message_box/adaptor.py +++ b/client/src/telethon/_impl/session/message_box/adaptor.py @@ -1,8 +1,8 @@ -from typing import Optional, Tuple +from typing import Optional from ...tl import abcs, types from ..chat.hash_cache import ChatHashCache -from .defs import ACCOUNT_WIDE, NO_SEQ, SECRET_CHATS, Gap +from .defs import ENTRY_ACCOUNT, ENTRY_SECRET, NO_SEQ, Gap, PtsInfo def updates_(updates: types.Updates) -> types.UpdatesCombined: @@ -180,64 +180,64 @@ def message_channel_id(message: abcs.Message) -> Optional[int]: return peer.channel_id if isinstance(peer, types.PeerChannel) else None -def pts_info_from_update(update: abcs.Update) -> Optional[Tuple[int | str, int, int]]: +def pts_info_from_update(update: abcs.Update) -> Optional[PtsInfo]: if isinstance(update, types.UpdateNewMessage): assert not isinstance(message_peer(update.message), types.PeerChannel) - return ACCOUNT_WIDE, update.pts, update.pts_count + return PtsInfo(ENTRY_ACCOUNT, update.pts, update.pts_count) elif isinstance(update, types.UpdateDeleteMessages): - return ACCOUNT_WIDE, update.pts, update.pts_count + return PtsInfo(ENTRY_ACCOUNT, update.pts, update.pts_count) elif isinstance(update, types.UpdateNewEncryptedMessage): - return SECRET_CHATS, update.qts, 1 + return PtsInfo(ENTRY_SECRET, update.qts, 1) elif isinstance(update, types.UpdateReadHistoryInbox): assert not isinstance(update.peer, types.PeerChannel) - return ACCOUNT_WIDE, update.pts, update.pts_count + return PtsInfo(ENTRY_ACCOUNT, update.pts, update.pts_count) elif isinstance(update, types.UpdateReadHistoryOutbox): assert not isinstance(update.peer, types.PeerChannel) - return ACCOUNT_WIDE, update.pts, update.pts_count + return PtsInfo(ENTRY_ACCOUNT, update.pts, update.pts_count) elif isinstance(update, types.UpdateWebPage): - return ACCOUNT_WIDE, update.pts, update.pts_count + return PtsInfo(ENTRY_ACCOUNT, update.pts, update.pts_count) elif isinstance(update, types.UpdateReadMessagesContents): - return ACCOUNT_WIDE, update.pts, update.pts_count + return PtsInfo(ENTRY_ACCOUNT, update.pts, update.pts_count) elif isinstance(update, types.UpdateChannelTooLong): if update.pts is not None: - return update.channel_id, update.pts, 0 + return PtsInfo(update.channel_id, update.pts, 0) else: return None elif isinstance(update, types.UpdateNewChannelMessage): channel_id = message_channel_id(update.message) if channel_id is not None: - return channel_id, update.pts, update.pts_count + return PtsInfo(channel_id, update.pts, update.pts_count) else: return None elif isinstance(update, types.UpdateReadChannelInbox): - return update.channel_id, update.pts, 0 + return PtsInfo(update.channel_id, update.pts, 0) elif isinstance(update, types.UpdateDeleteChannelMessages): - return update.channel_id, update.pts, update.pts_count + return PtsInfo(update.channel_id, update.pts, update.pts_count) elif isinstance(update, types.UpdateEditChannelMessage): channel_id = message_channel_id(update.message) if channel_id is not None: - return channel_id, update.pts, update.pts_count + return PtsInfo(channel_id, update.pts, update.pts_count) else: return None elif isinstance(update, types.UpdateEditMessage): assert not isinstance(message_peer(update.message), types.PeerChannel) - return ACCOUNT_WIDE, update.pts, update.pts_count + return PtsInfo(ENTRY_ACCOUNT, update.pts, update.pts_count) elif isinstance(update, types.UpdateChannelWebPage): - return update.channel_id, update.pts, update.pts_count + return PtsInfo(update.channel_id, update.pts, update.pts_count) elif isinstance(update, types.UpdateFolderPeers): - return ACCOUNT_WIDE, update.pts, update.pts_count + return PtsInfo(ENTRY_ACCOUNT, update.pts, update.pts_count) elif isinstance(update, types.UpdatePinnedMessages): assert not isinstance(update.peer, types.PeerChannel) - return ACCOUNT_WIDE, update.pts, update.pts_count + return PtsInfo(ENTRY_ACCOUNT, update.pts, update.pts_count) elif isinstance(update, types.UpdatePinnedChannelMessages): - return update.channel_id, update.pts, update.pts_count + return PtsInfo(update.channel_id, update.pts, update.pts_count) elif isinstance(update, types.UpdateChatParticipant): - return SECRET_CHATS, update.qts, 0 + return PtsInfo(ENTRY_SECRET, update.qts, 0) elif isinstance(update, types.UpdateChannelParticipant): - return SECRET_CHATS, update.qts, 0 + return PtsInfo(ENTRY_SECRET, update.qts, 0) elif isinstance(update, types.UpdateBotStopped): - return SECRET_CHATS, update.qts, 0 + return PtsInfo(ENTRY_SECRET, update.qts, 0) elif isinstance(update, types.UpdateBotChatInviteRequester): - return SECRET_CHATS, update.qts, 0 + return PtsInfo(ENTRY_SECRET, update.qts, 0) else: return None diff --git a/client/src/telethon/_impl/session/message_box/defs.py b/client/src/telethon/_impl/session/message_box/defs.py index e033a70d..fdafd26d 100644 --- a/client/src/telethon/_impl/session/message_box/defs.py +++ b/client/src/telethon/_impl/session/message_box/defs.py @@ -1,3 +1,133 @@ +import logging +from enum import Enum +from typing import List, Literal, Optional, Union + +from ...tl import abcs + + +class DataCenter: + __slots__ = ("id", "ip", "port", "auth") + + def __init__(self, *, id: int, ip: str, port: int, auth: Optional[bytes]) -> None: + self.id = id + self.ip = ip + self.port = port + self.auth = auth + + +class User: + __slots__ = ("id", "dc", "bot") + + def __init__(self, *, id: int, dc: int, bot: bool) -> None: + self.id = id + self.dc = dc + self.bot = bot + + +class ChannelState: + __slots__ = ("id", "pts") + + def __init__(self, *, id: int, pts: int) -> None: + self.id = id + self.pts = pts + + +class UpdateState: + __slots__ = ( + "pts", + "qts", + "date", + "seq", + "channels", + ) + + def __init__( + self, + *, + pts: int, + qts: int, + date: int, + seq: int, + channels: List[ChannelState], + ) -> None: + self.pts = pts + self.qts = qts + self.date = date + self.seq = seq + self.channels = channels + + +class Session: + __slots__ = ("dcs", "user", "state") + + def __init__( + self, + *, + dcs: List[DataCenter], + user: Optional[User], + state: Optional[UpdateState], + ): + self.dcs = dcs + self.user = user + self.state = state + + +class PtsInfo: + __slots__ = ("pts", "pts_count", "entry") + + def __init__(self, entry: "Entry", pts: int, pts_count: int) -> None: + self.pts = pts + self.pts_count = pts_count + self.entry = entry + + def __repr__(self) -> str: + return ( + f"PtsInfo(pts={self.pts}, pts_count={self.pts_count}, entry={self.entry})" + ) + + +class State: + __slots__ = ("pts", "deadline") + + def __init__( + self, + pts: int, + deadline: float, + ) -> None: + self.pts = pts + self.deadline = deadline + + def __repr__(self) -> str: + return f"State(pts={self.pts}, deadline={self.deadline})" + + +class PossibleGap: + __slots__ = ("deadline", "updates") + + def __init__( + self, + deadline: float, + updates: List[abcs.Update], + ) -> None: + self.deadline = deadline + self.updates = updates + + def __repr__(self) -> str: + return ( + f"PossibleGap(deadline={self.deadline}, update_count={len(self.updates)})" + ) + + +class PrematureEndReason(Enum): + TEMPORARY_SERVER_ISSUES = "tmp" + BANNED = "ban" + + +class Gap(ValueError): + def __repr__(self) -> str: + return "Gap()" + + NO_SEQ = 0 NO_PTS = 0 @@ -11,9 +141,10 @@ POSSIBLE_GAP_TIMEOUT = 0.5 # https://core.telegram.org/api/updates NO_UPDATES_TIMEOUT = 15 * 60 -ACCOUNT_WIDE = "ACCOUNT" -SECRET_CHATS = "SECRET" +ENTRY_ACCOUNT: Literal["ACCOUNT"] = "ACCOUNT" +ENTRY_SECRET: Literal["SECRET"] = "SECRET" +Entry = Union[Literal["ACCOUNT"], Literal["SECRET"], int] - -class Gap(ValueError): - pass +# Python's logging doesn't define a TRACE level. Pick halfway between DEBUG and NOTSET. +# We don't define a name for this as libraries shouldn't do that though. +LOG_LEVEL_TRACE = (logging.DEBUG - logging.NOTSET) // 2 diff --git a/client/src/telethon/_impl/session/message_box/messagebox.py b/client/src/telethon/_impl/session/message_box/messagebox.py new file mode 100644 index 00000000..a5f6ed58 --- /dev/null +++ b/client/src/telethon/_impl/session/message_box/messagebox.py @@ -0,0 +1,637 @@ +import asyncio +import datetime +import logging +import time +from typing import Dict, List, Optional, Set, Tuple + +from ...tl import abcs, functions, types +from ...tl.core.request import Request +from ..chat.hash_cache import ChatHashCache +from .adaptor import adapt, pts_info_from_update +from .defs import ( + BOT_CHANNEL_DIFF_LIMIT, + ENTRY_ACCOUNT, + ENTRY_SECRET, + LOG_LEVEL_TRACE, + NO_PTS, + NO_SEQ, + NO_UPDATES_TIMEOUT, + POSSIBLE_GAP_TIMEOUT, + USER_CHANNEL_DIFF_LIMIT, + ChannelState, + Entry, + Gap, + PossibleGap, + PrematureEndReason, + State, + UpdateState, +) + + +def next_updates_deadline() -> float: + return asyncio.get_running_loop().time() + NO_UPDATES_TIMEOUT + + +def epoch() -> datetime.datetime: + return datetime.datetime(*time.gmtime(0)[:6]).replace(tzinfo=datetime.timezone.utc) + + +# https://core.telegram.org/api/updates#message-related-event-sequences. +class MessageBox: + __slots__ = ( + "_log", + "map", + "date", + "seq", + "possible_gaps", + "getting_diff_for", + "next_deadline", + ) + + def __init__( + self, + *, + log: Optional[logging.Logger] = None, + ) -> None: + self._log = log or logging.getLogger("telethon.messagebox") + self.map: Dict[Entry, State] = {} + self.date = epoch() + self.seq = NO_SEQ + self.possible_gaps: Dict[Entry, PossibleGap] = {} + self.getting_diff_for: Set[Entry] = set() + self.next_deadline: Optional[Entry] = None + + if __debug__: + self._trace("MessageBox initialized") + + def _trace(self, msg: str, *args: object) -> None: + # Calls to trace can't really be removed beforehand without some dark magic. + # So every call to trace is prefixed with `if __debug__`` instead, to remove + # it when using `python -O`. Probably unnecessary, but it's nice to avoid + # paying the cost for something that is not used. + self._log.log( + LOG_LEVEL_TRACE, + "Current MessageBox state: seq = %r, date = %s, map = %r", + self.seq, + self.date.isoformat(), + self.map, + ) + self._log.log(LOG_LEVEL_TRACE, msg, *args) + + def load(self, state: UpdateState) -> None: + if __debug__: + self._trace( + "Loading MessageBox with state = %r", + state, + ) + + deadline = next_updates_deadline() + + self.map.clear() + if state.pts != NO_SEQ: + self.map[ENTRY_ACCOUNT] = State(pts=state.pts, deadline=deadline) + if state.qts != NO_SEQ: + self.map[ENTRY_SECRET] = State(pts=state.qts, deadline=deadline) + self.map.update( + (s.id, State(pts=s.pts, deadline=deadline)) for s in state.channels + ) + + self.date = datetime.datetime.fromtimestamp( + state.date, tz=datetime.timezone.utc + ) + self.seq = state.seq + self.possible_gaps.clear() + self.getting_diff_for.clear() + self.next_deadline = ENTRY_ACCOUNT + + def session_state(self) -> UpdateState: + return UpdateState( + pts=self.map[ENTRY_ACCOUNT].pts if ENTRY_ACCOUNT in self.map else NO_PTS, + qts=self.map[ENTRY_SECRET].pts if ENTRY_SECRET in self.map else NO_PTS, + date=int(self.date.timestamp()), + seq=self.seq, + channels=[ + ChannelState(id=int(entry), pts=state.pts) + for entry, state in self.map.items() + if entry not in (ENTRY_ACCOUNT, ENTRY_SECRET) + ], + ) + + def is_empty(self) -> bool: + return (self.map.get(ENTRY_ACCOUNT) or NO_PTS) == NO_PTS + + def check_deadlines(self) -> float: + now = asyncio.get_running_loop().time() + + if self.getting_diff_for: + return now + + default_deadline = next_updates_deadline() + + if self.possible_gaps: + deadline = min( + default_deadline, *(gap.deadline for gap in self.possible_gaps.values()) + ) + elif self.next_deadline in self.map: + deadline = min(default_deadline, self.map[self.next_deadline].deadline) + else: + deadline = default_deadline + + if now >= deadline: + self.getting_diff_for.update( + entry + for entry, gap in self.possible_gaps.items() + if now >= gap.deadline + ) + self.getting_diff_for.update( + entry for entry, state in self.map.items() if now >= state.deadline + ) + + if __debug__: + self._trace( + "Deadlines met, now getting diff for %r", self.getting_diff_for + ) + + for entry in self.getting_diff_for: + self.possible_gaps.pop(entry, None) + + return deadline + + def reset_deadlines(self, entries: Set[Entry], deadline: float) -> None: + if not entries: + return + + for entry in entries: + if entry not in self.map: + raise RuntimeError( + "Called reset_deadline on an entry for which we do not have state" + ) + self.map[entry].deadline = deadline + + if self.next_deadline in entries: + self.next_deadline = min( + self.map.items(), key=lambda entry_state: entry_state[1].deadline + )[0] + elif ( + self.next_deadline in self.map + and deadline < self.map[self.next_deadline].deadline + ): + self.next_deadline = entry + + def reset_channel_deadline(self, channel_id: int, timeout: Optional[float]) -> None: + self.reset_deadlines( + {channel_id}, + asyncio.get_running_loop().time() + (timeout or NO_UPDATES_TIMEOUT), + ) + + def set_state(self, state: abcs.updates.State) -> None: + if __debug__: + self._trace("Setting state %s", state) + + deadline = next_updates_deadline() + assert isinstance(state, types.updates.State) + self.map[ENTRY_ACCOUNT] = State(state.pts, deadline) + self.map[ENTRY_SECRET] = State(state.qts, deadline) + self.date = datetime.datetime.fromtimestamp( + state.date, tz=datetime.timezone.utc + ) + self.seq = state.seq + + def try_set_channel_state(self, id: int, pts: int) -> None: + if __debug__: + self._trace("Trying to set channel state for %r: %r", id, pts) + + if id not in self.map: + self.map[id] = State(pts=pts, deadline=next_updates_deadline()) + + def try_begin_get_diff(self, entry: Entry, reason: str) -> None: + if entry not in self.map: + if entry in self.possible_gaps: + raise RuntimeError( + "Should not have a possible_gap for an entry not in the state map" + ) + return + + if __debug__: + self._trace("Marking %r as needing difference because %s", entry, reason) + self.getting_diff_for.add(entry) + self.possible_gaps.pop(entry, None) + + def end_get_diff(self, entry: Entry) -> None: + try: + self.getting_diff_for.remove(entry) + except KeyError: + raise RuntimeError( + "Called end_get_diff on an entry which was not getting diff for" + ) + + self.reset_deadlines({entry}, next_updates_deadline()) + assert ( + entry not in self.possible_gaps + ), "gaps shouldn't be created while getting difference" + + def ensure_known_peer_hashes( + self, + updates: abcs.Updates, + chat_hashes: ChatHashCache, + ) -> None: + if not chat_hashes.extend_from_updates(updates): + can_recover = ( + not isinstance(updates, types.UpdateShort) + or pts_info_from_update(updates.update) is not None + ) + if can_recover: + raise Gap + + # https://core.telegram.org/api/updates + def process_updates( + self, + updates: abcs.Updates, + chat_hashes: ChatHashCache, + ) -> Tuple[List[abcs.Update], List[abcs.User], List[abcs.Chat]]: + result: List[abcs.Update] = [] + combined = adapt(updates, chat_hashes) + + if __debug__: + self._trace( + "Processing updates with seq = %r, seq_start = %r, date = %r: %s", + combined.seq, + combined.seq_start, + combined.date, + updates, + ) + + if combined.seq_start != NO_SEQ: + if self.seq + 1 > combined.seq_start: + if __debug__: + self._trace( + "Skipping updates as they should have already been handled" + ) + return result, combined.users, combined.chats + elif self.seq + 1 < combined.seq_start: + self.try_begin_get_diff(ENTRY_ACCOUNT, "detected gap") + raise Gap + + def update_sort_key(update: abcs.Update) -> int: + pts = pts_info_from_update(update) + return pts.pts - pts.pts_count if pts else 0 + + sorted_updates = list(sorted(combined.updates, key=update_sort_key)) + + any_pts_applied = False + reset_deadlines_for = set() + for update in sorted_updates: + entry, applied = self.apply_pts_info(update) + if entry is not None: + reset_deadlines_for.add(entry) + if applied is not None: + result.append(applied) + any_pts_applied |= entry is not None + + self.reset_deadlines(reset_deadlines_for, next_updates_deadline()) + + if any_pts_applied: + if __debug__: + self._trace("Updating seq as local pts was updated too") + self.date = datetime.datetime.fromtimestamp( + combined.date, tz=datetime.timezone.utc + ) + if combined.seq != NO_SEQ: + self.seq = combined.seq + + if self.possible_gaps: + if __debug__: + self._trace( + "Trying to re-apply %r possible gaps", len(self.possible_gaps) + ) + + for key in list(self.possible_gaps.keys()): + self.possible_gaps[key].updates.sort(key=update_sort_key) + + for _ in range(len(self.possible_gaps[key].updates)): + update = self.possible_gaps[key].updates.pop(0) + _, applied = self.apply_pts_info(update) + if applied is not None: + result.append(applied) + if __debug__: + self._trace( + "Resolved gap with %r: %s", + pts_info_from_update(applied), + applied, + ) + + self.possible_gaps = { + entry: gap for entry, gap in self.possible_gaps.items() if gap.updates + } + + return result, combined.users, combined.chats + + def apply_pts_info( + self, + update: abcs.Update, + ) -> Tuple[Optional[Entry], Optional[abcs.Update]]: + if isinstance(update, types.UpdateChannelTooLong): + self.try_begin_get_diff(update.channel_id, "received updateChannelTooLong") + return None, None + + pts = pts_info_from_update(update) + if not pts: + if __debug__: + self._trace( + "No pts in update, so it can be applied in any order: %s", update + ) + return None, update + + if pts.entry in self.getting_diff_for: + if __debug__: + self._trace( + "Skipping update with %r as its difference is being fetched", pts + ) + return pts.entry, None + + if state := self.map.get(pts.entry): + local_pts = state.pts + if local_pts + pts.pts_count > pts.pts: + if __debug__: + self._trace( + "Skipping update since local pts %r > %r: %s", + local_pts, + pts, + update, + ) + return pts.entry, None + elif local_pts + pts.pts_count < pts.pts: + # TODO store chats too? + if __debug__: + self._trace( + "Possible gap since local pts %r < %r: %s", + local_pts, + pts, + update, + ) + if pts.entry not in self.possible_gaps: + self.possible_gaps[pts.entry] = PossibleGap( + deadline=asyncio.get_running_loop().time() + + POSSIBLE_GAP_TIMEOUT, + updates=[], + ) + + self.possible_gaps[pts.entry].updates.append(update) + return pts.entry, None + else: + if __debug__: + self._trace( + "Applying update pts since local pts %r = %r: %s", + local_pts, + pts, + update, + ) + + if pts.entry not in self.map: + self.map[pts.entry] = State( + pts=0, + deadline=next_updates_deadline(), + ) + self.map[pts.entry].pts = pts.pts + + return pts.entry, update + + def get_difference(self) -> Optional[Request[abcs.updates.Difference]]: + for entry in (ENTRY_ACCOUNT, ENTRY_SECRET): + if entry in self.getting_diff_for: + if entry not in self.map: + raise RuntimeError( + "Should not try to get difference for an entry without known state" + ) + + gd = functions.updates.get_difference( + pts=self.map[ENTRY_ACCOUNT].pts, + pts_total_limit=None, + date=int(self.date.timestamp()), + qts=self.map[ENTRY_SECRET].pts + if ENTRY_SECRET in self.map + else NO_SEQ, + ) + if __debug__: + self._trace("Requesting account difference %s", gd) + return gd + + return None + + def apply_difference( + self, + diff: abcs.updates.Difference, + chat_hashes: ChatHashCache, + ) -> Tuple[List[abcs.Update], List[abcs.User], List[abcs.Chat]]: + if __debug__: + self._trace("Applying account difference %s", diff) + + finish: bool + result: Tuple[List[abcs.Update], List[abcs.User], List[abcs.Chat]] + if isinstance(diff, types.updates.DifferenceEmpty): + finish = True + self.date = datetime.datetime.fromtimestamp( + diff.date, tz=datetime.timezone.utc + ) + self.seq = diff.seq + result = [], [], [] + elif isinstance(diff, types.updates.Difference): + chat_hashes.extend(diff.users, diff.chats) + finish = True + result = self.apply_difference_type(diff, chat_hashes) + elif isinstance(diff, types.updates.DifferenceSlice): + chat_hashes.extend(diff.users, diff.chats) + finish = False + result = self.apply_difference_type( + types.updates.Difference( + new_messages=diff.new_messages, + new_encrypted_messages=diff.new_encrypted_messages, + other_updates=diff.other_updates, + chats=diff.chats, + users=diff.users, + state=diff.intermediate_state, + ), + chat_hashes, + ) + elif isinstance(diff, types.updates.DifferenceTooLong): + finish = True + self.map[ENTRY_ACCOUNT].pts = diff.pts + result = [], [], [] + else: + raise RuntimeError("unexpected case") + + if finish: + account = ENTRY_ACCOUNT in self.getting_diff_for + secret = ENTRY_SECRET in self.getting_diff_for + + if not account and not secret: + raise RuntimeError( + "Should not be applying the difference when neither account or secret was diff was active" + ) + + if account: + self.end_get_diff(ENTRY_ACCOUNT) + if secret: + self.end_get_diff(ENTRY_SECRET) + + return result + + def apply_difference_type( + self, + diff: types.updates.Difference, + chat_hashes: ChatHashCache, + ) -> Tuple[List[abcs.Update], List[abcs.User], List[abcs.Chat]]: + state = diff.state + assert isinstance(state, types.updates.State) + self.map[ENTRY_ACCOUNT].pts = state.pts + self.map[ENTRY_SECRET].pts = state.qts + self.date = datetime.datetime.fromtimestamp( + state.date, tz=datetime.timezone.utc + ) + self.seq = state.seq + + updates, users, chats = self.process_updates( + types.Updates( + updates=diff.other_updates, + users=diff.users, + chats=diff.chats, + date=int(epoch().timestamp()), + seq=NO_SEQ, + ), + chat_hashes, + ) + + updates.extend( + types.UpdateNewMessage( + message=m, + pts=NO_PTS, + pts_count=0, + ) + for m in diff.new_messages + ) + updates.extend( + types.UpdateNewEncryptedMessage( + message=m, + qts=NO_PTS, + ) + for m in diff.new_encrypted_messages + ) + + return updates, users, chats + + def get_channel_difference( + self, + chat_hashes: ChatHashCache, + ) -> Optional[Request[abcs.updates.ChannelDifference]]: + for entry in self.getting_diff_for: + if entry not in (ENTRY_ACCOUNT, ENTRY_SECRET): + id = int(entry) + break + else: + return None + + packed = chat_hashes.get(id) + if packed: + assert packed.access_hash is not None + channel = types.InputChannel( + channel_id=packed.id, + access_hash=packed.access_hash, + ) + if state := self.map.get(entry): + gd = functions.updates.get_channel_difference( + force=False, + channel=channel, + filter=types.ChannelMessagesFilterEmpty(), + pts=state.pts, + limit=BOT_CHANNEL_DIFF_LIMIT + if chat_hashes.is_self_bot + else USER_CHANNEL_DIFF_LIMIT, + ) + if __debug__: + self._trace("Requesting channel difference %s", gd) + return gd + else: + raise RuntimeError( + "Should not try to get difference for an entry without known state" + ) + else: + self.end_get_diff(entry) + self.map.pop(entry, None) + return None + + def apply_channel_difference( + self, + channel_id: int, + diff: abcs.updates.ChannelDifference, + chat_hashes: ChatHashCache, + ) -> Tuple[List[abcs.Update], List[abcs.User], List[abcs.Chat]]: + entry: Entry = channel_id + if __debug__: + self._trace("Applying channel difference for %r: %s", entry, diff) + + self.possible_gaps.pop(entry, None) + + if isinstance(diff, types.updates.ChannelDifferenceEmpty): + assert diff.final + self.end_get_diff(entry) + self.map[entry].pts = diff.pts + return [], [], [] + elif isinstance(diff, types.updates.ChannelDifferenceTooLong): + chat_hashes.extend(diff.users, diff.chats) + + assert diff.final + if isinstance(diff.dialog, types.Dialog): + assert diff.dialog.pts is not None + self.map[entry].pts = diff.dialog.pts + else: + raise RuntimeError("unexpected type on ChannelDifferenceTooLong") + self.reset_channel_deadline(channel_id, diff.timeout) + return [], [], [] + elif isinstance(diff, types.updates.ChannelDifference): + chat_hashes.extend(diff.users, diff.chats) + + if diff.final: + self.end_get_diff(entry) + + self.map[entry].pts = diff.pts + updates, users, chats = self.process_updates( + types.Updates( + updates=diff.other_updates, + users=diff.users, + chats=diff.chats, + date=int(epoch().timestamp()), + seq=NO_SEQ, + ), + chat_hashes, + ) + + updates.extend( + types.UpdateNewChannelMessage( + message=m, + pts=NO_PTS, + pts_count=0, + ) + for m in diff.new_messages + ) + self.reset_channel_deadline(channel_id, None) + + return updates, users, chats + else: + raise RuntimeError("unexpected case") + + def end_channel_difference( + self, channel_id: int, reason: PrematureEndReason + ) -> None: + entry: Entry = channel_id + if __debug__: + self._trace("Ending channel difference for %r because %s", entry, reason) + + if reason == PrematureEndReason.TEMPORARY_SERVER_ISSUES: + self.possible_gaps.pop(entry, None) + self.end_get_diff(entry) + elif reason == PrematureEndReason.BANNED: + self.possible_gaps.pop(entry, None) + self.end_get_diff(entry) + del self.map[entry] + else: + raise RuntimeError("Unknown reason to end channel difference")