diff --git a/telethon/client/telegrambaseclient.py b/telethon/client/telegrambaseclient.py index ef9bc866..1d47d564 100644 --- a/telethon/client/telegrambaseclient.py +++ b/telethon/client/telegrambaseclient.py @@ -605,6 +605,27 @@ class TelegramBaseClient(abc.ABC): else: connection._proxy = proxy + def _update_state_for(self, channel_id: 'typing.Optional[int]'): + if channel_id is None: + pts, qts, date = self._state_cache[None] + if pts and date: + self.session.set_update_state(0, types.updates.State( + pts=pts, + qts=qts, + date=date, + seq=0, + unread_count=0 + )) + else: + pts = self._state_cache[channel_id] + self.session.set_update_state(channel_id, types.updates.State( + pts=pts, + qts=0, + date=datetime.fromtimestamp(0), + seq=0, + unread_count=0 + )) + async def _disconnect_coro(self: 'TelegramClient'): await self._disconnect() @@ -632,23 +653,9 @@ class TelegramBaseClient(abc.ABC): await asyncio.wait(self._updates_queue) self._updates_queue.clear() - pts, qts, date = self._state_cache[None] - if pts and date: - self.session.set_update_state(0, types.updates.State( - pts=pts, - qts=qts, - date=date, - seq=0, - unread_count=0 - )) - for channel_id, pts in self._state_cache.get_channel_pts().items(): - self.session.set_update_state(channel_id, types.updates.State( - pts=pts, - qts=0, - date=datetime.fromtimestamp(0), - seq=0, - unread_count=0 - )) + self._update_state_for(None) + for channel_id in self._state_cache.get_channel_pts(): + self._update_state_for(channel_id) self.session.close() diff --git a/telethon/client/updates.py b/telethon/client/updates.py index 031882bf..af423b0a 100644 --- a/telethon/client/updates.py +++ b/telethon/client/updates.py @@ -220,8 +220,8 @@ class UpdateMethods: if not pts: # First-time, can't get difference. Get pts instead. result = await self(functions.channels.GetFullChannelRequest(channel_id)) - pts = self._state_cache[channel_id] = result.full_chat.pts - self.session.set_update_state(channel_id, types.updates.State(pts, 0, datetime.fromtimestamp(0), 0, 0)) + self._state_cache[channel_id] = result.full_chat.pts + self._update_state_for(channel_id) return try: while True: @@ -249,7 +249,8 @@ class UpdateMethods: # there is no way to get them without raising limit or GetHistoryRequest, so just break break finally: - self.session.set_update_state(channel_id, types.updates.State(pts, 0, datetime.fromtimestamp(0), 0, 0)) + self._state_cache[channel_id] = pts + self._update_state_for(channel_id) async def catch_up(self: 'TelegramClient', pts_total_limit=None, limit=None): """ @@ -320,7 +321,7 @@ class UpdateMethods: pass finally: self._state_cache[None] = (pts, qts, date) - self.session.set_update_state(0, types.updates.State(pts, qts, date, seq=0, unread_count=0)) + self._update_state_for(None) self.session.catching_up = False # endregion @@ -404,10 +405,10 @@ class UpdateMethods: except (ConnectionError, asyncio.CancelledError): return - # Entities and cached files are not saved when they are - # inserted because this is a rather expensive operation - # (default's sqlite3 takes ~0.1s to commit changes). Do - # it every minute instead. No-op if there's nothing new. + # Entities, cached files and update states are not saved + # when they are inserted because this is a rather expensive + # operation (default's sqlite3 takes ~0.1s to commit changes). + # Do it every minute instead. No-op if there's nothing new. self.session.save() # We need to send some content-related request at least hourly @@ -589,6 +590,7 @@ class UpdateMethods: utils.get_input_channel(where) )) self._state_cache[channel_id] = result.full_chat.pts + self._update_state_for(channel_id) return result = await self(functions.updates.GetChannelDifferenceRequest( @@ -603,6 +605,7 @@ class UpdateMethods: # First-time, can't get difference. Get pts instead. result = await self(functions.updates.GetStateRequest()) self._state_cache[None] = result.pts, result.qts, result.date + self._update_state_for(None) return result = await self(functions.updates.GetDifferenceRequest(