Persist session state and usage fixes

Catching up is now an option when creating the client.
This commit is contained in:
Lonami Exo 2022-01-23 19:46:37 +01:00
parent 015acf20c6
commit f547a00da3
4 changed files with 33 additions and 25 deletions

View File

@ -91,6 +91,7 @@ def init(
request_retries: int = 4,
flood_sleep_threshold: int = 60,
# Update handling.
catch_up: bool = False,
receive_updates: bool = True,
max_queued_updates: int = 100,
):
@ -142,6 +143,7 @@ def init(
self._parse_mode = markdown
# Update handling.
self._catch_up = catch_up
self._no_updates = not receive_updates
self._updates_queue = asyncio.Queue(maxsize=max_queued_updates)
self._updates_handle = None
@ -232,6 +234,8 @@ async def connect(self: 'TelegramClient') -> None:
)
else:
try_fetch_user = self._session_state.user_id == 0
if self._catch_up:
self._message_box.load(self._session_state, await self._session.get_all_channel_states())
dc = all_dcs.get(self._session_state.dc_id)
if dc is None:
@ -358,6 +362,15 @@ async def _disconnect(self: 'TelegramClient'):
except asyncio.CancelledError:
pass
await self._session.insert_entities(self._entity_cache.get_all_entities())
session_state, channel_states = self._message_box.session_state()
for channel_id, pts in channel_states.items():
await self._session.insert_channel_state(channel_id, pts)
await self._replace_session_state(**session_state)
async def _switch_dc(self: 'TelegramClient', new_dc):
"""
Permanently switches the current connection to the new data center.

View File

@ -245,7 +245,7 @@ class SQLiteSession(Session):
try:
c.executemany(
'insert or replace into entity values (?,?,?)',
[(e.id, e.access_hash, e.ty) for e in entities]
[(e.id, e.access_hash, e.ty.value) for e in entities]
)
finally:
c.close()

View File

@ -95,3 +95,6 @@ class EntityCache:
for c in chats
if getattr(c, 'access_hash', None) and not getattr(c, 'min', None)
)
def get_all_entities(self):
return [Entity(ty, id, hash) for id, (hash, ty) in self.hash_map.items()]

View File

@ -141,46 +141,38 @@ class MessageBox:
# region Creation, querying, and setting base state.
@classmethod
def load(cls, session_state, channel_states):
def load(self, session_state, channel_states):
"""
Create a [`MessageBox`] from a previously known update state.
"""
deadline = next_updates_deadline()
return cls(
map={
ENTRY_ACCOUNT: State(pts=session_state.pts, deadline=deadline),
ENTRY_SECRET: State(pts=session_state.qts, deadline=deadline),
**{s.channel_id: s.pts for s in channel_states}
},
date=session_state.date,
seq=session_state.seq,
next_deadline=ENTRY_ACCOUNT,
)
self.map = {
ENTRY_ACCOUNT: State(pts=session_state.pts, deadline=deadline),
ENTRY_SECRET: State(pts=session_state.qts, deadline=deadline),
**{s.channel_id: State(pts=s.pts, deadline=deadline) for s in channel_states}
}
self.date = session_state.date
self.seq = session_state.seq
self.next_deadline = ENTRY_ACCOUNT
@classmethod
def session_state(self):
"""
Return the current state in a format that sessions understand.
Return the current state.
This should be used for persisting the state.
"""
return SessionState(
user_id=0,
dc_id=0,
bot=False,
pts=self.map.get(ENTRY_ACCOUNT, 0),
qts=self.map.get(ENTRY_SECRET, 0),
return dict(
pts=self.map[ENTRY_ACCOUNT].pts if ENTRY_ACCOUNT in self.map else 0,
qts=self.map[ENTRY_SECRET].pts if ENTRY_SECRET in self.map else 0,
date=self.date,
seq=self.seq,
takeout_id=None,
), [ChannelState(channel_id=id, pts=pts) for id, pts in self.map.items() if isinstance(id, int)]
), {id: state.pts for id, state in self.map.items() if isinstance(id, int)}
def is_empty(self) -> bool:
"""
Return true if the message box is empty and has no state yet.
"""
return self.map.get(ENTRY_ACCOUNT, NO_SEQ) == NO_SEQ
return ENTRY_ACCOUNT not in self.map or self.map[ENTRY_ACCOUNT] == NO_SEQ
def check_deadlines(self):
"""
@ -200,7 +192,7 @@ class MessageBox:
if self.possible_gaps:
deadline = min(deadline, *(gap.deadline for gap in self.possible_gaps.values()))
elif self.next_deadline in self.map:
deadline = min(deadline, self.map[self.next_deadline])
deadline = min(deadline, self.map[self.next_deadline].deadline)
if now > deadline:
# Check all expired entries and add them to the list that needs getting difference.