Fix getting difference for channels and for the first time

This commit is contained in:
Lonami Exo 2019-05-04 19:29:47 +02:00
parent adc9b4c9f1
commit 716ab2f96d
2 changed files with 56 additions and 8 deletions

View File

@ -292,6 +292,9 @@ class UpdateMethods(UserMethods):
async def _dispatch_update(self: 'TelegramClient', update, channel_id, pts_date): async def _dispatch_update(self: 'TelegramClient', update, channel_id, pts_date):
if not self._entity_cache.ensure_cached(update): if not self._entity_cache.ensure_cached(update):
# We could add a lock to not fetch the same pts twice if we are
# already fetching it. However this does not happen in practice,
# which makes sense, because different updates have different pts.
await self._get_difference(update, channel_id, pts_date) await self._get_difference(update, channel_id, pts_date)
built = EventBuilderDict(self, update) built = EventBuilderDict(self, update)
@ -357,6 +360,16 @@ class UpdateMethods(UserMethods):
try: try:
where = await self.get_input_entity(channel_id) where = await self.get_input_entity(channel_id)
except ValueError: except ValueError:
# There's a high chance that this fails, since
# we are getting the difference to fetch entities.
return
if not pts_date:
# First-time, can't get difference. Get pts instead.
result = await self(functions.messages.GetPeerDialogsRequest([
utils.get_input_dialog(where)
]))
self._state_cache[channel_id] = result.dialogs[0].pts
return return
result = await self(functions.updates.GetChannelDifferenceRequest( result = await self(functions.updates.GetChannelDifferenceRequest(
@ -367,6 +380,12 @@ class UpdateMethods(UserMethods):
force=True force=True
)) ))
else: else:
if not pts_date[0]:
# First-time, can't get difference. Get pts instead.
result = await self(functions.updates.GetStateRequest())
self._state_cache[None] = result.pts, result.date
return
result = await self(functions.updates.GetDifferenceRequest( result = await self(functions.updates.GetDifferenceRequest(
pts=pts_date[0], pts=pts_date[0],
date=pts_date[1], date=pts_date[1],

View File

@ -1,8 +1,32 @@
import datetime import datetime
import inspect
from .tl import types from .tl import types
# Which updates have the following fields?
_has_channel_id = []
# TODO EntityCache does the same. Reuse?
def _fill():
for name in dir(types):
update = getattr(types, name)
if getattr(update, 'SUBCLASS_OF_ID', None) == 0x9f89304e:
cid = update.CONSTRUCTOR_ID
sig = inspect.signature(update.__init__)
for param in sig.parameters.values():
if param.name == 'channel_id' and param.annotation == int:
_has_channel_id.append(cid)
if not _has_channel_id:
raise RuntimeError('FIXME: Did the init signature or updates change?')
# We use a function to avoid cluttering the globals (with name/update/cid/doc)
_fill()
class StateCache: class StateCache:
""" """
In-memory update state cache, defaultdict-like behaviour. In-memory update state cache, defaultdict-like behaviour.
@ -15,11 +39,11 @@ class StateCache:
if initial: if initial:
self._pts_date = initial.pts, initial.date self._pts_date = initial.pts, initial.date
else: else:
self._pts_date = 1, datetime.datetime.now() self._pts_date = None, None
def reset(self): def reset(self):
self.__dict__.clear() self.__dict__.clear()
self._pts_date = (1, 1) self._pts_date = None, None
# TODO Call this when receiving responses too...? # TODO Call this when receiving responses too...?
def update( def update(
@ -90,11 +114,8 @@ class StateCache:
def get_channel_id( def get_channel_id(
self, self,
update, update,
has_channel_id=frozenset(x.CONSTRUCTOR_ID for x in ( has_channel_id=frozenset(_has_channel_id),
types.UpdateChannelTooLong, # Hardcoded because only some with message are for channels
types.UpdateDeleteChannelMessages,
types.UpdateChannelWebPage
)),
has_message=frozenset(x.CONSTRUCTOR_ID for x in ( has_message=frozenset(x.CONSTRUCTOR_ID for x in (
types.UpdateNewChannelMessage, types.UpdateNewChannelMessage,
types.UpdateEditChannelMessage types.UpdateEditChannelMessage
@ -122,8 +143,16 @@ class StateCache:
If `item` is ``None``, returns the default ``(pts, date)``. If `item` is ``None``, returns the default ``(pts, date)``.
If it's an **unmarked** channel ID, returns its ``pts``. If it's an **unmarked** channel ID, returns its ``pts``.
If no information is known, ``pts`` will be ``None``.
""" """
if item is None: if item is None:
return self._pts_date return self._pts_date
else: else:
return self.__dict__.get(item, 1) return self.__dict__.get(item)
def __setitem__(self, where, value):
if where is None:
self._pts_date = value
else:
self.__dict__[where] = value