Better catch_up behaviour when invalid states are present

This commit is contained in:
Lonami Exo 2018-12-06 16:07:11 +01:00
parent 40730e7862
commit f9fc433c0f
4 changed files with 38 additions and 23 deletions

View File

@ -1,3 +1,4 @@
import datetime
import getpass import getpass
import hashlib import hashlib
import inspect import inspect
@ -294,10 +295,7 @@ class AuthMethods(MessageParseMethods, UserMethods):
'and a password only if an RPCError was raised before.' 'and a password only if an RPCError was raised before.'
) )
self._self_input_peer = utils.get_input_peer( return self._on_login(result.user)
result.user, allow_self=False
)
return result.user
async def sign_up(self, code, first_name, last_name=''): async def sign_up(self, code, first_name, last_name=''):
""" """
@ -346,10 +344,24 @@ class AuthMethods(MessageParseMethods, UserMethods):
await self( await self(
functions.help.AcceptTermsOfServiceRequest(self._tos.id)) functions.help.AcceptTermsOfServiceRequest(self._tos.id))
self._self_input_peer = utils.get_input_peer( return self._on_login(result.user)
result.user, allow_self=False
) def _on_login(self, user):
return result.user """
Callback called whenever the login or sign up process completes.
Returns the input user parameter.
"""
self._self_input_peer = utils.get_input_peer(user, allow_self=False)
self._authorized = True
# By setting state.pts = 1 after logging in, the user or bot can
# `catch_up` on all updates (and obtain necessary access hashes)
# if they desire. The date parameter is ignored when pts = 1.
self._state.pts = 1
self._state.date = datetime.datetime.now()
return user
async def send_code_request(self, phone, *, force_sms=False): async def send_code_request(self, phone, *, force_sms=False):
""" """
@ -403,7 +415,8 @@ class AuthMethods(MessageParseMethods, UserMethods):
return False return False
self._self_input_peer = None self._self_input_peer = None
self._state.pts = -1 self._authorized = False
self._state = types.updates.State(0, 0, datetime.datetime.now(), 0, 0)
self.disconnect() self.disconnect()
self.session.delete() self.session.delete()
return True return True

View File

@ -269,9 +269,9 @@ class TelegramBaseClient(abc.ABC):
self._updates_queue = None self._updates_queue = None
self._dispatching_updates_queue = None self._dispatching_updates_queue = None
# Start with invalid state (-1) so we can have somewhere to store self._authorized = None # None = unknown, False = no, True = yes
# the state, but also be able to determine if we are authorized. self._state = (self.session.get_update_state(0)
self._state = types.updates.State(-1, 0, datetime.now(), 0, -1) or types.updates.State(0, 0, datetime.now(), 0, 0))
# Some further state for subclasses # Some further state for subclasses
self._event_builders = [] self._event_builders = []

View File

@ -132,9 +132,10 @@ class UpdateMethods(UserMethods):
This can also be used to forcibly fetch new updates if there are any. This can also be used to forcibly fetch new updates if there are any.
""" """
state = self.session.get_update_state(0) state = self._state
if not state or state.pts <= 0: if state.pts == 0:
state = await self(functions.updates.GetStateRequest()) # pts = 0 is invalid, pts = 1 will catch up since the beginning
state.pts = 1
self.session.catching_up = True self.session.catching_up = True
try: try:
@ -167,6 +168,7 @@ class UpdateMethods(UserMethods):
state.pts = d.pts state.pts = d.pts
break break
finally: finally:
self._state = state
self.session.set_update_state(0, state) self.session.set_update_state(0, state)
self.session.catching_up = False self.session.catching_up = False

View File

@ -5,8 +5,8 @@ import time
from .telegrambaseclient import TelegramBaseClient from .telegrambaseclient import TelegramBaseClient
from .. import errors, utils from .. import errors, utils
from ..tl import TLObject, TLRequest, types, functions
from ..errors import MultiError, RPCError from ..errors import MultiError, RPCError
from ..tl import TLObject, TLRequest, types, functions
__log__ = logging.getLogger(__name__) __log__ = logging.getLogger(__name__)
_NOT_A_REQUEST = TypeError('You can only invoke requests, not types!') _NOT_A_REQUEST = TypeError('You can only invoke requests, not types!')
@ -123,14 +123,14 @@ class UserMethods(TelegramBaseClient):
""" """
Returns ``True`` if the user is authorized. Returns ``True`` if the user is authorized.
""" """
if self._self_input_peer is not None or self._state.pts != -1: if self._authorized is None:
return True try:
self._state = await self(functions.updates.GetStateRequest())
self._authorized = True
except errors.RPCError:
self._authorized = False
try: return self._authorized
self._state = await self(functions.updates.GetStateRequest())
return True
except errors.RPCError:
return False
async def get_entity(self, entity): async def get_entity(self, entity):
""" """