From db09a92bc5f46daee6c93785ffd71c6693b7bdbf Mon Sep 17 00:00:00 2001 From: Lonami Exo Date: Fri, 13 May 2022 13:17:16 +0200 Subject: [PATCH] Make use of the new MessageBox --- telethon/client/telegrambaseclient.py | 62 +++--- telethon/client/updates.py | 262 +++++++------------------- telethon/network/mtprotosender.py | 7 +- 3 files changed, 101 insertions(+), 230 deletions(-) diff --git a/telethon/client/telegrambaseclient.py b/telethon/client/telegrambaseclient.py index 1305ab9b..c0375eab 100644 --- a/telethon/client/telegrambaseclient.py +++ b/telethon/client/telegrambaseclient.py @@ -15,6 +15,7 @@ from ..network import MTProtoSender, Connection, ConnectionTcpFull, TcpMTProxy from ..sessions import Session, SQLiteSession, MemorySession from ..tl import functions, types from ..tl.alltlobjects import LAYER +from .._updates import MessageBox, EntityCache as MbEntityCache DEFAULT_DC_ID = 2 DEFAULT_IPV4_IP = '149.154.167.51' @@ -376,18 +377,6 @@ class TelegramBaseClient(abc.ABC): proxy=init_proxy ) - self._sender = MTProtoSender( - self.session.auth_key, - loggers=self._log, - retries=self._connection_retries, - delay=self._retry_delay, - auto_reconnect=self._auto_reconnect, - connect_timeout=self._timeout, - auth_key_callback=self._auth_key_callback, - update_callback=self._handle_update, - auto_reconnect_callback=self._handle_auto_reconnect - ) - # Remember flood-waited requests to avoid making them again self._flood_waited_requests = {} @@ -396,18 +385,14 @@ class TelegramBaseClient(abc.ABC): self._borrow_sender_lock = asyncio.Lock() self._updates_handle = None + self._keepalive_handle = None self._last_request = time.time() self._channel_pts = {} self._no_updates = not receive_updates - if sequential_updates: - self._updates_queue = asyncio.Queue() - self._dispatching_updates_queue = asyncio.Event() - else: - # Use a set of pending instead of a queue so we can properly - # terminate all pending updates on disconnect. - self._updates_queue = set() - self._dispatching_updates_queue = None + # Used for non-sequential updates, in order to terminate all pending tasks on disconnect. + self._sequential_updates = sequential_updates + self._event_handler_tasks = set() self._authorized = None # None = unknown, False = no, True = yes @@ -442,6 +427,26 @@ class TelegramBaseClient(abc.ABC): # A place to store if channels are a megagroup or not (see `edit_admin`) self._megagroup_cache = {} + # This is backported from v2 in a very ad-hoc way just to get proper update handling + self._catch_up = True + self._updates_queue = asyncio.Queue() + self._message_box = MessageBox() + # This entity cache is tailored for the messagebox and is not used for absolutely everything like _entity_cache + self._mb_entity_cache = MbEntityCache() # required for proper update handling (to know when to getDifference) + + self._sender = MTProtoSender( + self.session.auth_key, + loggers=self._log, + retries=self._connection_retries, + delay=self._retry_delay, + auto_reconnect=self._auto_reconnect, + connect_timeout=self._timeout, + auth_key_callback=self._auth_key_callback, + updates_queue=self._updates_queue, + auto_reconnect_callback=self._handle_auto_reconnect + ) + + # endregion # region Properties @@ -537,6 +542,7 @@ class TelegramBaseClient(abc.ABC): )) self._updates_handle = self.loop.create_task(self._update_loop()) + self._keepalive_handle = self.loop.create_task(self._keepalive_loop()) def is_connected(self: 'TelegramClient') -> bool: """ @@ -629,13 +635,12 @@ class TelegramBaseClient(abc.ABC): # trio's nurseries would handle this for us, but this is asyncio. # All tasks spawned in the background should properly be terminated. - if self._dispatching_updates_queue is None and self._updates_queue: - for task in self._updates_queue: + if self._event_handler_tasks: + for task in self._event_handler_tasks: task.cancel() - await asyncio.wait(self._updates_queue) - self._updates_queue.clear() - + await asyncio.wait(self._event_handler_tasks) + self._event_handler_tasks.clear() await self.session.close() @@ -648,7 +653,8 @@ class TelegramBaseClient(abc.ABC): """ await self._sender.disconnect() await helpers._cancel(self._log[__name__], - updates_handle=self._updates_handle) + updates_handle=self._updates_handle, + keepalive_handle=self._keepalive_handle) async def _switch_dc(self: 'TelegramClient', new_dc): """ @@ -845,10 +851,6 @@ class TelegramBaseClient(abc.ABC): """ raise NotImplementedError - @abc.abstractmethod - def _handle_update(self: 'TelegramClient', update): - raise NotImplementedError - @abc.abstractmethod def _update_loop(self: 'TelegramClient'): raise NotImplementedError diff --git a/telethon/client/updates.py b/telethon/client/updates.py index 65cff607..c100bef3 100644 --- a/telethon/client/updates.py +++ b/telethon/client/updates.py @@ -7,10 +7,12 @@ import time import traceback import typing import logging +from collections import deque from .. import events, utils, errors from ..events.common import EventBuilder, EventCommon from ..tl import types, functions +from .._updates import GapError if typing.TYPE_CHECKING: from .telegramclient import TelegramClient @@ -237,106 +239,76 @@ class UpdateMethods: await client.catch_up() """ - pts, date = self._state_cache[None] - if not pts: - return - - self.session.catching_up = True - try: - while True: - d = await self(functions.updates.GetDifferenceRequest( - pts, date, 0 - )) - if isinstance(d, (types.updates.DifferenceSlice, - types.updates.Difference)): - if isinstance(d, types.updates.Difference): - state = d.state - else: - state = d.intermediate_state - - pts, date = state.pts, state.date - await self._handle_update(types.Updates( - users=d.users, - chats=d.chats, - date=state.date, - seq=state.seq, - updates=d.other_updates + [ - types.UpdateNewMessage(m, 0, 0) - for m in d.new_messages - ] - )) - - # TODO Implement upper limit (max_pts) - # We don't want to fetch updates we already know about. - # - # We may still get duplicates because the Difference - # contains a lot of updates and presumably only has - # the state for the last one, but at least we don't - # unnecessarily fetch too many. - # - # updates.getDifference's pts_total_limit seems to mean - # "how many pts is the request allowed to return", and - # if there is more than that, it returns "too long" (so - # there would be duplicate updates since we know about - # some). This can be used to detect collisions (i.e. - # it would return an update we have already seen). - else: - if isinstance(d, types.updates.DifferenceEmpty): - date = d.date - elif isinstance(d, types.updates.DifferenceTooLong): - pts = d.pts - break - except (ConnectionError, asyncio.CancelledError): - pass - finally: - # TODO Save new pts to session - self._state_cache._pts_date = (pts, date) - self.session.catching_up = False + await self._updates_queue.put(types.UpdatesTooLong()) # endregion # region Private methods - # It is important to not make _handle_update async because we rely on - # the order that the updates arrive in to update the pts and date to - # be always-increasing. There is also no need to make this async. - async def _handle_update(self: 'TelegramClient', update): + async def _update_loop(self: 'TelegramClient'): + try: + updates_to_dispatch = deque() + + while self.is_connected(): + if updates_to_dispatch: + if self._sequential_updates: + await self._dispatch_update(updates_to_dispatch.popleft()) + else: + while updates_to_dispatch: + task = self.loop.create_task(self._dispatch_update(updates_to_dispatch.popleft())) + self._event_handler_tasks.add(task) + task.add_done_callback(lambda _: self._event_handler_tasks.discard(task)) + + continue + + get_diff = self._message_box.get_difference() + if get_diff: + self._log[__name__].info('Getting difference for account updates') + diff = await self(get_diff) + updates, users, chats = self._message_box.apply_difference(diff, self._mb_entity_cache) + updates_to_dispatch.extend(await self._preprocess_updates(updates, users, chats)) + continue + + get_diff = self._message_box.get_channel_difference(self._mb_entity_cache) + if get_diff: + self._log[__name__].info('Getting difference for channel updates') + diff = await self(get_diff) + updates, users, chats = self._message_box.apply_channel_difference(get_diff, diff, self._mb_entity_cache) + updates_to_dispatch.extend(await self._preprocess_updates(updates, users, chats)) + continue + + deadline = self._message_box.check_deadlines() + try: + updates = await asyncio.wait_for( + self._updates_queue.get(), + deadline - asyncio.get_running_loop().time() + ) + except asyncio.TimeoutError: + self._log[__name__].info('Timeout waiting for updates expired') + continue + + processed = [] + try: + users, chats = self._message_box.process_updates(updates, self._mb_entity_cache, processed) + except GapError: + continue # get(_channel)_difference will start returning requests + + updates_to_dispatch.extend(await self._preprocess_updates(processed, users, chats)) + except Exception: + self._log[__name__].exception('Fatal error handling updates (this is a bug in Telethon, please report it)') + + async def _preprocess_updates(self, updates, users, chats): await self.session.process_entities(update) self._entity_cache.add(update) - if isinstance(update, (types.Updates, types.UpdatesCombined)): - entities = {utils.get_peer_id(x): x for x in - itertools.chain(update.users, update.chats)} - for u in update.updates: - self._process_update(u, update.updates, entities=entities) - elif isinstance(update, types.UpdateShort): - self._process_update(update.update, None) - else: - self._process_update(update, None) + self._mb_entity_cache.extend(users, chats) + entities = {utils.get_peer_id(x): x + for x in itertools.chain(users, chats)} + for u in updates: + u._entities = entities + return updates - self._state_cache.update(update) - - def _process_update(self: 'TelegramClient', update, others, entities=None): - update._entities = entities or {} - - # This part is somewhat hot so we don't bother patching - # update with channel ID/its state. Instead we just pass - # arguments which is faster. - channel_id = self._state_cache.get_channel_id(update) - args = (update, others, channel_id, self._state_cache[channel_id]) - if self._dispatching_updates_queue is None: - task = self.loop.create_task(self._dispatch_update(*args)) - self._updates_queue.add(task) - task.add_done_callback(lambda _: self._updates_queue.discard(task)) - else: - self._updates_queue.put_nowait(args) - if not self._dispatching_updates_queue.is_set(): - self._dispatching_updates_queue.set() - self.loop.create_task(self._dispatch_queue_updates()) - - self._state_cache.update(update) - - async def _update_loop(self: 'TelegramClient'): + async def _keepalive_loop(self: 'TelegramClient'): # Pings' ID don't really need to be secure, just "random" rnd = lambda: random.randrange(-2**63, 2**63) while self.is_connected(): @@ -374,50 +346,9 @@ class UpdateMethods: # it every minute instead. No-op if there's nothing new. await self.session.save() - # We need to send some content-related request at least hourly - # for Telegram to keep delivering updates, otherwise they will - # just stop even if we're connected. Do so every 30 minutes. - # - # TODO Call getDifference instead since it's more relevant - if time.time() - self._last_request > 30 * 60: - if not await self.is_user_authorized(): - # What can be the user doing for so - # long without being logged in...? - continue - - try: - await self(functions.updates.GetStateRequest()) - except (ConnectionError, asyncio.CancelledError): - return - - async def _dispatch_queue_updates(self: 'TelegramClient'): - while not self._updates_queue.empty(): - await self._dispatch_update(*self._updates_queue.get_nowait()) - - self._dispatching_updates_queue.clear() - - async def _dispatch_update(self: 'TelegramClient', update, others, channel_id, pts_date): - if not self._entity_cache.ensure_cached(update): - # We could add a lock to not fetch the same pts twice if we are - # already fetching it. However this does not happen in practice, - # which makes sense, because different updates have different pts. - if self._state_cache.update(update, check_only=True): - # If the update doesn't have pts, fetching won't do anything. - # For example, UpdateUserStatus or UpdateChatUserTyping. - try: - await self._get_difference(update, channel_id, pts_date) - except OSError: - pass # We were disconnected, that's okay - except errors.RPCError: - # There's a high chance the request fails because we lack - # the channel. Because these "happen sporadically" (#1428) - # we should be okay (no flood waits) even if more occur. - pass - except ValueError: - # There is a chance that GetFullChannelRequest and GetDifferenceRequest - # inside the _get_difference() function will end up with - # ValueError("Request was unsuccessful N time(s)") for whatever reasons. - pass + async def _dispatch_update(self: 'TelegramClient', update): + # TODO only used for AlbumHack, and MessageBox is not really designed for this + others = None if not self._self_input_peer: # Some updates require our own ID, so we must make sure @@ -523,67 +454,6 @@ class UpdateMethods: name = getattr(callback, '__name__', repr(callback)) self._log[__name__].exception('Unhandled exception on %s', name) - async def _get_difference(self: 'TelegramClient', update, channel_id, pts_date): - """ - Get the difference for this `channel_id` if any, then load entities. - - Calls :tl:`updates.getDifference`, which fills the entities cache - (always done by `__call__`) and lets us know about the full entities. - """ - # Fetch since the last known pts/date before this update arrived, - # in order to fetch this update at full, including its entities. - self._log[__name__].debug('Getting difference for entities ' - 'for %r', update.__class__) - if channel_id: - # There are reports where we somehow call get channel difference - # with `InputPeerEmpty`. Check our assumptions to better debug - # this when it happens. - assert isinstance(channel_id, int), 'channel_id was {}, not int in {}'.format(type(channel_id), update) - try: - # Wrap the ID inside a peer to ensure we get a channel back. - where = await self.get_input_entity(types.PeerChannel(channel_id)) - except ValueError: - # There's a high chance that this fails, since - # we are getting the difference to fetch entities. - return - - if not pts_date: - # First-time, can't get difference. Get pts instead. - result = await self(functions.channels.GetFullChannelRequest( - utils.get_input_channel(where) - )) - self._state_cache[channel_id] = result.full_chat.pts - return - - result = await self(functions.updates.GetChannelDifferenceRequest( - channel=where, - filter=types.ChannelMessagesFilterEmpty(), - pts=pts_date, # just pts - limit=100, - force=True - )) - else: - if not pts_date[0]: - # First-time, can't get difference. Get pts instead. - result = await self(functions.updates.GetStateRequest()) - self._state_cache[None] = result.pts, result.date - return - - result = await self(functions.updates.GetDifferenceRequest( - pts=pts_date[0], - date=pts_date[1], - qts=0 - )) - - if isinstance(result, (types.updates.Difference, - types.updates.DifferenceSlice, - types.updates.ChannelDifference, - types.updates.ChannelDifferenceTooLong)): - update._entities.update({ - utils.get_peer_id(x): x for x in - itertools.chain(result.users, result.chats) - }) - async def _handle_auto_reconnect(self: 'TelegramClient'): # TODO Catch-up # For now we make a high-level request to let Telegram diff --git a/telethon/network/mtprotosender.py b/telethon/network/mtprotosender.py index 80c0c528..a1a02909 100644 --- a/telethon/network/mtprotosender.py +++ b/telethon/network/mtprotosender.py @@ -44,7 +44,7 @@ class MTProtoSender: def __init__(self, auth_key, *, loggers, retries=5, delay=1, auto_reconnect=True, connect_timeout=None, auth_key_callback=None, - update_callback=None, auto_reconnect_callback=None): + updates_queue=None, auto_reconnect_callback=None): self._connection = None self._loggers = loggers self._log = loggers[__name__] @@ -53,7 +53,7 @@ class MTProtoSender: self._auto_reconnect = auto_reconnect self._connect_timeout = connect_timeout self._auth_key_callback = auth_key_callback - self._update_callback = update_callback + self._updates_queue = updates_queue self._auto_reconnect_callback = auto_reconnect_callback self._connect_lock = asyncio.Lock() self._ping = None @@ -645,8 +645,7 @@ class MTProtoSender: return self._log.debug('Handling update %s', message.obj.__class__.__name__) - if self._update_callback: - await self._update_callback(message.obj) + self._updates_queue.put_nowait(message.obj) async def _handle_pong(self, message): """