Unify setting session state

This commit is contained in:
Lonami Exo 2022-01-15 11:22:33 +01:00
parent be0da9b183
commit 7524b652c8
4 changed files with 31 additions and 21 deletions

View File

@ -1,6 +1,7 @@
import functools
import inspect
import typing
import dataclasses
from contextvars import ContextVar
from .users import _NOT_A_REQUEST
@ -47,7 +48,7 @@ async def begin_takeout(
if takeout_active():
raise ValueError('a previous takeout session was already active')
self._session_state.takeout_id = (await client(
await self._replace_session_state(takeout_id=(await client(
contacts=contacts,
message_users=users,
message_chats=chats,
@ -55,7 +56,7 @@ async def begin_takeout(
message_channels=channels,
files=files,
file_max_size=max_file_size
)).id
)).id)
def takeout_active(self: 'TelegramClient') -> bool:
@ -70,4 +71,4 @@ async def end_takeout(self: 'TelegramClient', success: bool) -> bool:
if not result:
raise ValueError("could not end the active takeout session")
self._session_state.takeout_id = None
await self._replace_session_state(takeout_id=None)

View File

@ -5,6 +5,7 @@ import sys
import typing
import warnings
import functools
import dataclasses
from .._misc import utils, helpers, password as pwd_mod
from .. import errors, _tl
@ -308,6 +309,7 @@ async def sign_up(
return await _update_session_state(self, result.user)
async def _update_session_state(self, user, save=True):
"""
Callback called whenever the login or sign up process completes.
@ -315,20 +317,29 @@ async def _update_session_state(self, user, save=True):
"""
self._authorized = True
self._session_state.user_id = user.id
self._session_state.bot = user.bot
state = await self(_tl.fn.updates.GetState())
self._session_state.pts = state.pts
self._session_state.qts = state.qts
self._session_state.date = int(state.date.timestamp())
self._session_state.seq = state.seq
await _replace_session_state(
self,
save=save,
user_id=user.id,
bot=user.bot,
pts=state.pts,
qts=state.qts,
date=int(state.date.timestamp()),
seq=state.seq,
)
return user
async def _replace_session_state(self, *, save=True, **changes):
new = dataclasses.replace(self._session_state, **changes)
await self.session.set_state(new)
self._session_state = new
await self.session.set_state(self._session_state)
if save:
await self.session.save()
return user
async def send_code_request(
self: 'TelegramClient',

View File

@ -445,10 +445,7 @@ async def _disconnect_coro(self: 'TelegramClient'):
pts, date = self._state_cache[None]
if pts and date:
if self._session_state:
self._session_state.pts = pts
self._session_state.date = date
await self.session.set_state(self._session_state)
await self.session.save()
await self._replace_session_state(pts=pts, date=date)
async def _disconnect(self: 'TelegramClient'):
"""
@ -467,10 +464,7 @@ async def _switch_dc(self: 'TelegramClient', new_dc):
"""
self._log[__name__].info('Reconnecting to new data center %s', new_dc)
self._session_state.dc_id = new_dc
await self.session.set_state(self._session_state)
await self.session.save()
await self._replace_session_state(dc_id=new_dc)
await _disconnect(self)
return await self.connect()

View File

@ -3547,7 +3547,11 @@ class TelegramClient:
pass
@forward_call(auth._update_session_state)
async def _update_session_state(self, user, save=True):
async def _update_session_state(self, user, *, save=True):
pass
@forward_call(auth._replace_session_state)
async def _replace_session_state(self, *, save=True, **changes):
pass
# endregion Private