diff --git a/telethon/_client/telegrambaseclient.py b/telethon/_client/telegrambaseclient.py index b6872dc8..5c32d086 100644 --- a/telethon/_client/telegrambaseclient.py +++ b/telethon/_client/telegrambaseclient.py @@ -91,6 +91,7 @@ def init( request_retries: int = 4, flood_sleep_threshold: int = 60, # Update handling. + catch_up: bool = False, receive_updates: bool = True, max_queued_updates: int = 100, ): @@ -142,6 +143,7 @@ def init( self._parse_mode = markdown # Update handling. + self._catch_up = catch_up self._no_updates = not receive_updates self._updates_queue = asyncio.Queue(maxsize=max_queued_updates) self._updates_handle = None @@ -232,6 +234,8 @@ async def connect(self: 'TelegramClient') -> None: ) else: try_fetch_user = self._session_state.user_id == 0 + if self._catch_up: + self._message_box.load(self._session_state, await self._session.get_all_channel_states()) dc = all_dcs.get(self._session_state.dc_id) if dc is None: @@ -358,6 +362,15 @@ async def _disconnect(self: 'TelegramClient'): except asyncio.CancelledError: pass + await self._session.insert_entities(self._entity_cache.get_all_entities()) + + session_state, channel_states = self._message_box.session_state() + for channel_id, pts in channel_states.items(): + await self._session.insert_channel_state(channel_id, pts) + + await self._replace_session_state(**session_state) + + async def _switch_dc(self: 'TelegramClient', new_dc): """ Permanently switches the current connection to the new data center. diff --git a/telethon/_sessions/sqlite.py b/telethon/_sessions/sqlite.py index 2ea419be..b41975fb 100644 --- a/telethon/_sessions/sqlite.py +++ b/telethon/_sessions/sqlite.py @@ -245,7 +245,7 @@ class SQLiteSession(Session): try: c.executemany( 'insert or replace into entity values (?,?,?)', - [(e.id, e.access_hash, e.ty) for e in entities] + [(e.id, e.access_hash, e.ty.value) for e in entities] ) finally: c.close() diff --git a/telethon/_updates/entitycache.py b/telethon/_updates/entitycache.py index 176d2013..ce89eb4f 100644 --- a/telethon/_updates/entitycache.py +++ b/telethon/_updates/entitycache.py @@ -95,3 +95,6 @@ class EntityCache: for c in chats if getattr(c, 'access_hash', None) and not getattr(c, 'min', None) ) + + def get_all_entities(self): + return [Entity(ty, id, hash) for id, (hash, ty) in self.hash_map.items()] diff --git a/telethon/_updates/messagebox.py b/telethon/_updates/messagebox.py index 235247f3..232ef446 100644 --- a/telethon/_updates/messagebox.py +++ b/telethon/_updates/messagebox.py @@ -141,46 +141,38 @@ class MessageBox: # region Creation, querying, and setting base state. - @classmethod - def load(cls, session_state, channel_states): + def load(self, session_state, channel_states): """ Create a [`MessageBox`] from a previously known update state. """ deadline = next_updates_deadline() - return cls( - map={ - ENTRY_ACCOUNT: State(pts=session_state.pts, deadline=deadline), - ENTRY_SECRET: State(pts=session_state.qts, deadline=deadline), - **{s.channel_id: s.pts for s in channel_states} - }, - date=session_state.date, - seq=session_state.seq, - next_deadline=ENTRY_ACCOUNT, - ) + self.map = { + ENTRY_ACCOUNT: State(pts=session_state.pts, deadline=deadline), + ENTRY_SECRET: State(pts=session_state.qts, deadline=deadline), + **{s.channel_id: State(pts=s.pts, deadline=deadline) for s in channel_states} + } + self.date = session_state.date + self.seq = session_state.seq + self.next_deadline = ENTRY_ACCOUNT - @classmethod def session_state(self): """ - Return the current state in a format that sessions understand. + Return the current state. This should be used for persisting the state. """ - return SessionState( - user_id=0, - dc_id=0, - bot=False, - pts=self.map.get(ENTRY_ACCOUNT, 0), - qts=self.map.get(ENTRY_SECRET, 0), + return dict( + pts=self.map[ENTRY_ACCOUNT].pts if ENTRY_ACCOUNT in self.map else 0, + qts=self.map[ENTRY_SECRET].pts if ENTRY_SECRET in self.map else 0, date=self.date, seq=self.seq, - takeout_id=None, - ), [ChannelState(channel_id=id, pts=pts) for id, pts in self.map.items() if isinstance(id, int)] + ), {id: state.pts for id, state in self.map.items() if isinstance(id, int)} def is_empty(self) -> bool: """ Return true if the message box is empty and has no state yet. """ - return self.map.get(ENTRY_ACCOUNT, NO_SEQ) == NO_SEQ + return ENTRY_ACCOUNT not in self.map or self.map[ENTRY_ACCOUNT] == NO_SEQ def check_deadlines(self): """ @@ -200,7 +192,7 @@ class MessageBox: if self.possible_gaps: deadline = min(deadline, *(gap.deadline for gap in self.possible_gaps.values())) elif self.next_deadline in self.map: - deadline = min(deadline, self.map[self.next_deadline]) + deadline = min(deadline, self.map[self.next_deadline].deadline) if now > deadline: # Check all expired entries and add them to the list that needs getting difference.