Move UpdatesThread from MtProtoSender to TelegramClient

This makes it easier to perform a proper reconnection
This commit is contained in:
Lonami Exo 2017-05-29 21:24:47 +02:00
parent ebe4232b32
commit 042e3069a9
2 changed files with 143 additions and 166 deletions

View File

@ -1,15 +1,12 @@
import gzip import gzip
from datetime import timedelta from datetime import timedelta
from threading import Event, RLock, Thread from threading import RLock
from time import sleep, time
from .. import helpers as utils from .. import helpers as utils
from ..crypto import AES from ..crypto import AES
from ..errors import (BadMessageError, FloodWaitError, RPCError, from ..errors import (BadMessageError, FloodWaitError,
InvalidDCError, ReadCancelledError) RPCError, InvalidDCError)
from ..tl.all_tlobjects import tlobjects from ..tl.all_tlobjects import tlobjects
from ..tl.functions import PingRequest
from ..tl.functions.updates import GetStateRequest
from ..tl.types import MsgsAck from ..tl.types import MsgsAck
from ..utils import BinaryReader, BinaryWriter from ..utils import BinaryReader, BinaryWriter
@ -26,64 +23,22 @@ class MtProtoSender:
self._logger = logging.getLogger(__name__) self._logger = logging.getLogger(__name__)
self._need_confirmation = [] # Message IDs that need confirmation self._need_confirmation = [] # Message IDs that need confirmation
self._on_update_handlers = []
# Store an RLock instance to make this class safely multi-threaded # Store an RLock instance to make this class safely multi-threaded
self._lock = RLock() self._lock = RLock()
# Flag used to determine whether we've received a sent request yet or not
# We need this to avoid using the updates thread if we're waiting to read
self._waiting_receive = Event()
# Used when logging out, the only request that seems to use 'ack' requests # Used when logging out, the only request that seems to use 'ack' requests
# TODO There might be a better way to handle msgs_ack requests # TODO There might be a better way to handle msgs_ack requests
self.logging_out = False self.logging_out = False
self.ping_interval = 60
self._ping_time_last = time()
# Flags used to determine the status of the updates thread.
self._updates_thread_running = Event()
self._updates_thread_receiving = Event()
# Sleep amount on "must sleep" error for the updates thread to sleep too
self._updates_thread_sleep = None
self._updates_thread = None # Set later
def connect(self): def connect(self):
"""Connects to the server""" """Connects to the server"""
self._transport.connect() self._transport.connect()
def disconnect(self): def disconnect(self):
"""Disconnects and **stops all the running threads** if any""" """Disconnects from the server"""
self._set_updates_thread(running=False)
self._transport.close() self._transport.close()
def reconnect(self):
"""Disconnects and connects again (effectively reconnecting)"""
self.disconnect()
self.connect()
def setup_ping_thread(self):
"""Sets up the Ping's thread, so that a connection can be kept
alive for a longer time without Telegram disconnecting us"""
self._updates_thread = Thread(
name='UpdatesThread', daemon=True,
target=self._updates_thread_method)
self._set_updates_thread(running=True)
def add_update_handler(self, handler):
"""Adds an update handler (a method with one argument, the received
TLObject) that is fired when there are updates available"""
# The updates thread is already running for periodic ping requests,
# so there is no need to start it when adding update handlers.
self._on_update_handlers.append(handler)
def remove_update_handler(self, handler):
self._on_update_handlers.remove(handler)
def _generate_sequence(self, confirmed): def _generate_sequence(self, confirmed):
"""Generates the next sequence number, based on whether it """Generates the next sequence number, based on whether it
was confirmed yet or not""" was confirmed yet or not"""
@ -96,28 +51,13 @@ class MtProtoSender:
# region Send and receive # region Send and receive
def send_ping(self):
"""Sends PingRequest"""
request = PingRequest(utils.generate_random_long())
self.send(request)
self.receive(request)
def send(self, request): def send(self, request):
"""Sends the specified MTProtoRequest, previously sending any message """Sends the specified MTProtoRequest, previously sending any message
which needed confirmation. This also pauses the updates thread""" which needed confirmation."""
# Only cancel the receive *if* it was the
# updates thread who was receiving. We do
# not want to cancel other pending requests!
if self._updates_thread_receiving.is_set():
self._logger.info('Cancelling updates receive from send()...')
self._transport.cancel_receive()
# Now only us can be using this method # Now only us can be using this method
with self._lock: with self._lock:
self._logger.debug('send() acquired the lock') self._logger.debug('send() acquired the lock')
# Set the flag to true so the updates thread stops trying to receive
self._waiting_receive.set()
# If any message needs confirmation send an AckRequest first # If any message needs confirmation send an AckRequest first
if self._need_confirmation: if self._need_confirmation:
@ -175,9 +115,6 @@ class MtProtoSender:
break # Request, and result read, exit break # Request, and result read, exit
self._logger.info('Request result received') self._logger.info('Request result received')
# We can now set the flag to False thus resuming the updates thread
self._waiting_receive.clear()
self._logger.debug('receive() released the lock') self._logger.debug('receive() released the lock')
def receive_update(self, timeout=timedelta(seconds=5)): def receive_update(self, timeout=timedelta(seconds=5)):
@ -186,6 +123,11 @@ class MtProtoSender:
self.receive(timeout=timeout, updates=updates) self.receive(timeout=timeout, updates=updates)
return updates[0] return updates[0]
def cancel_receive(self):
"""Cancels any pending receive operation
by raising a ReadCancelledError"""
self._transport.cancel_receive()
# endregion # endregion
# region Low level processing # region Low level processing
@ -395,7 +337,8 @@ class MtProtoSender:
if not request: if not request:
raise ValueError( raise ValueError(
'The previously sent request must be resent. ' 'The previously sent request must be resent. '
'However, no request was previously sent (called from updates thread).') 'However, no request was previously sent '
'(possibly called from a different thread).')
request.confirm_received = False request.confirm_received = False
if error.message.startswith('FLOOD_WAIT_'): if error.message.startswith('FLOOD_WAIT_'):
@ -410,7 +353,8 @@ class MtProtoSender:
else: else:
if not request: if not request:
raise ValueError( raise ValueError(
'Cannot receive a request from inside an RPC result from the updates thread.') 'The request needed to read this RPC result was not '
'found (possibly called receive() from another thread).')
self._logger.debug('Reading request response') self._logger.debug('Reading request response')
if inner_code == 0x3072cfa1: # GZip packed if inner_code == 0x3072cfa1: # GZip packed
@ -437,90 +381,3 @@ class MtProtoSender:
request, updates) request, updates)
# endregion # endregion
def _set_updates_thread(self, running):
"""Sets the updates thread status (running or not)"""
if not self._updates_thread or \
running == self._updates_thread_running.is_set():
return
# Different state, update the saved value and behave as required
self._logger.info('Changing updates thread running status to %s', running)
if running:
self._updates_thread_running.set()
self._updates_thread.start()
else:
self._updates_thread_running.clear()
if self._updates_thread_receiving.is_set():
self._transport.cancel_receive()
def _updates_thread_method(self):
"""This method will run until specified and listen for incoming updates"""
# Set a reasonable timeout when checking for updates
timeout = timedelta(minutes=1)
while self._updates_thread_running.is_set():
# Always sleep a bit before each iteration to relax the CPU,
# since it's possible to early 'continue' the loop to reach
# the next iteration, but we still should to sleep.
if self._updates_thread_sleep:
sleep(self._updates_thread_sleep)
self._updates_thread_sleep = None
else:
# Longer sleep if we're not expecting updates (only pings)
sleep(0.1 if self._on_update_handlers else 1)
# Only try to receive updates if we're not waiting to receive a request
if not self._waiting_receive.is_set():
with self._lock:
self._logger.debug('Updates thread acquired the lock')
try:
now = time()
# If ping_interval seconds passed since last ping, send a new one
if now >= self._ping_time_last + self.ping_interval:
self._ping_time_last = now
self.send_ping()
self._logger.debug('Ping sent from the updates thread')
# Exit the loop if we're not expecting to receive any updates
if not self._on_update_handlers:
self._logger.debug('No updates handlers found, continuing')
continue
self._updates_thread_receiving.set()
self._logger.debug('Trying to receive updates from the updates thread')
result = self.receive_update(timeout=timeout)
self._logger.info('Received update from the updates thread')
for handler in self._on_update_handlers:
handler(result)
except TimeoutError:
self._logger.debug('Receiving updates timed out')
# TODO Workaround for issue #50
r = GetStateRequest()
try:
self._logger.debug('Sending GetStateRequest (workaround for issue #50)')
self.send(r)
self.receive(r)
except TimeoutError:
self._logger.warning('Timed out inside a timeout, trying to reconnect...')
self.reconnect()
self.send(r)
self.receive(r)
except ReadCancelledError:
self._logger.info('Receiving updates cancelled')
except OSError:
self._logger.warning('OSError on updates thread, %s logging out',
'was' if self.logging_out else 'was not')
if self.logging_out:
# This error is okay when logging out, means we got disconnected
# TODO Not sure why this happens because we call disconnect()…
self._set_updates_thread(running=False)
else:
raise
self._logger.debug('Updates thread released the lock')
self._updates_thread_receiving.clear()

View File

@ -3,16 +3,23 @@ from datetime import timedelta
from hashlib import md5 from hashlib import md5
from mimetypes import guess_type from mimetypes import guess_type
from os import listdir, path from os import listdir, path
from threading import Event, RLock, Thread
from time import time, sleep
import logging
# Import some externalized utilities to work with the Telegram types and more # Import some externalized utilities to work with the Telegram types and more
from . import helpers as utils from . import helpers as utils
from .errors import RPCError, InvalidDCError, InvalidParameterError from .errors import (RPCError, InvalidDCError, FloodWaitError,
InvalidParameterError, ReadCancelledError)
from .network import authenticator, MtProtoSender, TcpTransport from .network import authenticator, MtProtoSender, TcpTransport
from .parser.markdown_parser import parse_message_entities from .parser.markdown_parser import parse_message_entities
# For sending and receiving requests # For sending and receiving requests
from .tl import MTProtoRequest, Session from .tl import MTProtoRequest, Session
from .tl.all_tlobjects import layer from .tl.all_tlobjects import layer
from .tl.functions import InitConnectionRequest, InvokeWithLayerRequest from .tl.functions import (InitConnectionRequest, InvokeWithLayerRequest,
PingRequest)
# The following is required to get the password salt # The following is required to get the password salt
from .tl.functions.account import GetPasswordRequest from .tl.functions.account import GetPasswordRequest
from .tl.functions.auth import (CheckPasswordRequest, LogOutRequest, from .tl.functions.auth import (CheckPasswordRequest, LogOutRequest,
@ -72,7 +79,19 @@ class TelegramClient:
self.transport = None self.transport = None
self.proxy = proxy # Will be used when a TcpTransport is created self.proxy = proxy # Will be used when a TcpTransport is created
# Safety across multiple threads (for the updates thread)
self._lock = RLock()
self._logger = logging.getLogger(__name__)
# Methods to be called when an update is received
self.update_handlers = []
self.ping_interval = 60
self._ping_time_last = time()
self._updates_thread_running = Event()
self._updates_thread_receiving = Event()
# These will be set later # These will be set later
self._updates_thread = None
self.dc_options = None self.dc_options = None
self.sender = None self.sender = None
self.phone_code_hashes = {} self.phone_code_hashes = {}
@ -119,7 +138,7 @@ class TelegramClient:
# Once we know we're authorized, we can setup the ping thread # Once we know we're authorized, we can setup the ping thread
if self.is_user_authorized(): if self.is_user_authorized():
self.sender.setup_ping_thread() self._setup_ping_thread()
return True return True
except RPCError as error: except RPCError as error:
@ -143,7 +162,8 @@ class TelegramClient:
self.connect(reconnect=True) self.connect(reconnect=True)
def disconnect(self): def disconnect(self):
"""Disconnects from the Telegram server **and pauses all the spawned threads**""" """Disconnects from the Telegram server and stops all the spawned threads"""
self._set_updates_thread(running=False)
if self.sender: if self.sender:
self.sender.disconnect() self.sender.disconnect()
self.sender = None self.sender = None
@ -151,6 +171,11 @@ class TelegramClient:
self.transport.close() self.transport.close()
self.transport = None self.transport = None
def reconnect(self):
"""Disconnects and connects again (effectively reconnecting)"""
self.disconnect()
self.connect()
# endregion # endregion
# region Telegram requests functions # region Telegram requests functions
@ -168,12 +193,16 @@ class TelegramClient:
if not self.sender: if not self.sender:
raise ValueError('You must be connected to invoke requests!') raise ValueError('You must be connected to invoke requests!')
if self._updates_thread_receiving.is_set():
self.sender.cancel_receive()
try: try:
self._lock.acquire()
updates = [] updates = []
self.sender.send(request) self.sender.send(request)
self.sender.receive(request, timeout, updates=updates) self.sender.receive(request, timeout, updates=updates)
for update in updates: for update in updates:
for handler in self.sender._on_update_handlers: for handler in self.update_handlers:
handler(update) handler(update)
return request.result return request.result
@ -185,6 +214,14 @@ class TelegramClient:
self._reconnect_to_dc(error.new_dc) self._reconnect_to_dc(error.new_dc)
return self.invoke(request, timeout=timeout, throw_invalid_dc=True) return self.invoke(request, timeout=timeout, throw_invalid_dc=True)
except FloodWaitError:
# TODO Write somewhere that FloodWaitError disconnects the client
self.disconnect()
raise
finally:
self._lock.release()
# region Authorization requests # region Authorization requests
def is_user_authorized(self): def is_user_authorized(self):
@ -247,7 +284,7 @@ class TelegramClient:
# to start the pings thread once we're already authorized and not # to start the pings thread once we're already authorized and not
# before to avoid the updates thread trying to read anything while # before to avoid the updates thread trying to read anything while
# we haven't yet connected. # we haven't yet connected.
self.sender.setup_ping_thread() self._setup_ping_thread()
return True return True
@ -736,12 +773,95 @@ class TelegramClient:
raise RuntimeError( raise RuntimeError(
"You should connect at least once to add update handlers.") "You should connect at least once to add update handlers.")
self.sender.add_update_handler(handler) # TODO Eventually remove these methods, the user
# can access self.update_handlers manually
self.update_handlers.append(handler)
def remove_update_handler(self, handler): def remove_update_handler(self, handler):
self.sender.remove_update_handler(handler) self.update_handlers.remove(handler)
def list_update_handlers(self): def list_update_handlers(self):
return [handler.__name__ for handler in self.sender.on_update_handlers] return self.update_handlers[:]
def _setup_ping_thread(self):
"""Sets up the Ping's thread, so that a connection can be kept
alive for a longer time without Telegram disconnecting us"""
self._updates_thread = Thread(
name='UpdatesThread', daemon=True,
target=self._updates_thread_method)
self._set_updates_thread(running=True)
def _set_updates_thread(self, running):
"""Sets the updates thread status (running or not)"""
if not self._updates_thread or \
running == self._updates_thread_running.is_set():
return
# Different state, update the saved value and behave as required
self._logger.info('Changing updates thread running status to %s', running)
if running:
self._updates_thread_running.set()
self._updates_thread.start()
else:
self._updates_thread_running.clear()
if self._updates_thread_receiving.is_set():
self.sender.cancel_receive()
def _updates_thread_method(self):
"""This method will run until specified and listen for incoming updates"""
# Set a reasonable timeout when checking for updates
timeout = timedelta(minutes=1)
while self._updates_thread_running.is_set():
# Always sleep a bit before each iteration to relax the CPU,
# since it's possible to early 'continue' the loop to reach
# the next iteration, but we still should to sleep.
# Longer sleep if we're not expecting updates (only pings)
sleep(0.1 if self.update_handlers else 1)
with self._lock:
self._logger.debug('Updates thread acquired the lock')
try:
now = time()
# If ping_interval seconds passed since last ping, send a new one
if now >= self._ping_time_last + self.ping_interval:
self._ping_time_last = now
self.invoke(PingRequest(utils.generate_random_long()))
self._logger.debug('Ping sent from the updates thread')
# Exit the loop if we're not expecting to receive any updates
if not self.update_handlers:
self._logger.debug('No updates handlers found, continuing')
continue
self._updates_thread_receiving.set()
self._logger.debug('Trying to receive updates from the updates thread')
result = self.sender.receive_update(timeout=timeout)
self._logger.info('Received update from the updates thread')
for handler in self.update_handlers:
handler(result)
except TimeoutError:
self._logger.debug('Receiving updates timed out')
self.reconnect()
except ReadCancelledError:
self._logger.info('Receiving updates cancelled')
except OSError:
self._logger.warning('OSError on updates thread, %s logging out',
'was' if self.sender.logging_out else 'was not')
if self.sender.logging_out:
# This error is okay when logging out, means we got disconnected
# TODO Not sure why this happens because we call disconnect()…
self._set_updates_thread(running=False)
else:
raise
self._logger.debug('Updates thread released the lock')
self._updates_thread_receiving.clear()
# endregion # endregion