From f9fc433c0f655876eed1edc158bac12404adabec Mon Sep 17 00:00:00 2001 From: Lonami Exo Date: Thu, 6 Dec 2018 16:07:11 +0100 Subject: [PATCH] Better catch_up behaviour when invalid states are present --- telethon/client/auth.py | 31 +++++++++++++++++++-------- telethon/client/telegrambaseclient.py | 6 +++--- telethon/client/updates.py | 8 ++++--- telethon/client/users.py | 16 +++++++------- 4 files changed, 38 insertions(+), 23 deletions(-) diff --git a/telethon/client/auth.py b/telethon/client/auth.py index 03d31e59..86145240 100644 --- a/telethon/client/auth.py +++ b/telethon/client/auth.py @@ -1,3 +1,4 @@ +import datetime import getpass import hashlib import inspect @@ -294,10 +295,7 @@ class AuthMethods(MessageParseMethods, UserMethods): 'and a password only if an RPCError was raised before.' ) - self._self_input_peer = utils.get_input_peer( - result.user, allow_self=False - ) - return result.user + return self._on_login(result.user) async def sign_up(self, code, first_name, last_name=''): """ @@ -346,10 +344,24 @@ class AuthMethods(MessageParseMethods, UserMethods): await self( functions.help.AcceptTermsOfServiceRequest(self._tos.id)) - self._self_input_peer = utils.get_input_peer( - result.user, allow_self=False - ) - return result.user + return self._on_login(result.user) + + def _on_login(self, user): + """ + Callback called whenever the login or sign up process completes. + + Returns the input user parameter. + """ + self._self_input_peer = utils.get_input_peer(user, allow_self=False) + self._authorized = True + + # By setting state.pts = 1 after logging in, the user or bot can + # `catch_up` on all updates (and obtain necessary access hashes) + # if they desire. The date parameter is ignored when pts = 1. + self._state.pts = 1 + self._state.date = datetime.datetime.now() + + return user async def send_code_request(self, phone, *, force_sms=False): """ @@ -403,7 +415,8 @@ class AuthMethods(MessageParseMethods, UserMethods): return False self._self_input_peer = None - self._state.pts = -1 + self._authorized = False + self._state = types.updates.State(0, 0, datetime.datetime.now(), 0, 0) self.disconnect() self.session.delete() return True diff --git a/telethon/client/telegrambaseclient.py b/telethon/client/telegrambaseclient.py index 7fd60dd2..0f111f9a 100644 --- a/telethon/client/telegrambaseclient.py +++ b/telethon/client/telegrambaseclient.py @@ -269,9 +269,9 @@ class TelegramBaseClient(abc.ABC): self._updates_queue = None self._dispatching_updates_queue = None - # Start with invalid state (-1) so we can have somewhere to store - # the state, but also be able to determine if we are authorized. - self._state = types.updates.State(-1, 0, datetime.now(), 0, -1) + self._authorized = None # None = unknown, False = no, True = yes + self._state = (self.session.get_update_state(0) + or types.updates.State(0, 0, datetime.now(), 0, 0)) # Some further state for subclasses self._event_builders = [] diff --git a/telethon/client/updates.py b/telethon/client/updates.py index 5bb47b24..619e73d0 100644 --- a/telethon/client/updates.py +++ b/telethon/client/updates.py @@ -132,9 +132,10 @@ class UpdateMethods(UserMethods): This can also be used to forcibly fetch new updates if there are any. """ - state = self.session.get_update_state(0) - if not state or state.pts <= 0: - state = await self(functions.updates.GetStateRequest()) + state = self._state + if state.pts == 0: + # pts = 0 is invalid, pts = 1 will catch up since the beginning + state.pts = 1 self.session.catching_up = True try: @@ -167,6 +168,7 @@ class UpdateMethods(UserMethods): state.pts = d.pts break finally: + self._state = state self.session.set_update_state(0, state) self.session.catching_up = False diff --git a/telethon/client/users.py b/telethon/client/users.py index e198303f..9e2029f0 100644 --- a/telethon/client/users.py +++ b/telethon/client/users.py @@ -5,8 +5,8 @@ import time from .telegrambaseclient import TelegramBaseClient from .. import errors, utils -from ..tl import TLObject, TLRequest, types, functions from ..errors import MultiError, RPCError +from ..tl import TLObject, TLRequest, types, functions __log__ = logging.getLogger(__name__) _NOT_A_REQUEST = TypeError('You can only invoke requests, not types!') @@ -123,14 +123,14 @@ class UserMethods(TelegramBaseClient): """ Returns ``True`` if the user is authorized. """ - if self._self_input_peer is not None or self._state.pts != -1: - return True + if self._authorized is None: + try: + self._state = await self(functions.updates.GetStateRequest()) + self._authorized = True + except errors.RPCError: + self._authorized = False - try: - self._state = await self(functions.updates.GetStateRequest()) - return True - except errors.RPCError: - return False + return self._authorized async def get_entity(self, entity): """