Better catch_up behaviour when invalid states are present

This commit is contained in:
Lonami Exo 2018-12-06 16:07:11 +01:00
parent 40730e7862
commit f9fc433c0f
4 changed files with 38 additions and 23 deletions

View File

@ -1,3 +1,4 @@
import datetime
import getpass
import hashlib
import inspect
@ -294,10 +295,7 @@ class AuthMethods(MessageParseMethods, UserMethods):
'and a password only if an RPCError was raised before.'
)
self._self_input_peer = utils.get_input_peer(
result.user, allow_self=False
)
return result.user
return self._on_login(result.user)
async def sign_up(self, code, first_name, last_name=''):
"""
@ -346,10 +344,24 @@ class AuthMethods(MessageParseMethods, UserMethods):
await self(
functions.help.AcceptTermsOfServiceRequest(self._tos.id))
self._self_input_peer = utils.get_input_peer(
result.user, allow_self=False
)
return result.user
return self._on_login(result.user)
def _on_login(self, user):
"""
Callback called whenever the login or sign up process completes.
Returns the input user parameter.
"""
self._self_input_peer = utils.get_input_peer(user, allow_self=False)
self._authorized = True
# By setting state.pts = 1 after logging in, the user or bot can
# `catch_up` on all updates (and obtain necessary access hashes)
# if they desire. The date parameter is ignored when pts = 1.
self._state.pts = 1
self._state.date = datetime.datetime.now()
return user
async def send_code_request(self, phone, *, force_sms=False):
"""
@ -403,7 +415,8 @@ class AuthMethods(MessageParseMethods, UserMethods):
return False
self._self_input_peer = None
self._state.pts = -1
self._authorized = False
self._state = types.updates.State(0, 0, datetime.datetime.now(), 0, 0)
self.disconnect()
self.session.delete()
return True

View File

@ -269,9 +269,9 @@ class TelegramBaseClient(abc.ABC):
self._updates_queue = None
self._dispatching_updates_queue = None
# Start with invalid state (-1) so we can have somewhere to store
# the state, but also be able to determine if we are authorized.
self._state = types.updates.State(-1, 0, datetime.now(), 0, -1)
self._authorized = None # None = unknown, False = no, True = yes
self._state = (self.session.get_update_state(0)
or types.updates.State(0, 0, datetime.now(), 0, 0))
# Some further state for subclasses
self._event_builders = []

View File

@ -132,9 +132,10 @@ class UpdateMethods(UserMethods):
This can also be used to forcibly fetch new updates if there are any.
"""
state = self.session.get_update_state(0)
if not state or state.pts <= 0:
state = await self(functions.updates.GetStateRequest())
state = self._state
if state.pts == 0:
# pts = 0 is invalid, pts = 1 will catch up since the beginning
state.pts = 1
self.session.catching_up = True
try:
@ -167,6 +168,7 @@ class UpdateMethods(UserMethods):
state.pts = d.pts
break
finally:
self._state = state
self.session.set_update_state(0, state)
self.session.catching_up = False

View File

@ -5,8 +5,8 @@ import time
from .telegrambaseclient import TelegramBaseClient
from .. import errors, utils
from ..tl import TLObject, TLRequest, types, functions
from ..errors import MultiError, RPCError
from ..tl import TLObject, TLRequest, types, functions
__log__ = logging.getLogger(__name__)
_NOT_A_REQUEST = TypeError('You can only invoke requests, not types!')
@ -123,14 +123,14 @@ class UserMethods(TelegramBaseClient):
"""
Returns ``True`` if the user is authorized.
"""
if self._self_input_peer is not None or self._state.pts != -1:
return True
if self._authorized is None:
try:
self._state = await self(functions.updates.GetStateRequest())
self._authorized = True
except errors.RPCError:
self._authorized = False
try:
self._state = await self(functions.updates.GetStateRequest())
return True
except errors.RPCError:
return False
return self._authorized
async def get_entity(self, entity):
"""