From 9965cda96826bef111788711d0d63b906fbd3eca Mon Sep 17 00:00:00 2001 From: Lonami Exo Date: Wed, 10 Apr 2019 21:09:15 +0400 Subject: [PATCH] Save pts and date in a tuple for immutability This way it is easy and cheap to copy the two required values to all incoming updates in case we need to getDifference since the previous pts/date to fetch entities. This is still a work in progress. --- telethon/client/auth.py | 13 ++-- telethon/client/telegrambaseclient.py | 28 +++---- telethon/client/updates.py | 101 ++++++++++++-------------- telethon/network/mtprotosender.py | 2 +- 4 files changed, 68 insertions(+), 76 deletions(-) diff --git a/telethon/client/auth.py b/telethon/client/auth.py index b1a932e7..a035833a 100644 --- a/telethon/client/auth.py +++ b/telethon/client/auth.py @@ -375,11 +375,10 @@ class AuthMethods(MessageParseMethods, UserMethods): self._self_input_peer = utils.get_input_peer(user, allow_self=False) self._authorized = True - # By setting state.pts = 1 after logging in, the user or bot can - # `catch_up` on all updates (and obtain necessary access hashes) - # if they desire. The date parameter is ignored when pts = 1. - self._old_state = types.updates.State( - 1, 0, datetime.datetime.now(tz=datetime.timezone.utc), 0, 0) + # `catch_up` will getDifference from pts = 1, date = 1 (ignored) + # to fetch all updates (and obtain necessary access hashes) if + # the ``pts is None``. + self._old_pts_date = (None, None) return user @@ -437,8 +436,8 @@ class AuthMethods(MessageParseMethods, UserMethods): self._bot = None self._self_input_peer = None self._authorized = False - self._old_state = None - self._new_state = None + self._old_pts_date = (None, None) + self._new_pts_date = (None, None) await self.disconnect() self.session.delete() diff --git a/telethon/client/telegrambaseclient.py b/telethon/client/telegrambaseclient.py index d58192ba..5e8004ae 100644 --- a/telethon/client/telegrambaseclient.py +++ b/telethon/client/telegrambaseclient.py @@ -306,15 +306,13 @@ class TelegramBaseClient(abc.ABC): self._authorized = None # None = unknown, False = no, True = yes # Update state (for catching up after a disconnection) - self._old_state = self.session.get_update_state(0) - self._new_state = None - - # If we catch up, while we don't get disconnected, - # the old state will be the same as the new one. # - # If we do get disconnected, then the old and new - # state may differ. - self._old_state_is_new = False + # We only care about the pts and the date. By using a tuple which + # is lightweight and immutable we can easily copy them around to + # each update in case they need to fetch missing entities. + state = self.session.get_update_state(0) + self._old_pts_date = state.pts, state.date + self._new_pts_date = (None, None) # Some further state for subclasses self._event_builders = [] @@ -397,11 +395,15 @@ class TelegramBaseClient(abc.ABC): async def _disconnect_coro(self): await self._disconnect() - # If we disconnect, the old state is the last one we are aware of - self._old_state_is_new = True - - if self._new_state: - self.session.set_update_state(0, self._new_state) + pts, date = self._new_pts_date + if pts: + self.session.set_update_state(0, types.updates.State( + pts=pts, + qts=0, + date=date or datetime.now(), + seq=0, + unread_count=0 + )) self.session.close() diff --git a/telethon/client/updates.py b/telethon/client/updates.py index d836585c..03a40074 100644 --- a/telethon/client/updates.py +++ b/telethon/client/updates.py @@ -135,23 +135,19 @@ class UpdateMethods(UserMethods): This can also be used to forcibly fetch new updates if there are any. """ - state = self._new_state if self._old_state_is_new else self._old_state - if not self._old_state_is_new and self._new_state: - max_pts = self._new_state.pts + # TODO Since which state should we catch up? + if all(self._new_pts_date): + pts, date = self._new_pts_date + elif all(self._old_pts_date): + pts, date = self._new_pts_date else: - max_pts = float('inf') - - # No known state -> catch up since the beginning (date is ignored). - # Note: pts = 0 is invalid (and so is no date/unix timestamp = 0). - if not state: - state = types.updates.State( - 1, 0, datetime.datetime.now(tz=datetime.timezone.utc), 0, 0) + return self.session.catching_up = True try: while True: d = await self(functions.updates.GetDifferenceRequest( - state.pts, state.date, state.qts + pts, date, 0 )) if isinstance(d, (types.updates.DifferenceSlice, types.updates.Difference)): @@ -160,7 +156,8 @@ class UpdateMethods(UserMethods): else: state = d.intermediate_state - await self._handle_update(types.Updates( + pts, date = state.pts, state.date + self._handle_update(types.Updates( users=d.users, chats=d.chats, date=state.date, @@ -171,6 +168,7 @@ class UpdateMethods(UserMethods): ] )) + # TODO Implement upper limit (max_pts) # We don't want to fetch updates we already know about. # # We may still get duplicates because the Difference @@ -184,29 +182,27 @@ class UpdateMethods(UserMethods): # there would be duplicate updates since we know about # some). This can be used to detect collisions (i.e. # it would return an update we have already seen). - if state.pts >= max_pts: - break else: if isinstance(d, types.updates.DifferenceEmpty): - state.date = d.date - state.seq = d.seq + date = d.date elif isinstance(d, types.updates.DifferenceTooLong): - state.pts = d.pts + pts = d.pts break except (ConnectionError, asyncio.CancelledError): pass finally: - self._old_state = None - self._new_state = state - self._old_state_is_new = True - self.session.set_update_state(0, state) + # TODO Save new pts to session + self._new_pts_date = (pts, date) self.session.catching_up = False # endregion # region Private methods - async def _handle_update(self, update): + # It is important to not make _handle_update async because we rely on + # the order that the updates arrive in to update the pts and date to + # be always-increasing. There is also no need to make this async. + def _handle_update(self, update): self.session.process_entities(update) self._entity_cache.add(update) @@ -214,40 +210,39 @@ class UpdateMethods(UserMethods): entities = {utils.get_peer_id(x): x for x in itertools.chain(update.users, update.chats)} for u in update.updates: - u._entities = entities - await self._handle_update(u) + self._process_update(u, entities) + + self._new_pts_date = (self._new_pts_date[0], update.date) elif isinstance(update, types.UpdateShort): - await self._handle_update(update.update) + self._process_update(update.update) + self._new_pts_date = (self._new_pts_date[0], update.date) else: - update._entities = getattr(update, '_entities', {}) - if self._updates_queue is None: - self._loop.create_task(self._dispatch_update(update)) - else: - self._updates_queue.put_nowait(update) - if not self._dispatching_updates_queue.is_set(): - self._dispatching_updates_queue.set() - self._loop.create_task(self._dispatch_queue_updates()) + self._process_update(update) - # TODO make use of need_diff - need_diff = False + # TODO Should this be done before or after? + self._update_pts_date(update) + + def _process_update(self, update, entities=None): + update._entities = entities or {} + if self._updates_queue is None: + self._loop.create_task(self._dispatch_update(update)) + else: + self._updates_queue.put_nowait(update) + if not self._dispatching_updates_queue.is_set(): + self._dispatching_updates_queue.set() + self._loop.create_task(self._dispatch_queue_updates()) + + self._update_pts_date(update) + + def _update_pts_date(self, update): + pts, date = self._new_pts_date if getattr(update, 'pts', None): - if not self._new_state: - self._new_state = types.updates.State( - update.pts, - 0, - getattr(update, 'date', datetime.datetime.now(tz=datetime.timezone.utc)), - getattr(update, 'seq', 0), - 0 - ) - else: - if self._new_state.pts and (update.pts - self._new_state.pts) > 1: - need_diff = True + pts = update.pts - self._new_state.pts = update.pts - if hasattr(update, 'date'): - self._new_state.date = update.date - if hasattr(update, 'seq'): - self._new_state.seq = update.seq + if getattr(update, 'date', None): + date = update.date + + self._new_pts_date = (pts, date) async def _update_loop(self): # Pings' ID don't really need to be secure, just "random" @@ -368,10 +363,6 @@ class UpdateMethods(UserMethods): # If a disconnection occurs, the old known state will be # the latest one we were aware of, so we can catch up since # the most recent state we were aware of. - # TODO Ideally we set _old_state = _new_state *on* disconnect, - # not *after* we managed to reconnect since perhaps an update - # arrives just before we can get started. - self._old_state_is_new = True await self.catch_up() self._log[__name__].info('Successfully fetched missed updates') diff --git a/telethon/network/mtprotosender.py b/telethon/network/mtprotosender.py index 7f968197..9672c5a3 100644 --- a/telethon/network/mtprotosender.py +++ b/telethon/network/mtprotosender.py @@ -547,7 +547,7 @@ class MTProtoSender: self._log.debug('Handling update {}' .format(message.obj.__class__.__name__)) if self._update_callback: - await self._update_callback(message.obj) + self._update_callback(message.obj) async def _handle_pong(self, message): """