mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2024-11-26 03:13:45 +03:00
Persist session state and usage fixes
Catching up is now an option when creating the client.
This commit is contained in:
parent
015acf20c6
commit
f547a00da3
|
@ -91,6 +91,7 @@ def init(
|
|||
request_retries: int = 4,
|
||||
flood_sleep_threshold: int = 60,
|
||||
# Update handling.
|
||||
catch_up: bool = False,
|
||||
receive_updates: bool = True,
|
||||
max_queued_updates: int = 100,
|
||||
):
|
||||
|
@ -142,6 +143,7 @@ def init(
|
|||
self._parse_mode = markdown
|
||||
|
||||
# Update handling.
|
||||
self._catch_up = catch_up
|
||||
self._no_updates = not receive_updates
|
||||
self._updates_queue = asyncio.Queue(maxsize=max_queued_updates)
|
||||
self._updates_handle = None
|
||||
|
@ -232,6 +234,8 @@ async def connect(self: 'TelegramClient') -> None:
|
|||
)
|
||||
else:
|
||||
try_fetch_user = self._session_state.user_id == 0
|
||||
if self._catch_up:
|
||||
self._message_box.load(self._session_state, await self._session.get_all_channel_states())
|
||||
|
||||
dc = all_dcs.get(self._session_state.dc_id)
|
||||
if dc is None:
|
||||
|
@ -358,6 +362,15 @@ async def _disconnect(self: 'TelegramClient'):
|
|||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
await self._session.insert_entities(self._entity_cache.get_all_entities())
|
||||
|
||||
session_state, channel_states = self._message_box.session_state()
|
||||
for channel_id, pts in channel_states.items():
|
||||
await self._session.insert_channel_state(channel_id, pts)
|
||||
|
||||
await self._replace_session_state(**session_state)
|
||||
|
||||
|
||||
async def _switch_dc(self: 'TelegramClient', new_dc):
|
||||
"""
|
||||
Permanently switches the current connection to the new data center.
|
||||
|
|
|
@ -245,7 +245,7 @@ class SQLiteSession(Session):
|
|||
try:
|
||||
c.executemany(
|
||||
'insert or replace into entity values (?,?,?)',
|
||||
[(e.id, e.access_hash, e.ty) for e in entities]
|
||||
[(e.id, e.access_hash, e.ty.value) for e in entities]
|
||||
)
|
||||
finally:
|
||||
c.close()
|
||||
|
|
|
@ -95,3 +95,6 @@ class EntityCache:
|
|||
for c in chats
|
||||
if getattr(c, 'access_hash', None) and not getattr(c, 'min', None)
|
||||
)
|
||||
|
||||
def get_all_entities(self):
|
||||
return [Entity(ty, id, hash) for id, (hash, ty) in self.hash_map.items()]
|
||||
|
|
|
@ -141,46 +141,38 @@ class MessageBox:
|
|||
|
||||
# region Creation, querying, and setting base state.
|
||||
|
||||
@classmethod
|
||||
def load(cls, session_state, channel_states):
|
||||
def load(self, session_state, channel_states):
|
||||
"""
|
||||
Create a [`MessageBox`] from a previously known update state.
|
||||
"""
|
||||
deadline = next_updates_deadline()
|
||||
return cls(
|
||||
map={
|
||||
ENTRY_ACCOUNT: State(pts=session_state.pts, deadline=deadline),
|
||||
ENTRY_SECRET: State(pts=session_state.qts, deadline=deadline),
|
||||
**{s.channel_id: s.pts for s in channel_states}
|
||||
},
|
||||
date=session_state.date,
|
||||
seq=session_state.seq,
|
||||
next_deadline=ENTRY_ACCOUNT,
|
||||
)
|
||||
self.map = {
|
||||
ENTRY_ACCOUNT: State(pts=session_state.pts, deadline=deadline),
|
||||
ENTRY_SECRET: State(pts=session_state.qts, deadline=deadline),
|
||||
**{s.channel_id: State(pts=s.pts, deadline=deadline) for s in channel_states}
|
||||
}
|
||||
self.date = session_state.date
|
||||
self.seq = session_state.seq
|
||||
self.next_deadline = ENTRY_ACCOUNT
|
||||
|
||||
@classmethod
|
||||
def session_state(self):
|
||||
"""
|
||||
Return the current state in a format that sessions understand.
|
||||
Return the current state.
|
||||
|
||||
This should be used for persisting the state.
|
||||
"""
|
||||
return SessionState(
|
||||
user_id=0,
|
||||
dc_id=0,
|
||||
bot=False,
|
||||
pts=self.map.get(ENTRY_ACCOUNT, 0),
|
||||
qts=self.map.get(ENTRY_SECRET, 0),
|
||||
return dict(
|
||||
pts=self.map[ENTRY_ACCOUNT].pts if ENTRY_ACCOUNT in self.map else 0,
|
||||
qts=self.map[ENTRY_SECRET].pts if ENTRY_SECRET in self.map else 0,
|
||||
date=self.date,
|
||||
seq=self.seq,
|
||||
takeout_id=None,
|
||||
), [ChannelState(channel_id=id, pts=pts) for id, pts in self.map.items() if isinstance(id, int)]
|
||||
), {id: state.pts for id, state in self.map.items() if isinstance(id, int)}
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""
|
||||
Return true if the message box is empty and has no state yet.
|
||||
"""
|
||||
return self.map.get(ENTRY_ACCOUNT, NO_SEQ) == NO_SEQ
|
||||
return ENTRY_ACCOUNT not in self.map or self.map[ENTRY_ACCOUNT] == NO_SEQ
|
||||
|
||||
def check_deadlines(self):
|
||||
"""
|
||||
|
@ -200,7 +192,7 @@ class MessageBox:
|
|||
if self.possible_gaps:
|
||||
deadline = min(deadline, *(gap.deadline for gap in self.possible_gaps.values()))
|
||||
elif self.next_deadline in self.map:
|
||||
deadline = min(deadline, self.map[self.next_deadline])
|
||||
deadline = min(deadline, self.map[self.next_deadline].deadline)
|
||||
|
||||
if now > deadline:
|
||||
# Check all expired entries and add them to the list that needs getting difference.
|
||||
|
|
Loading…
Reference in New Issue
Block a user