From 34bb2b8fc38216eaaa0d6fdbbb438b7d38ecc37c Mon Sep 17 00:00:00 2001 From: vanutp Date: Thu, 27 May 2021 11:31:55 +0300 Subject: [PATCH] Fix catch_up updates.py: - Fix catch_up getting stuck in infinite loop by saving qts - Save new state to session after catch_up - Save initial state immediately if catch_up was ran first time - If DifferenceTooLong was received, continue fetching with new pts instead of exiting - Catch up channels - Implement pts limit (pts_total_limit in chats and limit in channels) - Check if update was already processed before processing telegrambaseclient.py, sessions: - Get saved state from channels on startup and save channels state statecache.py: - Create separate variable _store in StateCache, because __dict__ also stores _logger variable - Move has_pts, has_qts, has_date, has_channel_pts outside of function arguments Fixes #3041 --- telethon/client/telegrambaseclient.py | 16 ++- telethon/client/updates.py | 109 ++++++++++++++------ telethon/sessions/abstract.py | 7 ++ telethon/sessions/memory.py | 3 + telethon/sessions/sqlite.py | 8 ++ telethon/statecache.py | 141 ++++++++++++++++---------- 6 files changed, 198 insertions(+), 86 deletions(-) diff --git a/telethon/client/telegrambaseclient.py b/telethon/client/telegrambaseclient.py index bea10cbe..ef9bc866 100644 --- a/telethon/client/telegrambaseclient.py +++ b/telethon/client/telegrambaseclient.py @@ -6,6 +6,7 @@ import logging import platform import time import typing +from datetime import datetime from .. import version, helpers, __name__ as __base_name__ from ..crypto import rsa @@ -401,9 +402,10 @@ class TelegramBaseClient(abc.ABC): self._authorized = None # None = unknown, False = no, True = yes # Update state (for catching up after a disconnection) - # TODO Get state from channels too self._state_cache = StateCache( self.session.get_update_state(0), self._log) + for k, v in self.session.get_channel_pts().items(): + self._state_cache[k] = v # Some further state for subclasses self._event_builders = [] @@ -630,15 +632,23 @@ class TelegramBaseClient(abc.ABC): await asyncio.wait(self._updates_queue) self._updates_queue.clear() - pts, date = self._state_cache[None] + pts, qts, date = self._state_cache[None] if pts and date: self.session.set_update_state(0, types.updates.State( pts=pts, - qts=0, + qts=qts, date=date, seq=0, unread_count=0 )) + for channel_id, pts in self._state_cache.get_channel_pts().items(): + self.session.set_update_state(channel_id, types.updates.State( + pts=pts, + qts=0, + date=datetime.fromtimestamp(0), + seq=0, + unread_count=0 + )) self.session.close() diff --git a/telethon/client/updates.py b/telethon/client/updates.py index a9d6344e..031882bf 100644 --- a/telethon/client/updates.py +++ b/telethon/client/updates.py @@ -7,10 +7,12 @@ import time import traceback import typing import logging +from datetime import datetime from .. import events, utils, errors from ..events.common import EventBuilder, EventCommon from ..tl import types, functions +from ..tl.types import UpdateChannelTooLong if typing.TYPE_CHECKING: from .telegramclient import TelegramClient @@ -211,7 +213,45 @@ class UpdateMethods: """ return [(callback, event) for event, callback in self._event_builders] - async def catch_up(self: 'TelegramClient'): + async def _catch_up_channel(self: 'TelegramClient', channel_id: int, pts: int, limit: int = None): + if self._state_cache[channel_id]: + pts = self._state_cache[channel_id] + + if not pts: + # First-time, can't get difference. Get pts instead. + result = await self(functions.channels.GetFullChannelRequest(channel_id)) + pts = self._state_cache[channel_id] = result.full_chat.pts + self.session.set_update_state(channel_id, types.updates.State(pts, 0, datetime.fromtimestamp(0), 0, 0)) + return + try: + while True: + d = await self(functions.updates.GetChannelDifferenceRequest( + channel=channel_id, + filter=types.ChannelMessagesFilterEmpty(), + pts=pts, + limit=limit or 100000 # If the limit isn't set, fetch all updates that we can + )) + if isinstance(d, types.updates.ChannelDifference): + pts = d.pts + self._handle_update(types.Updates( + users=d.users, + chats=d.chats, + date=None, + seq=0, + updates=d.other_updates + [ + types.UpdateNewChannelMessage(m, 0, 0) + for m in d.new_messages + ] + )) + elif isinstance(d, (types.updates.ChannelDifferenceTooLong, + types.updates.ChannelDifferenceEmpty)): + # If there is too much updates (ChannelDifferenceTooLong), + # there is no way to get them without raising limit or GetHistoryRequest, so just break + break + finally: + self.session.set_update_state(channel_id, types.updates.State(pts, 0, datetime.fromtimestamp(0), 0, 0)) + + async def catch_up(self: 'TelegramClient', pts_total_limit=None, limit=None): """ "Catches up" on the missed updates while the client was offline. You should call this method after registering the event handlers @@ -224,15 +264,22 @@ class UpdateMethods: await client.catch_up() """ - pts, date = self._state_cache[None] - if not pts: - return + pts, qts, date = self._state_cache[None] self.session.catching_up = True try: + if not pts: + # Ran first time, get initial pts, qts and date and return + result = await self(functions.updates.GetStateRequest()) + pts, qts, date = result.pts, result.qts, result.date + return + if not qts: + qts = 0 + + channels_to_fetch = [] while True: d = await self(functions.updates.GetDifferenceRequest( - pts, date, 0 + pts, date, qts, pts_total_limit )) if isinstance(d, (types.updates.DifferenceSlice, types.updates.Difference)): @@ -241,43 +288,39 @@ class UpdateMethods: else: state = d.intermediate_state - pts, date = state.pts, state.date + updates = [] + for update in d.other_updates: + if isinstance(update, UpdateChannelTooLong): + channels_to_fetch.append((update.channel_id, update.pts)) + else: + updates.append(update) + + pts, qts, date = state.pts, state.qts, state.date self._handle_update(types.Updates( users=d.users, chats=d.chats, date=state.date, seq=state.seq, - updates=d.other_updates + [ + updates=updates + [ types.UpdateNewMessage(m, 0, 0) for m in d.new_messages ] )) - - # TODO Implement upper limit (max_pts) - # We don't want to fetch updates we already know about. - # - # We may still get duplicates because the Difference - # contains a lot of updates and presumably only has - # the state for the last one, but at least we don't - # unnecessarily fetch too many. - # - # updates.getDifference's pts_total_limit seems to mean - # "how many pts is the request allowed to return", and - # if there is more than that, it returns "too long" (so - # there would be duplicate updates since we know about - # some). This can be used to detect collisions (i.e. - # it would return an update we have already seen). - else: - if isinstance(d, types.updates.DifferenceEmpty): - date = d.date - elif isinstance(d, types.updates.DifferenceTooLong): - pts = d.pts + elif isinstance(d, types.updates.DifferenceTooLong): + pts = d.pts + # If the limit isn't set, fetch all updates that we can + if pts_total_limit is not None: + break + elif isinstance(d, types.updates.DifferenceEmpty): + date = d.date break + for channel_id, channel_pts in channels_to_fetch: + await self._catch_up_channel(channel_id, channel_pts, limit) except (ConnectionError, asyncio.CancelledError): pass finally: - # TODO Save new pts to session - self._state_cache._pts_date = (pts, date) + self._state_cache[None] = (pts, qts, date) + self.session.set_update_state(0, types.updates.State(pts, qts, date, seq=0, unread_count=0)) self.session.catching_up = False # endregion @@ -288,6 +331,9 @@ class UpdateMethods: # the order that the updates arrive in to update the pts and date to # be always-increasing. There is also no need to make this async. def _handle_update(self: 'TelegramClient', update): + if self._state_cache.update_already_processed(update): + return + self.session.process_entities(update) self._entity_cache.add(update) @@ -304,6 +350,9 @@ class UpdateMethods: self._state_cache.update(update) def _process_update(self: 'TelegramClient', update, others, entities=None): + if self._state_cache.update_already_processed(update): + return + update._entities = entities or {} # This part is somewhat hot so we don't bother patching @@ -553,7 +602,7 @@ class UpdateMethods: if not pts_date[0]: # First-time, can't get difference. Get pts instead. result = await self(functions.updates.GetStateRequest()) - self._state_cache[None] = result.pts, result.date + self._state_cache[None] = result.pts, result.qts, result.date return result = await self(functions.updates.GetDifferenceRequest( diff --git a/telethon/sessions/abstract.py b/telethon/sessions/abstract.py index 5fda1c18..6cb11e66 100644 --- a/telethon/sessions/abstract.py +++ b/telethon/sessions/abstract.py @@ -88,6 +88,13 @@ class Session(ABC): """ raise NotImplementedError + @abstractmethod + def get_channel_pts(self): + """ + Returns the ``Dict[int, int]`` with pts for all saved channels + """ + raise NotImplementedError + @abstractmethod def set_update_state(self, entity_id, state): """ diff --git a/telethon/sessions/memory.py b/telethon/sessions/memory.py index 881622a6..5afeeafd 100644 --- a/telethon/sessions/memory.py +++ b/telethon/sessions/memory.py @@ -74,6 +74,9 @@ class MemorySession(Session): def get_update_state(self, entity_id): return self._update_states.get(entity_id, None) + def get_channel_pts(self): + return {x[0]: x[1] for x in self._update_states.items() if x[0] != 0} + def set_update_state(self, entity_id, state): self._update_states[entity_id] = state diff --git a/telethon/sessions/sqlite.py b/telethon/sessions/sqlite.py index 82b2129d..ab34dbf3 100644 --- a/telethon/sessions/sqlite.py +++ b/telethon/sessions/sqlite.py @@ -210,6 +210,14 @@ class SQLiteSession(MemorySession): date, tz=datetime.timezone.utc) return types.updates.State(pts, qts, date, seq, unread_count=0) + def get_channel_pts(self): + c = self._cursor() + try: + rows = c.execute('select id, pts from update_state').fetchall() + finally: + c.close() + return {x[0]: x[1] for x in rows if x[0] != 0} + def set_update_state(self, entity_id, state): self._execute('insert or replace into update_state values (?,?,?,?,?)', entity_id, state.pts, state.qts, diff --git a/telethon/statecache.py b/telethon/statecache.py index 0e02bbd4..feb63460 100644 --- a/telethon/statecache.py +++ b/telethon/statecache.py @@ -2,7 +2,6 @@ import inspect from .tl import types - # Which updates have the following fields? _has_channel_id = [] @@ -26,23 +25,70 @@ def _fill(): _fill() +has_pts = frozenset(x.CONSTRUCTOR_ID for x in ( + types.UpdateNewMessage, + types.UpdateDeleteMessages, + types.UpdateReadHistoryInbox, + types.UpdateReadHistoryOutbox, + types.UpdateWebPage, + types.UpdateReadMessagesContents, + types.UpdateEditMessage, + types.updates.State, + types.updates.DifferenceTooLong, + types.UpdateShortMessage, + types.UpdateShortChatMessage, + types.UpdateShortSentMessage +)) +has_qts = frozenset(x.CONSTRUCTOR_ID for x in ( + types.UpdateBotStopped, + types.UpdateNewEncryptedMessage, + types.updates.State +)) +has_date = frozenset(x.CONSTRUCTOR_ID for x in ( + types.UpdateUserPhoto, + types.UpdateEncryption, + types.UpdateEncryptedMessagesRead, + types.UpdateChatParticipantAdd, + types.updates.DifferenceEmpty, + types.UpdateShortMessage, + types.UpdateShortChatMessage, + types.UpdateShort, + types.UpdatesCombined, + types.Updates, + types.UpdateShortSentMessage, +)) +has_channel_pts = frozenset(x.CONSTRUCTOR_ID for x in ( + types.UpdateChannelTooLong, + types.UpdateNewChannelMessage, + types.UpdateDeleteChannelMessages, + types.UpdateEditChannelMessage, + types.UpdateChannelWebPage, + types.updates.ChannelDifferenceEmpty, + types.updates.ChannelDifferenceTooLong, + types.updates.ChannelDifference +)) + + class StateCache: """ In-memory update state cache, defaultdict-like behaviour. """ + _store: dict + def __init__(self, initial, loggers): # We only care about the pts and the date. By using a tuple which # is lightweight and immutable we can easily copy them around to # each update in case they need to fetch missing entities. self._logger = loggers[__name__] + self._store = {} if initial: - self._pts_date = initial.pts, initial.date + self._pts_date = initial.pts, initial.qts, initial.date else: - self._pts_date = None, None + self._pts_date = None, None, None def reset(self): - self.__dict__.clear() - self._pts_date = None, None + self._store.clear() + self._pts_date = None, None, None # TODO Call this when receiving responses too...? def update( @@ -50,43 +96,6 @@ class StateCache: update, *, channel_id=None, - has_pts=frozenset(x.CONSTRUCTOR_ID for x in ( - types.UpdateNewMessage, - types.UpdateDeleteMessages, - types.UpdateReadHistoryInbox, - types.UpdateReadHistoryOutbox, - types.UpdateWebPage, - types.UpdateReadMessagesContents, - types.UpdateEditMessage, - types.updates.State, - types.updates.DifferenceTooLong, - types.UpdateShortMessage, - types.UpdateShortChatMessage, - types.UpdateShortSentMessage - )), - has_date=frozenset(x.CONSTRUCTOR_ID for x in ( - types.UpdateUserPhoto, - types.UpdateEncryption, - types.UpdateEncryptedMessagesRead, - types.UpdateChatParticipantAdd, - types.updates.DifferenceEmpty, - types.UpdateShortMessage, - types.UpdateShortChatMessage, - types.UpdateShort, - types.UpdatesCombined, - types.Updates, - types.UpdateShortSentMessage, - )), - has_channel_pts=frozenset(x.CONSTRUCTOR_ID for x in ( - types.UpdateChannelTooLong, - types.UpdateNewChannelMessage, - types.UpdateDeleteChannelMessages, - types.UpdateEditChannelMessage, - types.UpdateChannelWebPage, - types.updates.ChannelDifferenceEmpty, - types.updates.ChannelDifferenceTooLong, - types.updates.ChannelDifference - )), check_only=False ): """ @@ -96,13 +105,20 @@ class StateCache: if check_only: return cid in has_pts or cid in has_date or cid in has_channel_pts + new_pts_date = tuple() if cid in has_pts: - if cid in has_date: - self._pts_date = update.pts, update.date - else: - self._pts_date = update.pts, self._pts_date[1] - elif cid in has_date: - self._pts_date = self._pts_date[0], update.date + new_pts_date += update.pts, + else: + new_pts_date += self._pts_date[0], + if cid in has_qts: + new_pts_date += update.qts, + else: + new_pts_date += self._pts_date[1], + if cid in has_date: + new_pts_date += update.date, + else: + new_pts_date += self._pts_date[2], + self._pts_date = new_pts_date if cid in has_channel_pts: if channel_id is None: @@ -112,7 +128,23 @@ class StateCache: self._logger.info( 'Failed to retrieve channel_id from %s', update) else: - self.__dict__[channel_id] = update.pts + self._store[channel_id] = update.pts + + def update_already_processed(self, update): + cid = update.CONSTRUCTOR_ID + # If pts == 0, the update is from catch_up + if cid in has_pts and \ + update.pts != 0 and \ + update.pts >= self._pts_date[0]: + return True + if cid in has_qts and update.qts >= self._pts_date[1]: + return True + if cid in has_channel_pts: + channel_id = self.get_channel_id(update) + if update.pts != 0 and \ + self._store.get(channel_id, 0) >= update.pts: + return True + return False def get_channel_id( self, @@ -120,8 +152,8 @@ class StateCache: has_channel_id=frozenset(_has_channel_id), # Hardcoded because only some with message are for channels has_message=frozenset(x.CONSTRUCTOR_ID for x in ( - types.UpdateNewChannelMessage, - types.UpdateEditChannelMessage + types.UpdateNewChannelMessage, + types.UpdateEditChannelMessage )) ): """ @@ -155,10 +187,13 @@ class StateCache: if item is None: return self._pts_date else: - return self.__dict__.get(item) + return self._store.get(item) def __setitem__(self, where, value): if where is None: self._pts_date = value else: - self.__dict__[where] = value + self._store[where] = value + + def get_channel_pts(self): + return self._store