From f6df5d377c78f0dbc2559b14df159b9b28b0303d Mon Sep 17 00:00:00 2001 From: Lonami Exo Date: Tue, 18 Jan 2022 19:46:19 +0100 Subject: [PATCH] Begin reworking update handling Use a fixed-size queue instead of a callback to deal with updates. Port the message box and entity cache from grammers to start off with a clean design. Temporarily get rid of other cruft such as automatic pings or old catch up implementation. --- telethon/_client/telegrambaseclient.py | 32 +- telethon/_client/telegramclient.py | 5 +- telethon/_client/updates.py | 292 +------------ telethon/_network/mtprotosender.py | 26 +- telethon/_updates/__init__.py | 2 + telethon/_updates/entitycache.py | 97 +++++ telethon/_updates/messagebox.py | 565 +++++++++++++++++++++++++ 7 files changed, 704 insertions(+), 315 deletions(-) create mode 100644 telethon/_updates/__init__.py create mode 100644 telethon/_updates/entitycache.py create mode 100644 telethon/_updates/messagebox.py diff --git a/telethon/_client/telegrambaseclient.py b/telethon/_client/telegrambaseclient.py index 62dc09b0..b6872dc8 100644 --- a/telethon/_client/telegrambaseclient.py +++ b/telethon/_client/telegrambaseclient.py @@ -11,10 +11,11 @@ import dataclasses from .. import version, __name__ as __base_name__, _tl from .._crypto import rsa -from .._misc import markdown, statecache, enums, helpers +from .._misc import markdown, enums, helpers from .._network import MTProtoSender, Connection, transports from .._sessions import Session, SQLiteSession, MemorySession from .._sessions.types import DataCenter, SessionState +from .._updates import EntityCache, MessageBox DEFAULT_DC_ID = 2 DEFAULT_IPV4_IP = '149.154.167.51' @@ -91,6 +92,7 @@ def init( flood_sleep_threshold: int = 60, # Update handling. receive_updates: bool = True, + max_queued_updates: int = 100, ): # Logging. if isinstance(base_logger, str): @@ -139,6 +141,13 @@ def init( self._flood_waited_requests = {} # prevent calls that would floodwait entirely self._parse_mode = markdown + # Update handling. + self._no_updates = not receive_updates + self._updates_queue = asyncio.Queue(maxsize=max_queued_updates) + self._updates_handle = None + self._message_box = MessageBox() + self._entity_cache = EntityCache() # required for proper update handling (to know when to getDifference) + # Connection parameters. if not api_id or not api_hash: raise ValueError( @@ -189,16 +198,13 @@ def init( delay=self._connect_retry_delay, auto_reconnect=self._auto_reconnect, connect_timeout=self._connect_timeout, - update_callback=self._handle_update, + updates_queue=self._updates_queue, ) # Cache ``{dc_id: (_ExportState, MTProtoSender)}`` for all borrowed senders. self._borrowed_senders = {} self._borrow_sender_lock = asyncio.Lock() - # Update handling. - self._no_updates = not receive_updates - def get_flood_sleep_threshold(self): return self._flood_sleep_threshold @@ -337,15 +343,6 @@ async def _disconnect_coro(self: 'TelegramClient'): # If any was borrowed self._borrowed_senders.clear() - # trio's nurseries would handle this for us, but this is asyncio. - # All tasks spawned in the background should properly be terminated. - if self._dispatching_updates_queue is None and self._updates_queue: - for task in self._updates_queue: - task.cancel() - - await asyncio.wait(self._updates_queue) - self._updates_queue.clear() - async def _disconnect(self: 'TelegramClient'): """ @@ -355,8 +352,11 @@ async def _disconnect(self: 'TelegramClient'): their job with the client is complete and we should clean it up all. """ await self._sender.disconnect() - await helpers._cancel(self._log[__name__], - updates_handle=self._updates_handle) + await helpers._cancel(self._log[__name__], updates_handle=self._updates_handle) + try: + await self._updates_handle + except asyncio.CancelledError: + pass async def _switch_dc(self: 'TelegramClient', new_dc): """ diff --git a/telethon/_client/telegramclient.py b/telethon/_client/telegramclient.py index 412e7b61..4504d589 100644 --- a/telethon/_client/telegramclient.py +++ b/telethon/_client/telegramclient.py @@ -2665,6 +2665,7 @@ class TelegramClient: flood_sleep_threshold: int = 60, # Update handling. receive_updates: bool = True, + max_queued_updates: int = 100, ): telegrambaseclient.init(**locals()) @@ -3509,10 +3510,6 @@ class TelegramClient: async def _clean_exported_senders(self: 'TelegramClient'): pass - @forward_call(updates._handle_update) - def _handle_update(self: 'TelegramClient', update): - pass - @forward_call(auth._update_session_state) async def _update_session_state(self, user, *, save=True): pass diff --git a/telethon/_client/updates.py b/telethon/_client/updates.py index 4b8e29cb..2460c561 100644 --- a/telethon/_client/updates.py +++ b/telethon/_client/updates.py @@ -79,295 +79,9 @@ def list_event_handlers(self: 'TelegramClient')\ return [(callback, event) for event, callback in self._event_builders] async def catch_up(self: 'TelegramClient'): - return - self._catching_up = True - try: - while True: - d = await self(_tl.fn.updates.GetDifference( - pts, date, 0 - )) - if isinstance(d, (_tl.updates.DifferenceSlice, - _tl.updates.Difference)): - if isinstance(d, _tl.updates.Difference): - state = d.state - else: - state = d.intermediate_state - - pts, date = state.pts, state.date - _handle_update(self, _tl.Updates( - users=d.users, - chats=d.chats, - date=state.date, - seq=state.seq, - updates=d.other_updates + [ - _tl.UpdateNewMessage(m, 0, 0) - for m in d.new_messages - ] - )) - - # TODO Implement upper limit (max_pts) - # We don't want to fetch updates we already know about. - # - # We may still get duplicates because the Difference - # contains a lot of updates and presumably only has - # the state for the last one, but at least we don't - # unnecessarily fetch too many. - # - # updates.getDifference's pts_total_limit seems to mean - # "how many pts is the request allowed to return", and - # if there is more than that, it returns "too long" (so - # there would be duplicate updates since we know about - # some). This can be used to detect collisions (i.e. - # it would return an update we have already seen). - else: - if isinstance(d, _tl.updates.DifferenceEmpty): - date = d.date - elif isinstance(d, _tl.updates.DifferenceTooLong): - pts = d.pts - break - except (ConnectionError, asyncio.CancelledError): - pass - finally: - self._catching_up = False - - -# It is important to not make _handle_update async because we rely on -# the order that the updates arrive in to update the pts and date to -# be always-increasing. There is also no need to make this async. -def _handle_update(self: 'TelegramClient', update): - if isinstance(update, (_tl.Updates, _tl.UpdatesCombined)): - entities = {utils.get_peer_id(x): x for x in - itertools.chain(update.users, update.chats)} - for u in update.updates: - _process_update(self, u, entities, update.updates) - elif isinstance(update, _tl.UpdateShort): - _process_update(self, update.update, {}, None) - else: - _process_update(self, update, {}, None) - - -def _process_update(self: 'TelegramClient', update, entities, others): - # This part is somewhat hot so we don't bother patching - # update with channel ID/its state. Instead we just pass - # arguments which is faster. - args = (update, entities, others, channel_id, None) - if self._dispatching_updates_queue is None: - task = asyncio.create_task(_dispatch_update(self, *args)) - self._updates_queue.add(task) - task.add_done_callback(lambda _: self._updates_queue.discard(task)) - else: - self._updates_queue.put_nowait(args) - if not self._dispatching_updates_queue.is_set(): - self._dispatching_updates_queue.set() - asyncio.create_task(_dispatch_queue_updates(self)) + pass async def _update_loop(self: 'TelegramClient'): - # Pings' ID don't really need to be secure, just "random" - rnd = lambda: random.randrange(-2**63, 2**63) while self.is_connected(): - try: - await asyncio.wait_for(self.run_until_disconnected(), timeout=60) - continue # We actually just want to act upon timeout - except asyncio.TimeoutError: - pass - except asyncio.CancelledError: - return - except Exception as e: - # Any disconnected exception should be ignored (or it may hint at - # another problem, leading to an infinite loop, hence the logging call) - self._log[__name__].info('Exception waiting on a disconnect: %s', e) - continue - - # Check if we have any exported senders to clean-up periodically - await self._clean_exported_senders() - - # Don't bother sending pings until the low-level connection is - # ready, otherwise a lot of pings will be batched to be sent upon - # reconnect, when we really don't care about that. - if not self._sender._transport_connected(): - continue - - # We also don't really care about their result. - # Just send them periodically. - try: - self._sender._keepalive_ping(rnd()) - except (ConnectionError, asyncio.CancelledError): - return - - # Entities are not saved when they are inserted because this is a rather expensive - # operation (default's sqlite3 takes ~0.1s to commit changes). Do it every minute - # instead. No-op if there's nothing new. - await self._session.save() - - # We need to send some content-related request at least hourly - # for Telegram to keep delivering updates, otherwise they will - # just stop even if we're connected. Do so every 30 minutes. - # - # TODO Call getDifference instead since it's more relevant - if time.time() - self._last_request > 30 * 60: - if not await self.is_user_authorized(): - # What can be the user doing for so - # long without being logged in...? - continue - - try: - await self(_tl.fn.updates.GetState()) - except (ConnectionError, asyncio.CancelledError): - return - -async def _dispatch_queue_updates(self: 'TelegramClient'): - while not self._updates_queue.empty(): - await _dispatch_update(self, *self._updates_queue.get_nowait()) - - self._dispatching_updates_queue.clear() - -async def _dispatch_update(self: 'TelegramClient', update, entities, others, channel_id, pts_date): - built = EventBuilderDict(self, update, entities, others) - - for builder, callback in self._event_builders: - event = built[type(builder)] - if not event: - continue - - if not builder.resolved: - await builder.resolve(self) - - filter = builder.filter(event) - if inspect.isawaitable(filter): - filter = await filter - if not filter: - continue - - try: - await callback(event) - except StopPropagation: - name = getattr(callback, '__name__', repr(callback)) - self._log[__name__].debug( - 'Event handler "%s" stopped chain of propagation ' - 'for event %s.', name, type(event).__name__ - ) - break - except Exception as e: - if not isinstance(e, asyncio.CancelledError) or self.is_connected(): - name = getattr(callback, '__name__', repr(callback)) - self._log[__name__].exception('Unhandled exception on %s', name) - -async def _dispatch_event(self: 'TelegramClient', event): - """ - Dispatches a single, out-of-order event. Used by `AlbumHack`. - """ - # We're duplicating a most logic from `_dispatch_update`, but all in - # the name of speed; we don't want to make it worse for all updates - # just because albums may need it. - for builder, callback in self._event_builders: - if isinstance(builder, Raw): - continue - if not isinstance(event, builder.Event): - continue - - if not builder.resolved: - await builder.resolve(self) - - filter = builder.filter(event) - if inspect.isawaitable(filter): - filter = await filter - if not filter: - continue - - try: - await callback(event) - except StopPropagation: - name = getattr(callback, '__name__', repr(callback)) - self._log[__name__].debug( - 'Event handler "%s" stopped chain of propagation ' - 'for event %s.', name, type(event).__name__ - ) - break - except Exception as e: - if not isinstance(e, asyncio.CancelledError) or self.is_connected(): - name = getattr(callback, '__name__', repr(callback)) - self._log[__name__].exception('Unhandled exception on %s', name) - -async def _get_difference(self: 'TelegramClient', update, entities, channel_id, pts_date): - """ - Get the difference for this `channel_id` if any, then load entities. - - Calls :tl:`updates.getDifference`, which fills the entities cache - (always done by `__call__`) and lets us know about the full entities. - """ - # Fetch since the last known pts/date before this update arrived, - # in order to fetch this update at full, including its entities. - self._log[__name__].debug('Getting difference for entities ' - 'for %r', update.__class__) - if channel_id: - # There are reports where we somehow call get channel difference - # with `InputPeerEmpty`. Check our assumptions to better debug - # this when it happens. - assert isinstance(channel_id, int), 'channel_id was {}, not int in {}'.format(type(channel_id), update) - try: - # Wrap the ID inside a peer to ensure we get a channel back. - where = await self.get_input_entity(_tl.PeerChannel(channel_id)) - except ValueError: - # There's a high chance that this fails, since - # we are getting the difference to fetch entities. - return - - if not pts_date: - # First-time, can't get difference. Get pts instead. - result = await self(_tl.fn.channels.GetFullChannel( - utils.get_input_channel(where) - )) - return - - result = await self(_tl.fn.updates.GetChannelDifference( - channel=where, - filter=_tl.ChannelMessagesFilterEmpty(), - pts=pts_date, # just pts - limit=100, - force=True - )) - else: - if not pts_date[0]: - # First-time, can't get difference. Get pts instead. - result = await self(_tl.fn.updates.GetState()) - return - - result = await self(_tl.fn.updates.GetDifference( - pts=pts_date[0], - date=pts_date[1], - qts=0 - )) - - if isinstance(result, (_tl.updates.Difference, - _tl.updates.DifferenceSlice, - _tl.updates.ChannelDifference, - _tl.updates.ChannelDifferenceTooLong)): - entities.update({ - utils.get_peer_id(x): x for x in - itertools.chain(result.users, result.chats) - }) - -class EventBuilderDict: - """ - Helper "dictionary" to return events from types and cache them. - """ - def __init__(self, client: 'TelegramClient', update, entities, others): - self.client = client - self.update = update - self.entities = entities - self.others = others - - def __getitem__(self, builder): - try: - return self.__dict__[builder] - except KeyError: - event = self.__dict__[builder] = builder.build( - self.update, self.others, self.client._session_state.user_id, self.entities or {}, self.client) - - if isinstance(event, EventCommon): - # TODO eww - event.original_update = self.update - event._entities = self.entities or {} - event._set_client(self.client) - - return event + updates = await self._updates_queue.get() + updates, users, chats = self._message_box.process_updates(updates, self._entity_cache) diff --git a/telethon/_network/mtprotosender.py b/telethon/_network/mtprotosender.py index b05137e3..fa58240f 100644 --- a/telethon/_network/mtprotosender.py +++ b/telethon/_network/mtprotosender.py @@ -1,6 +1,7 @@ import asyncio import collections import struct +import logging from . import authenticator from .._misc.messagepacker import MessagePacker @@ -20,6 +21,9 @@ from .._misc import helpers, utils from .. import _tl +UPDATE_BUFFER_FULL_WARN_DELAY = 15 * 60 + + class MTProtoSender: """ MTProto Mobile Protocol sender @@ -35,9 +39,8 @@ class MTProtoSender: A new authorization key will be generated on connection if no other key exists yet. """ - def __init__(self, *, loggers, - retries=5, delay=1, auto_reconnect=True, connect_timeout=None, - update_callback=None): + def __init__(self, *, loggers, updates_queue, + retries=5, delay=1, auto_reconnect=True, connect_timeout=None,): self._connection = None self._loggers = loggers self._log = loggers[__name__] @@ -45,7 +48,7 @@ class MTProtoSender: self._delay = delay self._auto_reconnect = auto_reconnect self._connect_timeout = connect_timeout - self._update_callback = update_callback + self._updates_queue = updates_queue self._connect_lock = asyncio.Lock() self._ping = None @@ -83,6 +86,9 @@ class MTProtoSender: # is received, but we may still need to resend their state on bad salts. self._last_acks = collections.deque(maxlen=10) + # Last time we warned about the update buffer being full + self._last_update_warn = -UPDATE_BUFFER_FULL_WARN_DELAY + # Jump table from response ID to method that handles it self._handlers = { RpcResult.CONSTRUCTOR_ID: self._handle_rpc_result, @@ -629,8 +635,16 @@ class MTProtoSender: return self._log.debug('Handling update %s', message.obj.__class__.__name__) - if self._update_callback: - self._update_callback(message.obj) + try: + self._updates_queue.put_nowait(message.obj) + except asyncio.QueueFull: + now = asyncio.get_running_loop().time() + if now - self._last_update_warn >= UPDATE_BUFFER_FULL_WARN_DELAY: + self._log.warning( + 'Cannot dispatch update because the buffer capacity of %d was reached', + self._updates_queue.maxsize + ) + self._last_update_warn = now async def _handle_pong(self, message): """ diff --git a/telethon/_updates/__init__.py b/telethon/_updates/__init__.py new file mode 100644 index 00000000..7951c9aa --- /dev/null +++ b/telethon/_updates/__init__.py @@ -0,0 +1,2 @@ +from .entitycache import EntityCache, PackedChat +from .messagebox import MessageBox diff --git a/telethon/_updates/entitycache.py b/telethon/_updates/entitycache.py new file mode 100644 index 00000000..176d2013 --- /dev/null +++ b/telethon/_updates/entitycache.py @@ -0,0 +1,97 @@ +import inspect +import itertools +from dataclasses import dataclass, field +from collections import namedtuple + +from .._misc import utils +from .. import _tl +from .._sessions.types import EntityType, Entity + + +class PackedChat(namedtuple('PackedChat', 'ty id hash')): + __slots__ = () + + @property + def is_user(self): + return self.ty in (EntityType.USER, EntityType.BOT) + + @property + def is_chat(self): + return self.ty in (EntityType.GROUP,) + + @property + def is_channel(self): + return self.ty in (EntityType.CHANNEL, EntityType.MEGAGROUP, EntityType.GIGAGROUP) + + def to_peer(self): + if self.is_user: + return _tl.PeerUser(user_id=self.id) + elif self.is_chat: + return _tl.PeerChat(chat_id=self.id) + elif self.is_channel: + return _tl.PeerChannel(channel_id=self.id) + + def to_input_peer(self): + if self.is_user: + return _tl.InputPeerUser(user_id=self.id, access_hash=self.hash) + elif self.is_chat: + return _tl.InputPeerChat(chat_id=self.id) + elif self.is_channel: + return _tl.InputPeerChannel(channel_id=self.id, access_hash=self.hash) + + def try_to_input_user(self): + if self.is_user: + return _tl.InputUser(user_id=self.id, access_hash=self.hash) + else: + return None + + def try_to_chat_id(self): + if self.is_chat: + return self.id + else: + return None + + def try_to_input_channel(self): + if self.is_channel: + return _tl.InputChannel(channel_id=self.id, access_hash=self.hash) + else: + return None + + def __str__(self): + return f'{chr(self.ty.value)}.{self.id}.{self.hash}' + + +@dataclass +class EntityCache: + hash_map: dict = field(default_factory=dict) # id -> (hash, ty) + self_id: int = None + self_bot: bool = False + + def set_self_user(self, id, bot): + self.self_id = id + self.self_bot = bot + + def get(self, id): + value = self.hash_map.get(id) + return PackedChat(ty=value[1], id=id, hash=value[0]) if value else None + + def extend(self, users, chats): + # See https://core.telegram.org/api/min for "issues" with "min constructors". + self.hash_map.update( + (u.id, ( + u.access_hash, + EntityType.BOT if u.bot else EntityType.USER, + )) + for u in users + if getattr(u, 'access_hash', None) and not u.min + ) + self.hash_map.update( + (c.id, ( + c.access_hash, + EntityType.MEGAGROUP if c.megagroup else ( + EntityType.GIGAGROUP if getattr(c, 'gigagroup', None) else EntityType.CHANNEL + ), + )) + for c in chats + if getattr(c, 'access_hash', None) and not getattr(c, 'min', None) + ) diff --git a/telethon/_updates/messagebox.py b/telethon/_updates/messagebox.py new file mode 100644 index 00000000..555451ad --- /dev/null +++ b/telethon/_updates/messagebox.py @@ -0,0 +1,565 @@ +""" +This module deals with correct handling of updates, including gaps, and knowing when the code +should "get difference" (the set of updates that the client should know by now minus the set +of updates that it actually knows). + +Each chat has its own [`Entry`] in the [`MessageBox`] (this `struct` is the "entry point"). +At any given time, the message box may be either getting difference for them (entry is in +[`MessageBox::getting_diff_for`]) or not. If not getting difference, a possible gap may be +found for the updates (entry is in [`MessageBox::possible_gaps`]). Otherwise, the entry is +on its happy path. + +Gaps are cleared when they are either resolved on their own (by waiting for a short time) +or because we got the difference for the corresponding entry. + +While there are entries for which their difference must be fetched, +[`MessageBox::check_deadlines`] will always return [`Instant::now`], since "now" is the time +to get the difference. +""" +import asyncio +from dataclasses import dataclass, field +from .._sessions.types import SessionState, ChannelState + + +# Telegram sends `seq` equal to `0` when "it doesn't matter", so we use that value too. +NO_SEQ = 0 + +# See https://core.telegram.org/method/updates.getChannelDifference. +BOT_CHANNEL_DIFF_LIMIT = 100000 +USER_CHANNEL_DIFF_LIMIT = 100 + +# > It may be useful to wait up to 0.5 seconds +POSSIBLE_GAP_TIMEOUT = 0.5 + +# After how long without updates the client will "timeout". +# +# When this timeout occurs, the client will attempt to fetch updates by itself, ignoring all the +# updates that arrive in the meantime. After all updates are fetched when this happens, the +# client will resume normal operation, and the timeout will reset. +# +# Documentation recommends 15 minutes without updates (https://core.telegram.org/api/updates). +NO_UPDATES_TIMEOUT = 15 * 60 + +# Entry "enum". +# Account-wide `pts` includes private conversations (one-to-one) and small group chats. +ENTRY_ACCOUNT = object() +# Account-wide `qts` includes only "secret" one-to-one chats. +ENTRY_SECRET = object() +# Integers will be Channel-specific `pts`, and includes "megagroup", "broadcast" and "supergroup" channels. + + +def next_updates_deadline(): + return asyncio.get_running_loop().time() + NO_UPDATES_TIMEOUT + + +class GapError(ValueError): + pass + + +# Represents the information needed to correctly handle a specific `tl::enums::Update`. +@dataclass +class PtsInfo: + pts: int + pts_count: int + entry: object + + @classmethod + def from_update(cls, update): + pts = getattr(update, 'pts', None) + if pts: + pts_count = getattr(update, 'pts_count', None) or 0 + entry = getattr(update, 'channel_id', None) or ENTRY_ACCOUNT + return cls(pts=pts, pts_count=pts_count, entry=entry) + + qts = getattr(update, 'qts', None) + if qts: + pts_count = 1 if isinstance(update, _tl.UpdateNewEncryptedMessage) else 0 + return cls(pts=qts, pts_count=pts_count, entry=ENTRY_SECRET) + + return None + + +# The state of a particular entry in the message box. +@dataclass +class State: + # Current local persistent timestamp. + pts: int + + # Next instant when we would get the update difference if no updates arrived before then. + deadline: float + + +# > ### Recovering gaps +# > […] Manually obtaining updates is also required in the following situations: +# > • Loss of sync: a gap was found in `seq` / `pts` / `qts` (as described above). +# > It may be useful to wait up to 0.5 seconds in this situation and abort the sync in case a new update +# > arrives, that fills the gap. +# +# This is really easy to trigger by spamming messages in a channel (with as little as 3 members works), because +# the updates produced by the RPC request take a while to arrive (whereas the read update comes faster alone). +@dataclass +class PossibleGap: + deadline: float + # Pending updates (those with a larger PTS, producing the gap which may later be filled). + updates: list # of updates + + +# Represents a "message box" (event `pts` for a specific entry). +# +# See https://core.telegram.org/api/updates#message-related-event-sequences. +@dataclass +class MessageBox: + # Map each entry to their current state. + map: dict = field(default_factory=dict) # entry -> state + + # Additional fields beyond PTS needed by `ENTRY_ACCOUNT`. + date: int = 1 + seq: int = 0 + + # Holds the entry with the closest deadline (optimization to avoid recalculating the minimum deadline). + next_deadline: object = None # entry + + # Which entries have a gap and may soon trigger a need to get difference. + # + # If a gap is found, stores the required information to resolve it (when should it timeout and what updates + # should be held in case the gap is resolved on its own). + # + # Not stored directly in `map` as an optimization (else we would need another way of knowing which entries have + # a gap in them). + possible_gaps: dict = field(default_factory=dict) # entry -> possiblegap + + # For which entries are we currently getting difference. + getting_diff_for: set = field(default_factory=set) # entry + + # Temporarily stores which entries should have their update deadline reset. + # Stored in the message box in order to reuse the allocation. + reset_deadlines_for: set = field(default_factory=set) # entry + + # region Creation, querying, and setting base state. + + @classmethod + def load(cls, session_state, channel_states): + """ + Create a [`MessageBox`] from a previously known update state. + """ + deadline = next_updates_deadline() + return cls( + map={ + ENTRY_ACCOUNT: State(pts=session_state.pts, deadline=deadline), + ENTRY_SECRET: State(pts=session_state.qts, deadline=deadline), + **{s.channel_id: s.pts for s in channel_states} + }, + date=session_state.date, + seq=session_state.seq, + next_deadline=ENTRY_ACCOUNT, + ) + + @classmethod + def session_state(self): + """ + Return the current state in a format that sessions understand. + + This should be used for persisting the state. + """ + return SessionState( + user_id=0, + dc_id=0, + bot=False, + pts=self.map.get(ENTRY_ACCOUNT, 0), + qts=self.map.get(ENTRY_SECRET, 0), + date=self.date, + seq=self.seq, + takeout_id=None, + ), [ChannelState(channel_id=id, pts=pts) for id, pts in self.map.items() if isinstance(id, int)] + + def is_empty(self) -> bool: + """ + Return true if the message box is empty and has no state yet. + """ + return self.map.get(ENTRY_ACCOUNT, NO_SEQ) == NO_SEQ + + def check_deadlines(self): + """ + Return the next deadline when receiving updates should timeout. + + If a deadline expired, the corresponding entries will be marked as needing to get its difference. + While there are entries pending of getting their difference, this method returns the current instant. + """ + now = asyncio.get_running_loop().time() + + if self.getting_diff_for: + return now + + deadline = next_updates_deadline() + + # Most of the time there will be zero or one gap in flight so finding the minimum is cheap. + if self.possible_gaps: + deadline = min(deadline, *self.possible_gaps.values()) + elif self.next_deadline in self.map: + deadline = min(deadline, self.map[self.next_deadline]) + + if now > deadline: + # Check all expired entries and add them to the list that needs getting difference. + self.getting_diff_for.update(entry for entry, gap in self.possible_gaps.items() if now > gap.deadline) + self.getting_diff_for.update(entry for entry, state in self.map.items() if now > state.deadline) + + # When extending `getting_diff_for`, it's important to have the moral equivalent of + # `begin_get_diff` (that is, clear possible gaps if we're now getting difference). + for entry in self.getting_diff_for: + self.possible_gaps.pop(entry, None) + + return deadline + + # Reset the deadline for the periods without updates for a given entry. + # + # It also updates the next deadline time to reflect the new closest deadline. + def reset_deadline(self, entry, deadline): + if entry in self.map: + self.map[entry].deadline = deadline + # TODO figure out why not in map may happen + + if self.next_deadline == entry: + # If the updated deadline was the closest one, recalculate the new minimum. + self.next_deadline = min(self.map.items(), key=lambda entry_state: entry_state[1].deadline)[0] + elif deadline < self.map.get(self.next_deadline, 0): + # If the updated deadline is smaller than the next deadline, change the next deadline to be the new one. + self.next_deadline = entry + # else an unrelated deadline was updated, so the closest one remains unchanged. + + # Convenience to reset a channel's deadline, with optional timeout. + def reset_channel_deadline(self, channel_id, timeout): + self.reset_deadlines(channel_id, asyncio.get_running_loop().time() + (timeout or NO_UPDATES_TIMEOUT)) + + # Reset all the deadlines in `reset_deadlines_for` and then empty the set. + def apply_deadlines_reset(self): + next_deadline = next_updates_deadline() + + reset_deadlines_for = self.reset_deadlines_for + self.reset_deadlines_for = set() # "move" the set to avoid self.reset_deadline() from touching it during iter + + for entry in reset_deadlines_for: + self.reset_deadline(entry, next_deadline) + + reset_deadlines_for.clear() # reuse allocation, the other empty set was a temporary dummy value + self.reset_deadlines_for = reset_deadlines_for + + # Sets the update state. + # + # Should be called right after login if [`MessageBox::new`] was used, otherwise undesirable + # updates will be fetched. + def set_state(self, state): + deadline = next_updates_deadline() + self.map[ENTRY_ACCOUNT] = State(pts=state.pts, deadline=deadline) + self.map[ENTRY_SECRET] = State(pts=state.qts, deadline=deadline) + self.date = state.date + self.seq = state.seq + + # Like [`MessageBox::set_state`], but for channels. Useful when getting dialogs. + # + # The update state will only be updated if no entry was known previously. + def try_set_channel_state(self, id, pts): + if id not in self.map: + self.map[id] = State(pts=pts, deadline=next_updates_deadline()) + + # Begin getting difference for the given entry. + # + # Clears any previous gaps. + def begin_get_diff(self, entry): + self.getting_diff_for.add(entry) + self.possible_gaps.pop(entry, None) + + # Finish getting difference for the given entry. + # + # It also resets the deadline. + def end_get_diff(self, entry): + self.getting_diff_for.pop(entry, None) + self.reset_deadline(entry, next_updates_deadline()) + assert entry not in self.possible_gaps, "gaps shouldn't be created while getting difference" + + # endregion Creation, querying, and setting base state. + + # region "Normal" updates flow (processing and detection of gaps). + + # Process an update and return what should be done with it. + # + # Updates corresponding to entries for which their difference is currently being fetched + # will be ignored. While according to the [updates' documentation]: + # + # > Implementations [have] to postpone updates received via the socket while + # > filling gaps in the event and `Update` sequences, as well as avoid filling + # > gaps in the same sequence. + # + # In practice, these updates should have also been retrieved through getting difference. + # + # [updates documentation] https://core.telegram.org/api/updates + def process_updates( + self, + updates, + chat_hashes, + result, # out list of updates; returns list of user, chat, or raise if gap + ): + # XXX adapt updates and chat hashes into updatescombined, raise gap on too long + date = updates.date + seq_start = updates.seq_start + seq = updates.seq + updates = updates.updates + users = updates.users + chats = updates.chats + + # > For all the other [not `updates` or `updatesCombined`] `Updates` type constructors + # > there is no need to check `seq` or change a local state. + if updates.seq_start != NO_SEQ: + if self.seq + 1 > updates.seq_start: + # Skipping updates that were already handled + return (updates.users, updates.chats) + elif self.seq + 1 < updates.seq_start: + # Gap detected + self.begin_get_diff(ENTRY_ACCOUNT) + raise GapError + # else apply + + self.date = updates.date + if updates.seq != NO_SEQ: + self.seq = updates.seq + + result.extend(filter(None, (self.apply_pts_info(u, reset_deadline=True) for u in updates.updates))) + + self.apply_deadlines_reset() + + def _sort_gaps(update): + pts = PtsInfo.from_update(u) + return pts.pts - pts.pts_count if pts else 0 + + if self.possible_gaps: + # For each update in possible gaps, see if the gap has been resolved already. + for key in list(self.possible_gaps.keys()): + self.possible_gaps[key].updates.sort(key=_sort_gaps) + + for _ in range(len(self.possible_gaps[key].updates)): + update = self.possible_gaps[key].updates.pop(0) + + # If this fails to apply, it will get re-inserted at the end. + # All should fail, so the order will be preserved (it would've cycled once). + update = self.apply_pts_info(update, reset_deadline=False) + if update: + result.append(update) + + # Clear now-empty gaps. + self.possible_gaps = {entry: gap for entry, gap in self.possible_gaps if gap.updates} + + return (updates.users, updates.chats) + + # Tries to apply the input update if its `PtsInfo` follows the correct order. + # + # If the update can be applied, it is returned; otherwise, the update is stored in a + # possible gap (unless it was already handled or would be handled through getting + # difference) and `None` is returned. + def apply_pts_info( + self, + update, + *, + reset_deadline, + ): + pts = PtsInfo.from_update(update) + if not pts: + # No pts means that the update can be applied in any order. + return update + + # As soon as we receive an update of any form related to messages (has `PtsInfo`), + # the "no updates" period for that entry is reset. + # + # Build the `HashSet` to avoid calling `reset_deadline` more than once for the same entry. + if reset_deadline: + self.reset_deadlines_for.insert(pts.entry) + + if pts.entry in self.getting_diff_for: + # Note: early returning here also prevents gap from being inserted (which they should + # not be while getting difference). + return None + + if pts.entry in self.map: + local_pts = self.map[pts.entry].pts + if local_pts + pts.pts_count > pts.pts: + # Ignore + return None + elif local_pts + pts.pts_count < pts.pts: + # Possible gap + # TODO store chats too? + if pts.entry not in self.possible_gaps: + self.possible_gaps[pts.entry] = PossibleGap( + deadline=asyncio.get_running_loop().time() + POSSIBLE_GAP_TIMEOUT, + updates=[] + ) + + self.possible_gaps[pts.entry].updates.append(update) + return None + else: + # Apply + pass + else: + # No previous `pts` known, and because this update has to be "right" (it's the first one) our + # `local_pts` must be one less. + local_pts = pts.pts - 1 + + # For example, when we're in a channel, we immediately receive: + # * ReadChannelInbox (pts = X) + # * NewChannelMessage (pts = X, pts_count = 1) + # + # Notice how both `pts` are the same. If we stored the one from the first, then the second one would + # be considered "already handled" and ignored, which is not desirable. Instead, advance local `pts` + # by `pts_count` (which is 0 for updates not directly related to messages, like reading inbox). + if pts.entry in self.map: + self.map[pts.entry].pts = local_pts + pts.pts_count + else: + self.map[pts.entry] = State(pts=local_pts + pts.pts_count, deadline=next_updates_deadline()) + + return update + + # endregion "Normal" updates flow (processing and detection of gaps). + + # region Getting and applying account difference. + + # Return the request that needs to be made to get the difference, if any. + def get_difference(self): + entry = ENTRY_ACCOUNT + if entry in self.getting_diff_for: + if entry in self.map: + return _tl.fn.updates.GetDifference( + pts=state.pts, + pts_total_limit=None, + date=self.date, + qts=self.map[ENTRY_SECRET].pts, + ) + else: + # TODO investigate when/why/if this can happen + self.end_get_diff(entry) + + return None + + # Similar to [`MessageBox::process_updates`], but using the result from getting difference. + def apply_difference( + self, + diff, + chat_hashes, + ): + if isinstance(diff, _tl.updates.DifferenceEmpty): + self.date = diff.date + self.seq = diff.seq + self.end_get_diff(ENTRY_ACCOUNT) + return [], [], [] + elif isinstance(diff, _tl.updates.Difference): + self.end_get_diff(ENTRY_ACCOUNT) + chat_hashes.extend(diff.users, diff.chats) + return self.apply_difference_type(diff) + elif isinstance(diff, _tl.updates.DifferenceSlice): + chat_hashes.extend(diff.users, diff.chats) + return self.apply_difference_type(diff) + elif isinstance(diff, _tl.updates.DifferenceTooLong): + # TODO when are deadlines reset if we update the map?? + self.map[ENTRY_ACCOUNT].pts = diff.pts + self.end_get_diff(ENTRY_ACCOUNT) + return [], [], [] + + def apply_difference_type( + self, + diff, + ): + state = getattr(diff, 'intermediate_state', None) or diff.state + self.map[ENTRY_ACCOUNT].pts = state.pts + self.map[ENTRY_SECRET].pts = state.qts + self.date = state.date + self.seq = state.seq + + for u in diff.updates: + if isinstance(u, _tl.UpdateChannelTooLong): + self.begin_get_diff(u.channel_id) + + updates.extend(_tl.UpdateNewMessage( + message=m, + pts=NO_SEQ, + pts_count=NO_SEQ, + ) for m in diff.new_messages) + updates.extend(_tl.UpdateNewEncryptedMessage( + message=m, + qts=NO_SEQ, + ) for m in diff.new_encrypted_messages) + + return diff.updates, diff.users, diff.chats + + # endregion Getting and applying account difference. + + # region Getting and applying channel difference. + + # Return the request that needs to be made to get a channel's difference, if any. + def get_channel_difference( + self, + chat_hashes, + ): + entry = next((id for id in self.getting_diff_for if isinstance(id, int)), None) + if not entry: + return None + + packed = chat_hashes.get(entry) + if not packed: + # Cannot get channel difference as we're missing its hash + self.end_get_diff(entry) + # Remove the outdated `pts` entry from the map so that the next update can correct + # it. Otherwise, it will spam that the access hash is missing. + self.map.pop(entry, None) + return None + + state = self.map.get(entry) + if not state: + # TODO investigate when/why/if this can happen + # Cannot get channel difference as we're missing its pts + self.end_get_diff(entry) + return None + + return _tl.fn.updates.GetChannelDifference( + force=False, + channel=channel, + filter=_tl.ChannelMessagesFilterEmpty(), + pts=state.pts, + limit=BOT_CHANNEL_DIFF_LIMIT if chat_hashes.is_self_bot() else USER_CHANNEL_DIFF_LIMIT + ) + + # Similar to [`MessageBox::process_updates`], but using the result from getting difference. + def apply_channel_difference( + self, + request, + diff, + chat_hashes, + ): + entry = request.channel.channel_id + self.possible_gaps.remove(entry) + + if isinstance(diff, _tl.updates.ChannelDifferenceEmpty): + assert diff.final + self.end_get_diff(entry) + self.map[entry].pts = diff.pts + return [], [], [] + elif isinstance(diff, _tl.updates.ChannelDifferenceTooLong): + assert diff.final + self.map[entry].pts = diff.dialog.pts + chat_hashes.extend(diff.users, diff.chats) + self.reset_channel_deadline(channel_id, diff.timeout) + # This `diff` has the "latest messages and corresponding chats", but it would + # be strange to give the user only partial changes of these when they would + # expect all updates to be fetched. Instead, nothing is returned. + return [], [], [] + elif isinstance(diff, _tl.updates.ChannelDifference): + if diff.final: + self.end_get_diff(entry) + + self.map[entry].pts = pts + updates.extend(_tl.UpdateNewMessage( + message=m, + pts=NO_SEQ, + pts_count=NO_SEQ, + ) for m in diff.new_messages) + chat_hashes.extend(diff.users, diff.chats); + self.reset_channel_deadline(channel_id, timeout) + + (diff.updates, diff.users, diff.chats) + + # endregion Getting and applying channel difference.