Persist session state and usage fixes

Catching up is now an option when creating the client.
This commit is contained in:
Lonami Exo 2022-01-23 19:46:37 +01:00
parent 015acf20c6
commit f547a00da3
4 changed files with 33 additions and 25 deletions

View File

@ -91,6 +91,7 @@ def init(
request_retries: int = 4, request_retries: int = 4,
flood_sleep_threshold: int = 60, flood_sleep_threshold: int = 60,
# Update handling. # Update handling.
catch_up: bool = False,
receive_updates: bool = True, receive_updates: bool = True,
max_queued_updates: int = 100, max_queued_updates: int = 100,
): ):
@ -142,6 +143,7 @@ def init(
self._parse_mode = markdown self._parse_mode = markdown
# Update handling. # Update handling.
self._catch_up = catch_up
self._no_updates = not receive_updates self._no_updates = not receive_updates
self._updates_queue = asyncio.Queue(maxsize=max_queued_updates) self._updates_queue = asyncio.Queue(maxsize=max_queued_updates)
self._updates_handle = None self._updates_handle = None
@ -232,6 +234,8 @@ async def connect(self: 'TelegramClient') -> None:
) )
else: else:
try_fetch_user = self._session_state.user_id == 0 try_fetch_user = self._session_state.user_id == 0
if self._catch_up:
self._message_box.load(self._session_state, await self._session.get_all_channel_states())
dc = all_dcs.get(self._session_state.dc_id) dc = all_dcs.get(self._session_state.dc_id)
if dc is None: if dc is None:
@ -358,6 +362,15 @@ async def _disconnect(self: 'TelegramClient'):
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
await self._session.insert_entities(self._entity_cache.get_all_entities())
session_state, channel_states = self._message_box.session_state()
for channel_id, pts in channel_states.items():
await self._session.insert_channel_state(channel_id, pts)
await self._replace_session_state(**session_state)
async def _switch_dc(self: 'TelegramClient', new_dc): async def _switch_dc(self: 'TelegramClient', new_dc):
""" """
Permanently switches the current connection to the new data center. Permanently switches the current connection to the new data center.

View File

@ -245,7 +245,7 @@ class SQLiteSession(Session):
try: try:
c.executemany( c.executemany(
'insert or replace into entity values (?,?,?)', 'insert or replace into entity values (?,?,?)',
[(e.id, e.access_hash, e.ty) for e in entities] [(e.id, e.access_hash, e.ty.value) for e in entities]
) )
finally: finally:
c.close() c.close()

View File

@ -95,3 +95,6 @@ class EntityCache:
for c in chats for c in chats
if getattr(c, 'access_hash', None) and not getattr(c, 'min', None) if getattr(c, 'access_hash', None) and not getattr(c, 'min', None)
) )
def get_all_entities(self):
return [Entity(ty, id, hash) for id, (hash, ty) in self.hash_map.items()]

View File

@ -141,46 +141,38 @@ class MessageBox:
# region Creation, querying, and setting base state. # region Creation, querying, and setting base state.
@classmethod def load(self, session_state, channel_states):
def load(cls, session_state, channel_states):
""" """
Create a [`MessageBox`] from a previously known update state. Create a [`MessageBox`] from a previously known update state.
""" """
deadline = next_updates_deadline() deadline = next_updates_deadline()
return cls( self.map = {
map={
ENTRY_ACCOUNT: State(pts=session_state.pts, deadline=deadline), ENTRY_ACCOUNT: State(pts=session_state.pts, deadline=deadline),
ENTRY_SECRET: State(pts=session_state.qts, deadline=deadline), ENTRY_SECRET: State(pts=session_state.qts, deadline=deadline),
**{s.channel_id: s.pts for s in channel_states} **{s.channel_id: State(pts=s.pts, deadline=deadline) for s in channel_states}
}, }
date=session_state.date, self.date = session_state.date
seq=session_state.seq, self.seq = session_state.seq
next_deadline=ENTRY_ACCOUNT, self.next_deadline = ENTRY_ACCOUNT
)
@classmethod
def session_state(self): def session_state(self):
""" """
Return the current state in a format that sessions understand. Return the current state.
This should be used for persisting the state. This should be used for persisting the state.
""" """
return SessionState( return dict(
user_id=0, pts=self.map[ENTRY_ACCOUNT].pts if ENTRY_ACCOUNT in self.map else 0,
dc_id=0, qts=self.map[ENTRY_SECRET].pts if ENTRY_SECRET in self.map else 0,
bot=False,
pts=self.map.get(ENTRY_ACCOUNT, 0),
qts=self.map.get(ENTRY_SECRET, 0),
date=self.date, date=self.date,
seq=self.seq, seq=self.seq,
takeout_id=None, ), {id: state.pts for id, state in self.map.items() if isinstance(id, int)}
), [ChannelState(channel_id=id, pts=pts) for id, pts in self.map.items() if isinstance(id, int)]
def is_empty(self) -> bool: def is_empty(self) -> bool:
""" """
Return true if the message box is empty and has no state yet. Return true if the message box is empty and has no state yet.
""" """
return self.map.get(ENTRY_ACCOUNT, NO_SEQ) == NO_SEQ return ENTRY_ACCOUNT not in self.map or self.map[ENTRY_ACCOUNT] == NO_SEQ
def check_deadlines(self): def check_deadlines(self):
""" """
@ -200,7 +192,7 @@ class MessageBox:
if self.possible_gaps: if self.possible_gaps:
deadline = min(deadline, *(gap.deadline for gap in self.possible_gaps.values())) deadline = min(deadline, *(gap.deadline for gap in self.possible_gaps.values()))
elif self.next_deadline in self.map: elif self.next_deadline in self.map:
deadline = min(deadline, self.map[self.next_deadline]) deadline = min(deadline, self.map[self.next_deadline].deadline)
if now > deadline: if now > deadline:
# Check all expired entries and add them to the list that needs getting difference. # Check all expired entries and add them to the list that needs getting difference.