From 716ab2f96d1781eab981c07b7aff5f50dbc45830 Mon Sep 17 00:00:00 2001 From: Lonami Exo Date: Sat, 4 May 2019 19:29:47 +0200 Subject: [PATCH] Fix getting difference for channels and for the first time --- telethon/client/updates.py | 19 ++++++++++++++++ telethon/statecache.py | 45 +++++++++++++++++++++++++++++++------- 2 files changed, 56 insertions(+), 8 deletions(-) diff --git a/telethon/client/updates.py b/telethon/client/updates.py index 126a754c..00f43e38 100644 --- a/telethon/client/updates.py +++ b/telethon/client/updates.py @@ -292,6 +292,9 @@ class UpdateMethods(UserMethods): async def _dispatch_update(self: 'TelegramClient', update, channel_id, pts_date): if not self._entity_cache.ensure_cached(update): + # We could add a lock to not fetch the same pts twice if we are + # already fetching it. However this does not happen in practice, + # which makes sense, because different updates have different pts. await self._get_difference(update, channel_id, pts_date) built = EventBuilderDict(self, update) @@ -357,6 +360,16 @@ class UpdateMethods(UserMethods): try: where = await self.get_input_entity(channel_id) except ValueError: + # There's a high chance that this fails, since + # we are getting the difference to fetch entities. + return + + if not pts_date: + # First-time, can't get difference. Get pts instead. + result = await self(functions.messages.GetPeerDialogsRequest([ + utils.get_input_dialog(where) + ])) + self._state_cache[channel_id] = result.dialogs[0].pts return result = await self(functions.updates.GetChannelDifferenceRequest( @@ -367,6 +380,12 @@ class UpdateMethods(UserMethods): force=True )) else: + if not pts_date[0]: + # First-time, can't get difference. Get pts instead. + result = await self(functions.updates.GetStateRequest()) + self._state_cache[None] = result.pts, result.date + return + result = await self(functions.updates.GetDifferenceRequest( pts=pts_date[0], date=pts_date[1], diff --git a/telethon/statecache.py b/telethon/statecache.py index bea63edd..ea3214ba 100644 --- a/telethon/statecache.py +++ b/telethon/statecache.py @@ -1,8 +1,32 @@ import datetime +import inspect from .tl import types +# Which updates have the following fields? +_has_channel_id = [] + + +# TODO EntityCache does the same. Reuse? +def _fill(): + for name in dir(types): + update = getattr(types, name) + if getattr(update, 'SUBCLASS_OF_ID', None) == 0x9f89304e: + cid = update.CONSTRUCTOR_ID + sig = inspect.signature(update.__init__) + for param in sig.parameters.values(): + if param.name == 'channel_id' and param.annotation == int: + _has_channel_id.append(cid) + + if not _has_channel_id: + raise RuntimeError('FIXME: Did the init signature or updates change?') + + +# We use a function to avoid cluttering the globals (with name/update/cid/doc) +_fill() + + class StateCache: """ In-memory update state cache, defaultdict-like behaviour. @@ -15,11 +39,11 @@ class StateCache: if initial: self._pts_date = initial.pts, initial.date else: - self._pts_date = 1, datetime.datetime.now() + self._pts_date = None, None def reset(self): self.__dict__.clear() - self._pts_date = (1, 1) + self._pts_date = None, None # TODO Call this when receiving responses too...? def update( @@ -90,11 +114,8 @@ class StateCache: def get_channel_id( self, update, - has_channel_id=frozenset(x.CONSTRUCTOR_ID for x in ( - types.UpdateChannelTooLong, - types.UpdateDeleteChannelMessages, - types.UpdateChannelWebPage - )), + has_channel_id=frozenset(_has_channel_id), + # Hardcoded because only some with message are for channels has_message=frozenset(x.CONSTRUCTOR_ID for x in ( types.UpdateNewChannelMessage, types.UpdateEditChannelMessage @@ -122,8 +143,16 @@ class StateCache: If `item` is ``None``, returns the default ``(pts, date)``. If it's an **unmarked** channel ID, returns its ``pts``. + + If no information is known, ``pts`` will be ``None``. """ if item is None: return self._pts_date else: - return self.__dict__.get(item, 1) + return self.__dict__.get(item) + + def __setitem__(self, where, value): + if where is None: + self._pts_date = value + else: + self.__dict__[where] = value