mirror of
				https://github.com/LonamiWebs/Telethon.git
				synced 2025-11-01 00:17:47 +03:00 
			
		
		
		
	Implement MessageBox
This commit is contained in:
		
							parent
							
								
									c77c10b48f
								
							
						
					
					
						commit
						69d7941852
					
				|  | @ -1,8 +1,8 @@ | |||
| from typing import Optional, Tuple | ||||
| from typing import Optional | ||||
| 
 | ||||
| from ...tl import abcs, types | ||||
| from ..chat.hash_cache import ChatHashCache | ||||
| from .defs import ACCOUNT_WIDE, NO_SEQ, SECRET_CHATS, Gap | ||||
| from .defs import ENTRY_ACCOUNT, ENTRY_SECRET, NO_SEQ, Gap, PtsInfo | ||||
| 
 | ||||
| 
 | ||||
| def updates_(updates: types.Updates) -> types.UpdatesCombined: | ||||
|  | @ -180,64 +180,64 @@ def message_channel_id(message: abcs.Message) -> Optional[int]: | |||
|     return peer.channel_id if isinstance(peer, types.PeerChannel) else None | ||||
| 
 | ||||
| 
 | ||||
| def pts_info_from_update(update: abcs.Update) -> Optional[Tuple[int | str, int, int]]: | ||||
| def pts_info_from_update(update: abcs.Update) -> Optional[PtsInfo]: | ||||
|     if isinstance(update, types.UpdateNewMessage): | ||||
|         assert not isinstance(message_peer(update.message), types.PeerChannel) | ||||
|         return ACCOUNT_WIDE, update.pts, update.pts_count | ||||
|         return PtsInfo(ENTRY_ACCOUNT, update.pts, update.pts_count) | ||||
|     elif isinstance(update, types.UpdateDeleteMessages): | ||||
|         return ACCOUNT_WIDE, update.pts, update.pts_count | ||||
|         return PtsInfo(ENTRY_ACCOUNT, update.pts, update.pts_count) | ||||
|     elif isinstance(update, types.UpdateNewEncryptedMessage): | ||||
|         return SECRET_CHATS, update.qts, 1 | ||||
|         return PtsInfo(ENTRY_SECRET, update.qts, 1) | ||||
|     elif isinstance(update, types.UpdateReadHistoryInbox): | ||||
|         assert not isinstance(update.peer, types.PeerChannel) | ||||
|         return ACCOUNT_WIDE, update.pts, update.pts_count | ||||
|         return PtsInfo(ENTRY_ACCOUNT, update.pts, update.pts_count) | ||||
|     elif isinstance(update, types.UpdateReadHistoryOutbox): | ||||
|         assert not isinstance(update.peer, types.PeerChannel) | ||||
|         return ACCOUNT_WIDE, update.pts, update.pts_count | ||||
|         return PtsInfo(ENTRY_ACCOUNT, update.pts, update.pts_count) | ||||
|     elif isinstance(update, types.UpdateWebPage): | ||||
|         return ACCOUNT_WIDE, update.pts, update.pts_count | ||||
|         return PtsInfo(ENTRY_ACCOUNT, update.pts, update.pts_count) | ||||
|     elif isinstance(update, types.UpdateReadMessagesContents): | ||||
|         return ACCOUNT_WIDE, update.pts, update.pts_count | ||||
|         return PtsInfo(ENTRY_ACCOUNT, update.pts, update.pts_count) | ||||
|     elif isinstance(update, types.UpdateChannelTooLong): | ||||
|         if update.pts is not None: | ||||
|             return update.channel_id, update.pts, 0 | ||||
|             return PtsInfo(update.channel_id, update.pts, 0) | ||||
|         else: | ||||
|             return None | ||||
|     elif isinstance(update, types.UpdateNewChannelMessage): | ||||
|         channel_id = message_channel_id(update.message) | ||||
|         if channel_id is not None: | ||||
|             return channel_id, update.pts, update.pts_count | ||||
|             return PtsInfo(channel_id, update.pts, update.pts_count) | ||||
|         else: | ||||
|             return None | ||||
|     elif isinstance(update, types.UpdateReadChannelInbox): | ||||
|         return update.channel_id, update.pts, 0 | ||||
|         return PtsInfo(update.channel_id, update.pts, 0) | ||||
|     elif isinstance(update, types.UpdateDeleteChannelMessages): | ||||
|         return update.channel_id, update.pts, update.pts_count | ||||
|         return PtsInfo(update.channel_id, update.pts, update.pts_count) | ||||
|     elif isinstance(update, types.UpdateEditChannelMessage): | ||||
|         channel_id = message_channel_id(update.message) | ||||
|         if channel_id is not None: | ||||
|             return channel_id, update.pts, update.pts_count | ||||
|             return PtsInfo(channel_id, update.pts, update.pts_count) | ||||
|         else: | ||||
|             return None | ||||
|     elif isinstance(update, types.UpdateEditMessage): | ||||
|         assert not isinstance(message_peer(update.message), types.PeerChannel) | ||||
|         return ACCOUNT_WIDE, update.pts, update.pts_count | ||||
|         return PtsInfo(ENTRY_ACCOUNT, update.pts, update.pts_count) | ||||
|     elif isinstance(update, types.UpdateChannelWebPage): | ||||
|         return update.channel_id, update.pts, update.pts_count | ||||
|         return PtsInfo(update.channel_id, update.pts, update.pts_count) | ||||
|     elif isinstance(update, types.UpdateFolderPeers): | ||||
|         return ACCOUNT_WIDE, update.pts, update.pts_count | ||||
|         return PtsInfo(ENTRY_ACCOUNT, update.pts, update.pts_count) | ||||
|     elif isinstance(update, types.UpdatePinnedMessages): | ||||
|         assert not isinstance(update.peer, types.PeerChannel) | ||||
|         return ACCOUNT_WIDE, update.pts, update.pts_count | ||||
|         return PtsInfo(ENTRY_ACCOUNT, update.pts, update.pts_count) | ||||
|     elif isinstance(update, types.UpdatePinnedChannelMessages): | ||||
|         return update.channel_id, update.pts, update.pts_count | ||||
|         return PtsInfo(update.channel_id, update.pts, update.pts_count) | ||||
|     elif isinstance(update, types.UpdateChatParticipant): | ||||
|         return SECRET_CHATS, update.qts, 0 | ||||
|         return PtsInfo(ENTRY_SECRET, update.qts, 0) | ||||
|     elif isinstance(update, types.UpdateChannelParticipant): | ||||
|         return SECRET_CHATS, update.qts, 0 | ||||
|         return PtsInfo(ENTRY_SECRET, update.qts, 0) | ||||
|     elif isinstance(update, types.UpdateBotStopped): | ||||
|         return SECRET_CHATS, update.qts, 0 | ||||
|         return PtsInfo(ENTRY_SECRET, update.qts, 0) | ||||
|     elif isinstance(update, types.UpdateBotChatInviteRequester): | ||||
|         return SECRET_CHATS, update.qts, 0 | ||||
|         return PtsInfo(ENTRY_SECRET, update.qts, 0) | ||||
|     else: | ||||
|         return None | ||||
|  |  | |||
|  | @ -1,3 +1,133 @@ | |||
| import logging | ||||
| from enum import Enum | ||||
| from typing import List, Literal, Optional, Union | ||||
| 
 | ||||
| from ...tl import abcs | ||||
| 
 | ||||
| 
 | ||||
| class DataCenter: | ||||
|     __slots__ = ("id", "ip", "port", "auth") | ||||
| 
 | ||||
|     def __init__(self, *, id: int, ip: str, port: int, auth: Optional[bytes]) -> None: | ||||
|         self.id = id | ||||
|         self.ip = ip | ||||
|         self.port = port | ||||
|         self.auth = auth | ||||
| 
 | ||||
| 
 | ||||
| class User: | ||||
|     __slots__ = ("id", "dc", "bot") | ||||
| 
 | ||||
|     def __init__(self, *, id: int, dc: int, bot: bool) -> None: | ||||
|         self.id = id | ||||
|         self.dc = dc | ||||
|         self.bot = bot | ||||
| 
 | ||||
| 
 | ||||
| class ChannelState: | ||||
|     __slots__ = ("id", "pts") | ||||
| 
 | ||||
|     def __init__(self, *, id: int, pts: int) -> None: | ||||
|         self.id = id | ||||
|         self.pts = pts | ||||
| 
 | ||||
| 
 | ||||
| class UpdateState: | ||||
|     __slots__ = ( | ||||
|         "pts", | ||||
|         "qts", | ||||
|         "date", | ||||
|         "seq", | ||||
|         "channels", | ||||
|     ) | ||||
| 
 | ||||
|     def __init__( | ||||
|         self, | ||||
|         *, | ||||
|         pts: int, | ||||
|         qts: int, | ||||
|         date: int, | ||||
|         seq: int, | ||||
|         channels: List[ChannelState], | ||||
|     ) -> None: | ||||
|         self.pts = pts | ||||
|         self.qts = qts | ||||
|         self.date = date | ||||
|         self.seq = seq | ||||
|         self.channels = channels | ||||
| 
 | ||||
| 
 | ||||
| class Session: | ||||
|     __slots__ = ("dcs", "user", "state") | ||||
| 
 | ||||
|     def __init__( | ||||
|         self, | ||||
|         *, | ||||
|         dcs: List[DataCenter], | ||||
|         user: Optional[User], | ||||
|         state: Optional[UpdateState], | ||||
|     ): | ||||
|         self.dcs = dcs | ||||
|         self.user = user | ||||
|         self.state = state | ||||
| 
 | ||||
| 
 | ||||
| class PtsInfo: | ||||
|     __slots__ = ("pts", "pts_count", "entry") | ||||
| 
 | ||||
|     def __init__(self, entry: "Entry", pts: int, pts_count: int) -> None: | ||||
|         self.pts = pts | ||||
|         self.pts_count = pts_count | ||||
|         self.entry = entry | ||||
| 
 | ||||
|     def __repr__(self) -> str: | ||||
|         return ( | ||||
|             f"PtsInfo(pts={self.pts}, pts_count={self.pts_count}, entry={self.entry})" | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| class State: | ||||
|     __slots__ = ("pts", "deadline") | ||||
| 
 | ||||
|     def __init__( | ||||
|         self, | ||||
|         pts: int, | ||||
|         deadline: float, | ||||
|     ) -> None: | ||||
|         self.pts = pts | ||||
|         self.deadline = deadline | ||||
| 
 | ||||
|     def __repr__(self) -> str: | ||||
|         return f"State(pts={self.pts}, deadline={self.deadline})" | ||||
| 
 | ||||
| 
 | ||||
| class PossibleGap: | ||||
|     __slots__ = ("deadline", "updates") | ||||
| 
 | ||||
|     def __init__( | ||||
|         self, | ||||
|         deadline: float, | ||||
|         updates: List[abcs.Update], | ||||
|     ) -> None: | ||||
|         self.deadline = deadline | ||||
|         self.updates = updates | ||||
| 
 | ||||
|     def __repr__(self) -> str: | ||||
|         return ( | ||||
|             f"PossibleGap(deadline={self.deadline}, update_count={len(self.updates)})" | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| class PrematureEndReason(Enum): | ||||
|     TEMPORARY_SERVER_ISSUES = "tmp" | ||||
|     BANNED = "ban" | ||||
| 
 | ||||
| 
 | ||||
| class Gap(ValueError): | ||||
|     def __repr__(self) -> str: | ||||
|         return "Gap()" | ||||
| 
 | ||||
| 
 | ||||
| NO_SEQ = 0 | ||||
| 
 | ||||
| NO_PTS = 0 | ||||
|  | @ -11,9 +141,10 @@ POSSIBLE_GAP_TIMEOUT = 0.5 | |||
| # https://core.telegram.org/api/updates | ||||
| NO_UPDATES_TIMEOUT = 15 * 60 | ||||
| 
 | ||||
| ACCOUNT_WIDE = "ACCOUNT" | ||||
| SECRET_CHATS = "SECRET" | ||||
| ENTRY_ACCOUNT: Literal["ACCOUNT"] = "ACCOUNT" | ||||
| ENTRY_SECRET: Literal["SECRET"] = "SECRET" | ||||
| Entry = Union[Literal["ACCOUNT"], Literal["SECRET"], int] | ||||
| 
 | ||||
| 
 | ||||
| class Gap(ValueError): | ||||
|     pass | ||||
| # Python's logging doesn't define a TRACE level. Pick halfway between DEBUG and NOTSET. | ||||
| # We don't define a name for this as libraries shouldn't do that though. | ||||
| LOG_LEVEL_TRACE = (logging.DEBUG - logging.NOTSET) // 2 | ||||
|  |  | |||
							
								
								
									
										637
									
								
								client/src/telethon/_impl/session/message_box/messagebox.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										637
									
								
								client/src/telethon/_impl/session/message_box/messagebox.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,637 @@ | |||
| import asyncio | ||||
| import datetime | ||||
| import logging | ||||
| import time | ||||
| from typing import Dict, List, Optional, Set, Tuple | ||||
| 
 | ||||
| from ...tl import abcs, functions, types | ||||
| from ...tl.core.request import Request | ||||
| from ..chat.hash_cache import ChatHashCache | ||||
| from .adaptor import adapt, pts_info_from_update | ||||
| from .defs import ( | ||||
|     BOT_CHANNEL_DIFF_LIMIT, | ||||
|     ENTRY_ACCOUNT, | ||||
|     ENTRY_SECRET, | ||||
|     LOG_LEVEL_TRACE, | ||||
|     NO_PTS, | ||||
|     NO_SEQ, | ||||
|     NO_UPDATES_TIMEOUT, | ||||
|     POSSIBLE_GAP_TIMEOUT, | ||||
|     USER_CHANNEL_DIFF_LIMIT, | ||||
|     ChannelState, | ||||
|     Entry, | ||||
|     Gap, | ||||
|     PossibleGap, | ||||
|     PrematureEndReason, | ||||
|     State, | ||||
|     UpdateState, | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| def next_updates_deadline() -> float: | ||||
|     return asyncio.get_running_loop().time() + NO_UPDATES_TIMEOUT | ||||
| 
 | ||||
| 
 | ||||
| def epoch() -> datetime.datetime: | ||||
|     return datetime.datetime(*time.gmtime(0)[:6]).replace(tzinfo=datetime.timezone.utc) | ||||
| 
 | ||||
| 
 | ||||
| # https://core.telegram.org/api/updates#message-related-event-sequences. | ||||
| class MessageBox: | ||||
|     __slots__ = ( | ||||
|         "_log", | ||||
|         "map", | ||||
|         "date", | ||||
|         "seq", | ||||
|         "possible_gaps", | ||||
|         "getting_diff_for", | ||||
|         "next_deadline", | ||||
|     ) | ||||
| 
 | ||||
|     def __init__( | ||||
|         self, | ||||
|         *, | ||||
|         log: Optional[logging.Logger] = None, | ||||
|     ) -> None: | ||||
|         self._log = log or logging.getLogger("telethon.messagebox") | ||||
|         self.map: Dict[Entry, State] = {} | ||||
|         self.date = epoch() | ||||
|         self.seq = NO_SEQ | ||||
|         self.possible_gaps: Dict[Entry, PossibleGap] = {} | ||||
|         self.getting_diff_for: Set[Entry] = set() | ||||
|         self.next_deadline: Optional[Entry] = None | ||||
| 
 | ||||
|         if __debug__: | ||||
|             self._trace("MessageBox initialized") | ||||
| 
 | ||||
|     def _trace(self, msg: str, *args: object) -> None: | ||||
|         # Calls to trace can't really be removed beforehand without some dark magic. | ||||
|         # So every call to trace is prefixed with `if __debug__`` instead, to remove | ||||
|         # it when using `python -O`. Probably unnecessary, but it's nice to avoid | ||||
|         # paying the cost for something that is not used. | ||||
|         self._log.log( | ||||
|             LOG_LEVEL_TRACE, | ||||
|             "Current MessageBox state: seq = %r, date = %s, map = %r", | ||||
|             self.seq, | ||||
|             self.date.isoformat(), | ||||
|             self.map, | ||||
|         ) | ||||
|         self._log.log(LOG_LEVEL_TRACE, msg, *args) | ||||
| 
 | ||||
|     def load(self, state: UpdateState) -> None: | ||||
|         if __debug__: | ||||
|             self._trace( | ||||
|                 "Loading MessageBox with state = %r", | ||||
|                 state, | ||||
|             ) | ||||
| 
 | ||||
|         deadline = next_updates_deadline() | ||||
| 
 | ||||
|         self.map.clear() | ||||
|         if state.pts != NO_SEQ: | ||||
|             self.map[ENTRY_ACCOUNT] = State(pts=state.pts, deadline=deadline) | ||||
|         if state.qts != NO_SEQ: | ||||
|             self.map[ENTRY_SECRET] = State(pts=state.qts, deadline=deadline) | ||||
|         self.map.update( | ||||
|             (s.id, State(pts=s.pts, deadline=deadline)) for s in state.channels | ||||
|         ) | ||||
| 
 | ||||
|         self.date = datetime.datetime.fromtimestamp( | ||||
|             state.date, tz=datetime.timezone.utc | ||||
|         ) | ||||
|         self.seq = state.seq | ||||
|         self.possible_gaps.clear() | ||||
|         self.getting_diff_for.clear() | ||||
|         self.next_deadline = ENTRY_ACCOUNT | ||||
| 
 | ||||
|     def session_state(self) -> UpdateState: | ||||
|         return UpdateState( | ||||
|             pts=self.map[ENTRY_ACCOUNT].pts if ENTRY_ACCOUNT in self.map else NO_PTS, | ||||
|             qts=self.map[ENTRY_SECRET].pts if ENTRY_SECRET in self.map else NO_PTS, | ||||
|             date=int(self.date.timestamp()), | ||||
|             seq=self.seq, | ||||
|             channels=[ | ||||
|                 ChannelState(id=int(entry), pts=state.pts) | ||||
|                 for entry, state in self.map.items() | ||||
|                 if entry not in (ENTRY_ACCOUNT, ENTRY_SECRET) | ||||
|             ], | ||||
|         ) | ||||
| 
 | ||||
|     def is_empty(self) -> bool: | ||||
|         return (self.map.get(ENTRY_ACCOUNT) or NO_PTS) == NO_PTS | ||||
| 
 | ||||
|     def check_deadlines(self) -> float: | ||||
|         now = asyncio.get_running_loop().time() | ||||
| 
 | ||||
|         if self.getting_diff_for: | ||||
|             return now | ||||
| 
 | ||||
|         default_deadline = next_updates_deadline() | ||||
| 
 | ||||
|         if self.possible_gaps: | ||||
|             deadline = min( | ||||
|                 default_deadline, *(gap.deadline for gap in self.possible_gaps.values()) | ||||
|             ) | ||||
|         elif self.next_deadline in self.map: | ||||
|             deadline = min(default_deadline, self.map[self.next_deadline].deadline) | ||||
|         else: | ||||
|             deadline = default_deadline | ||||
| 
 | ||||
|         if now >= deadline: | ||||
|             self.getting_diff_for.update( | ||||
|                 entry | ||||
|                 for entry, gap in self.possible_gaps.items() | ||||
|                 if now >= gap.deadline | ||||
|             ) | ||||
|             self.getting_diff_for.update( | ||||
|                 entry for entry, state in self.map.items() if now >= state.deadline | ||||
|             ) | ||||
| 
 | ||||
|             if __debug__: | ||||
|                 self._trace( | ||||
|                     "Deadlines met, now getting diff for %r", self.getting_diff_for | ||||
|                 ) | ||||
| 
 | ||||
|             for entry in self.getting_diff_for: | ||||
|                 self.possible_gaps.pop(entry, None) | ||||
| 
 | ||||
|         return deadline | ||||
| 
 | ||||
|     def reset_deadlines(self, entries: Set[Entry], deadline: float) -> None: | ||||
|         if not entries: | ||||
|             return | ||||
| 
 | ||||
|         for entry in entries: | ||||
|             if entry not in self.map: | ||||
|                 raise RuntimeError( | ||||
|                     "Called reset_deadline on an entry for which we do not have state" | ||||
|                 ) | ||||
|             self.map[entry].deadline = deadline | ||||
| 
 | ||||
|         if self.next_deadline in entries: | ||||
|             self.next_deadline = min( | ||||
|                 self.map.items(), key=lambda entry_state: entry_state[1].deadline | ||||
|             )[0] | ||||
|         elif ( | ||||
|             self.next_deadline in self.map | ||||
|             and deadline < self.map[self.next_deadline].deadline | ||||
|         ): | ||||
|             self.next_deadline = entry | ||||
| 
 | ||||
|     def reset_channel_deadline(self, channel_id: int, timeout: Optional[float]) -> None: | ||||
|         self.reset_deadlines( | ||||
|             {channel_id}, | ||||
|             asyncio.get_running_loop().time() + (timeout or NO_UPDATES_TIMEOUT), | ||||
|         ) | ||||
| 
 | ||||
|     def set_state(self, state: abcs.updates.State) -> None: | ||||
|         if __debug__: | ||||
|             self._trace("Setting state %s", state) | ||||
| 
 | ||||
|         deadline = next_updates_deadline() | ||||
|         assert isinstance(state, types.updates.State) | ||||
|         self.map[ENTRY_ACCOUNT] = State(state.pts, deadline) | ||||
|         self.map[ENTRY_SECRET] = State(state.qts, deadline) | ||||
|         self.date = datetime.datetime.fromtimestamp( | ||||
|             state.date, tz=datetime.timezone.utc | ||||
|         ) | ||||
|         self.seq = state.seq | ||||
| 
 | ||||
|     def try_set_channel_state(self, id: int, pts: int) -> None: | ||||
|         if __debug__: | ||||
|             self._trace("Trying to set channel state for %r: %r", id, pts) | ||||
| 
 | ||||
|         if id not in self.map: | ||||
|             self.map[id] = State(pts=pts, deadline=next_updates_deadline()) | ||||
| 
 | ||||
|     def try_begin_get_diff(self, entry: Entry, reason: str) -> None: | ||||
|         if entry not in self.map: | ||||
|             if entry in self.possible_gaps: | ||||
|                 raise RuntimeError( | ||||
|                     "Should not have a possible_gap for an entry not in the state map" | ||||
|                 ) | ||||
|             return | ||||
| 
 | ||||
|         if __debug__: | ||||
|             self._trace("Marking %r as needing difference because %s", entry, reason) | ||||
|         self.getting_diff_for.add(entry) | ||||
|         self.possible_gaps.pop(entry, None) | ||||
| 
 | ||||
|     def end_get_diff(self, entry: Entry) -> None: | ||||
|         try: | ||||
|             self.getting_diff_for.remove(entry) | ||||
|         except KeyError: | ||||
|             raise RuntimeError( | ||||
|                 "Called end_get_diff on an entry which was not getting diff for" | ||||
|             ) | ||||
| 
 | ||||
|         self.reset_deadlines({entry}, next_updates_deadline()) | ||||
|         assert ( | ||||
|             entry not in self.possible_gaps | ||||
|         ), "gaps shouldn't be created while getting difference" | ||||
| 
 | ||||
|     def ensure_known_peer_hashes( | ||||
|         self, | ||||
|         updates: abcs.Updates, | ||||
|         chat_hashes: ChatHashCache, | ||||
|     ) -> None: | ||||
|         if not chat_hashes.extend_from_updates(updates): | ||||
|             can_recover = ( | ||||
|                 not isinstance(updates, types.UpdateShort) | ||||
|                 or pts_info_from_update(updates.update) is not None | ||||
|             ) | ||||
|             if can_recover: | ||||
|                 raise Gap | ||||
| 
 | ||||
|     # https://core.telegram.org/api/updates | ||||
|     def process_updates( | ||||
|         self, | ||||
|         updates: abcs.Updates, | ||||
|         chat_hashes: ChatHashCache, | ||||
|     ) -> Tuple[List[abcs.Update], List[abcs.User], List[abcs.Chat]]: | ||||
|         result: List[abcs.Update] = [] | ||||
|         combined = adapt(updates, chat_hashes) | ||||
| 
 | ||||
|         if __debug__: | ||||
|             self._trace( | ||||
|                 "Processing updates with seq = %r, seq_start = %r, date = %r: %s", | ||||
|                 combined.seq, | ||||
|                 combined.seq_start, | ||||
|                 combined.date, | ||||
|                 updates, | ||||
|             ) | ||||
| 
 | ||||
|         if combined.seq_start != NO_SEQ: | ||||
|             if self.seq + 1 > combined.seq_start: | ||||
|                 if __debug__: | ||||
|                     self._trace( | ||||
|                         "Skipping updates as they should have already been handled" | ||||
|                     ) | ||||
|                 return result, combined.users, combined.chats | ||||
|             elif self.seq + 1 < combined.seq_start: | ||||
|                 self.try_begin_get_diff(ENTRY_ACCOUNT, "detected gap") | ||||
|                 raise Gap | ||||
| 
 | ||||
|         def update_sort_key(update: abcs.Update) -> int: | ||||
|             pts = pts_info_from_update(update) | ||||
|             return pts.pts - pts.pts_count if pts else 0 | ||||
| 
 | ||||
|         sorted_updates = list(sorted(combined.updates, key=update_sort_key)) | ||||
| 
 | ||||
|         any_pts_applied = False | ||||
|         reset_deadlines_for = set() | ||||
|         for update in sorted_updates: | ||||
|             entry, applied = self.apply_pts_info(update) | ||||
|             if entry is not None: | ||||
|                 reset_deadlines_for.add(entry) | ||||
|             if applied is not None: | ||||
|                 result.append(applied) | ||||
|                 any_pts_applied |= entry is not None | ||||
| 
 | ||||
|         self.reset_deadlines(reset_deadlines_for, next_updates_deadline()) | ||||
| 
 | ||||
|         if any_pts_applied: | ||||
|             if __debug__: | ||||
|                 self._trace("Updating seq as local pts was updated too") | ||||
|             self.date = datetime.datetime.fromtimestamp( | ||||
|                 combined.date, tz=datetime.timezone.utc | ||||
|             ) | ||||
|             if combined.seq != NO_SEQ: | ||||
|                 self.seq = combined.seq | ||||
| 
 | ||||
|         if self.possible_gaps: | ||||
|             if __debug__: | ||||
|                 self._trace( | ||||
|                     "Trying to re-apply %r possible gaps", len(self.possible_gaps) | ||||
|                 ) | ||||
| 
 | ||||
|             for key in list(self.possible_gaps.keys()): | ||||
|                 self.possible_gaps[key].updates.sort(key=update_sort_key) | ||||
| 
 | ||||
|                 for _ in range(len(self.possible_gaps[key].updates)): | ||||
|                     update = self.possible_gaps[key].updates.pop(0) | ||||
|                     _, applied = self.apply_pts_info(update) | ||||
|                     if applied is not None: | ||||
|                         result.append(applied) | ||||
|                         if __debug__: | ||||
|                             self._trace( | ||||
|                                 "Resolved gap with %r: %s", | ||||
|                                 pts_info_from_update(applied), | ||||
|                                 applied, | ||||
|                             ) | ||||
| 
 | ||||
|             self.possible_gaps = { | ||||
|                 entry: gap for entry, gap in self.possible_gaps.items() if gap.updates | ||||
|             } | ||||
| 
 | ||||
|         return result, combined.users, combined.chats | ||||
| 
 | ||||
|     def apply_pts_info( | ||||
|         self, | ||||
|         update: abcs.Update, | ||||
|     ) -> Tuple[Optional[Entry], Optional[abcs.Update]]: | ||||
|         if isinstance(update, types.UpdateChannelTooLong): | ||||
|             self.try_begin_get_diff(update.channel_id, "received updateChannelTooLong") | ||||
|             return None, None | ||||
| 
 | ||||
|         pts = pts_info_from_update(update) | ||||
|         if not pts: | ||||
|             if __debug__: | ||||
|                 self._trace( | ||||
|                     "No pts in update, so it can be applied in any order: %s", update | ||||
|                 ) | ||||
|             return None, update | ||||
| 
 | ||||
|         if pts.entry in self.getting_diff_for: | ||||
|             if __debug__: | ||||
|                 self._trace( | ||||
|                     "Skipping update with %r as its difference is being fetched", pts | ||||
|                 ) | ||||
|             return pts.entry, None | ||||
| 
 | ||||
|         if state := self.map.get(pts.entry): | ||||
|             local_pts = state.pts | ||||
|             if local_pts + pts.pts_count > pts.pts: | ||||
|                 if __debug__: | ||||
|                     self._trace( | ||||
|                         "Skipping update since local pts %r > %r: %s", | ||||
|                         local_pts, | ||||
|                         pts, | ||||
|                         update, | ||||
|                     ) | ||||
|                 return pts.entry, None | ||||
|             elif local_pts + pts.pts_count < pts.pts: | ||||
|                 # TODO store chats too? | ||||
|                 if __debug__: | ||||
|                     self._trace( | ||||
|                         "Possible gap since local pts %r < %r: %s", | ||||
|                         local_pts, | ||||
|                         pts, | ||||
|                         update, | ||||
|                     ) | ||||
|                 if pts.entry not in self.possible_gaps: | ||||
|                     self.possible_gaps[pts.entry] = PossibleGap( | ||||
|                         deadline=asyncio.get_running_loop().time() | ||||
|                         + POSSIBLE_GAP_TIMEOUT, | ||||
|                         updates=[], | ||||
|                     ) | ||||
| 
 | ||||
|                 self.possible_gaps[pts.entry].updates.append(update) | ||||
|                 return pts.entry, None | ||||
|             else: | ||||
|                 if __debug__: | ||||
|                     self._trace( | ||||
|                         "Applying update pts since local pts %r = %r: %s", | ||||
|                         local_pts, | ||||
|                         pts, | ||||
|                         update, | ||||
|                     ) | ||||
| 
 | ||||
|         if pts.entry not in self.map: | ||||
|             self.map[pts.entry] = State( | ||||
|                 pts=0, | ||||
|                 deadline=next_updates_deadline(), | ||||
|             ) | ||||
|         self.map[pts.entry].pts = pts.pts | ||||
| 
 | ||||
|         return pts.entry, update | ||||
| 
 | ||||
|     def get_difference(self) -> Optional[Request[abcs.updates.Difference]]: | ||||
|         for entry in (ENTRY_ACCOUNT, ENTRY_SECRET): | ||||
|             if entry in self.getting_diff_for: | ||||
|                 if entry not in self.map: | ||||
|                     raise RuntimeError( | ||||
|                         "Should not try to get difference for an entry without known state" | ||||
|                     ) | ||||
| 
 | ||||
|                 gd = functions.updates.get_difference( | ||||
|                     pts=self.map[ENTRY_ACCOUNT].pts, | ||||
|                     pts_total_limit=None, | ||||
|                     date=int(self.date.timestamp()), | ||||
|                     qts=self.map[ENTRY_SECRET].pts | ||||
|                     if ENTRY_SECRET in self.map | ||||
|                     else NO_SEQ, | ||||
|                 ) | ||||
|                 if __debug__: | ||||
|                     self._trace("Requesting account difference %s", gd) | ||||
|                 return gd | ||||
| 
 | ||||
|         return None | ||||
| 
 | ||||
|     def apply_difference( | ||||
|         self, | ||||
|         diff: abcs.updates.Difference, | ||||
|         chat_hashes: ChatHashCache, | ||||
|     ) -> Tuple[List[abcs.Update], List[abcs.User], List[abcs.Chat]]: | ||||
|         if __debug__: | ||||
|             self._trace("Applying account difference %s", diff) | ||||
| 
 | ||||
|         finish: bool | ||||
|         result: Tuple[List[abcs.Update], List[abcs.User], List[abcs.Chat]] | ||||
|         if isinstance(diff, types.updates.DifferenceEmpty): | ||||
|             finish = True | ||||
|             self.date = datetime.datetime.fromtimestamp( | ||||
|                 diff.date, tz=datetime.timezone.utc | ||||
|             ) | ||||
|             self.seq = diff.seq | ||||
|             result = [], [], [] | ||||
|         elif isinstance(diff, types.updates.Difference): | ||||
|             chat_hashes.extend(diff.users, diff.chats) | ||||
|             finish = True | ||||
|             result = self.apply_difference_type(diff, chat_hashes) | ||||
|         elif isinstance(diff, types.updates.DifferenceSlice): | ||||
|             chat_hashes.extend(diff.users, diff.chats) | ||||
|             finish = False | ||||
|             result = self.apply_difference_type( | ||||
|                 types.updates.Difference( | ||||
|                     new_messages=diff.new_messages, | ||||
|                     new_encrypted_messages=diff.new_encrypted_messages, | ||||
|                     other_updates=diff.other_updates, | ||||
|                     chats=diff.chats, | ||||
|                     users=diff.users, | ||||
|                     state=diff.intermediate_state, | ||||
|                 ), | ||||
|                 chat_hashes, | ||||
|             ) | ||||
|         elif isinstance(diff, types.updates.DifferenceTooLong): | ||||
|             finish = True | ||||
|             self.map[ENTRY_ACCOUNT].pts = diff.pts | ||||
|             result = [], [], [] | ||||
|         else: | ||||
|             raise RuntimeError("unexpected case") | ||||
| 
 | ||||
|         if finish: | ||||
|             account = ENTRY_ACCOUNT in self.getting_diff_for | ||||
|             secret = ENTRY_SECRET in self.getting_diff_for | ||||
| 
 | ||||
|             if not account and not secret: | ||||
|                 raise RuntimeError( | ||||
|                     "Should not be applying the difference when neither account or secret was diff was active" | ||||
|                 ) | ||||
| 
 | ||||
|             if account: | ||||
|                 self.end_get_diff(ENTRY_ACCOUNT) | ||||
|             if secret: | ||||
|                 self.end_get_diff(ENTRY_SECRET) | ||||
| 
 | ||||
|         return result | ||||
| 
 | ||||
|     def apply_difference_type( | ||||
|         self, | ||||
|         diff: types.updates.Difference, | ||||
|         chat_hashes: ChatHashCache, | ||||
|     ) -> Tuple[List[abcs.Update], List[abcs.User], List[abcs.Chat]]: | ||||
|         state = diff.state | ||||
|         assert isinstance(state, types.updates.State) | ||||
|         self.map[ENTRY_ACCOUNT].pts = state.pts | ||||
|         self.map[ENTRY_SECRET].pts = state.qts | ||||
|         self.date = datetime.datetime.fromtimestamp( | ||||
|             state.date, tz=datetime.timezone.utc | ||||
|         ) | ||||
|         self.seq = state.seq | ||||
| 
 | ||||
|         updates, users, chats = self.process_updates( | ||||
|             types.Updates( | ||||
|                 updates=diff.other_updates, | ||||
|                 users=diff.users, | ||||
|                 chats=diff.chats, | ||||
|                 date=int(epoch().timestamp()), | ||||
|                 seq=NO_SEQ, | ||||
|             ), | ||||
|             chat_hashes, | ||||
|         ) | ||||
| 
 | ||||
|         updates.extend( | ||||
|             types.UpdateNewMessage( | ||||
|                 message=m, | ||||
|                 pts=NO_PTS, | ||||
|                 pts_count=0, | ||||
|             ) | ||||
|             for m in diff.new_messages | ||||
|         ) | ||||
|         updates.extend( | ||||
|             types.UpdateNewEncryptedMessage( | ||||
|                 message=m, | ||||
|                 qts=NO_PTS, | ||||
|             ) | ||||
|             for m in diff.new_encrypted_messages | ||||
|         ) | ||||
| 
 | ||||
|         return updates, users, chats | ||||
| 
 | ||||
|     def get_channel_difference( | ||||
|         self, | ||||
|         chat_hashes: ChatHashCache, | ||||
|     ) -> Optional[Request[abcs.updates.ChannelDifference]]: | ||||
|         for entry in self.getting_diff_for: | ||||
|             if entry not in (ENTRY_ACCOUNT, ENTRY_SECRET): | ||||
|                 id = int(entry) | ||||
|                 break | ||||
|         else: | ||||
|             return None | ||||
| 
 | ||||
|         packed = chat_hashes.get(id) | ||||
|         if packed: | ||||
|             assert packed.access_hash is not None | ||||
|             channel = types.InputChannel( | ||||
|                 channel_id=packed.id, | ||||
|                 access_hash=packed.access_hash, | ||||
|             ) | ||||
|             if state := self.map.get(entry): | ||||
|                 gd = functions.updates.get_channel_difference( | ||||
|                     force=False, | ||||
|                     channel=channel, | ||||
|                     filter=types.ChannelMessagesFilterEmpty(), | ||||
|                     pts=state.pts, | ||||
|                     limit=BOT_CHANNEL_DIFF_LIMIT | ||||
|                     if chat_hashes.is_self_bot | ||||
|                     else USER_CHANNEL_DIFF_LIMIT, | ||||
|                 ) | ||||
|                 if __debug__: | ||||
|                     self._trace("Requesting channel difference %s", gd) | ||||
|                 return gd | ||||
|             else: | ||||
|                 raise RuntimeError( | ||||
|                     "Should not try to get difference for an entry without known state" | ||||
|                 ) | ||||
|         else: | ||||
|             self.end_get_diff(entry) | ||||
|             self.map.pop(entry, None) | ||||
|             return None | ||||
| 
 | ||||
|     def apply_channel_difference( | ||||
|         self, | ||||
|         channel_id: int, | ||||
|         diff: abcs.updates.ChannelDifference, | ||||
|         chat_hashes: ChatHashCache, | ||||
|     ) -> Tuple[List[abcs.Update], List[abcs.User], List[abcs.Chat]]: | ||||
|         entry: Entry = channel_id | ||||
|         if __debug__: | ||||
|             self._trace("Applying channel difference for %r: %s", entry, diff) | ||||
| 
 | ||||
|         self.possible_gaps.pop(entry, None) | ||||
| 
 | ||||
|         if isinstance(diff, types.updates.ChannelDifferenceEmpty): | ||||
|             assert diff.final | ||||
|             self.end_get_diff(entry) | ||||
|             self.map[entry].pts = diff.pts | ||||
|             return [], [], [] | ||||
|         elif isinstance(diff, types.updates.ChannelDifferenceTooLong): | ||||
|             chat_hashes.extend(diff.users, diff.chats) | ||||
| 
 | ||||
|             assert diff.final | ||||
|             if isinstance(diff.dialog, types.Dialog): | ||||
|                 assert diff.dialog.pts is not None | ||||
|                 self.map[entry].pts = diff.dialog.pts | ||||
|             else: | ||||
|                 raise RuntimeError("unexpected type on ChannelDifferenceTooLong") | ||||
|             self.reset_channel_deadline(channel_id, diff.timeout) | ||||
|             return [], [], [] | ||||
|         elif isinstance(diff, types.updates.ChannelDifference): | ||||
|             chat_hashes.extend(diff.users, diff.chats) | ||||
| 
 | ||||
|             if diff.final: | ||||
|                 self.end_get_diff(entry) | ||||
| 
 | ||||
|             self.map[entry].pts = diff.pts | ||||
|             updates, users, chats = self.process_updates( | ||||
|                 types.Updates( | ||||
|                     updates=diff.other_updates, | ||||
|                     users=diff.users, | ||||
|                     chats=diff.chats, | ||||
|                     date=int(epoch().timestamp()), | ||||
|                     seq=NO_SEQ, | ||||
|                 ), | ||||
|                 chat_hashes, | ||||
|             ) | ||||
| 
 | ||||
|             updates.extend( | ||||
|                 types.UpdateNewChannelMessage( | ||||
|                     message=m, | ||||
|                     pts=NO_PTS, | ||||
|                     pts_count=0, | ||||
|                 ) | ||||
|                 for m in diff.new_messages | ||||
|             ) | ||||
|             self.reset_channel_deadline(channel_id, None) | ||||
| 
 | ||||
|             return updates, users, chats | ||||
|         else: | ||||
|             raise RuntimeError("unexpected case") | ||||
| 
 | ||||
|     def end_channel_difference( | ||||
|         self, channel_id: int, reason: PrematureEndReason | ||||
|     ) -> None: | ||||
|         entry: Entry = channel_id | ||||
|         if __debug__: | ||||
|             self._trace("Ending channel difference for %r because %s", entry, reason) | ||||
| 
 | ||||
|         if reason == PrematureEndReason.TEMPORARY_SERVER_ISSUES: | ||||
|             self.possible_gaps.pop(entry, None) | ||||
|             self.end_get_diff(entry) | ||||
|         elif reason == PrematureEndReason.BANNED: | ||||
|             self.possible_gaps.pop(entry, None) | ||||
|             self.end_get_diff(entry) | ||||
|             del self.map[entry] | ||||
|         else: | ||||
|             raise RuntimeError("Unknown reason to end channel difference") | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user