New way to work with updates (#237)

This commit is contained in:
Lonami 2017-09-08 13:11:04 +02:00 committed by GitHub
commit 2dea665721
5 changed files with 167 additions and 88 deletions

View File

@ -28,23 +28,15 @@ class MtProtoSender:
self._need_confirmation = [] # Message IDs that need confirmation self._need_confirmation = [] # Message IDs that need confirmation
self._pending_receive = [] # Requests sent waiting to be received self._pending_receive = [] # Requests sent waiting to be received
# Store an RLock instance to make this class safely multi-threaded # Sending and receiving are independent, but two threads cannot
self._lock = RLock() # send or receive at the same time no matter what.
self._send_lock = RLock()
self._recv_lock = RLock()
# Used when logging out, the only request that seems to use 'ack' # Used when logging out, the only request that seems to use 'ack'
# TODO There might be a better way to handle msgs_ack requests # TODO There might be a better way to handle msgs_ack requests
self.logging_out = False self.logging_out = False
# Every unhandled result gets passed to these callbacks, which
# should be functions accepting a single parameter: a TLObject.
# This should only be Update(s), although it can actually be any type.
#
# The thread from which these callbacks are called can be any.
#
# The creator of the MtProtoSender is responsible for setting this
# to point to the list wherever their callbacks reside.
self.unhandled_callbacks = None
def connect(self): def connect(self):
"""Connects to the server""" """Connects to the server"""
self.connection.connect() self.connection.connect()
@ -62,23 +54,17 @@ class MtProtoSender:
"""Sends the specified MTProtoRequest, previously sending any message """Sends the specified MTProtoRequest, previously sending any message
which needed confirmation.""" which needed confirmation."""
# Now only us can be using this method # If any message needs confirmation send an AckRequest first
with self._lock: self._send_acknowledges()
self._logger.debug('send() acquired the lock')
# If any message needs confirmation send an AckRequest first # Finally send our packed request
self._send_acknowledges() with BinaryWriter() as writer:
request.on_send(writer)
self._send_packet(writer.get_bytes(), request)
self._pending_receive.append(request)
# Finally send our packed request # And update the saved session
with BinaryWriter() as writer: self.session.save()
request.on_send(writer)
self._send_packet(writer.get_bytes(), request)
self._pending_receive.append(request)
# And update the saved session
self.session.save()
self._logger.debug('send() released the lock')
def _send_acknowledges(self): def _send_acknowledges(self):
"""Sends a messages acknowledge for all those who _need_confirmation""" """Sends a messages acknowledge for all those who _need_confirmation"""
@ -90,23 +76,22 @@ class MtProtoSender:
del self._need_confirmation[:] del self._need_confirmation[:]
def receive(self): def receive(self, update_state):
"""Receives a single message from the connected endpoint. """Receives a single message from the connected endpoint.
This method returns nothing, and will only affect other parts This method returns nothing, and will only affect other parts
of the MtProtoSender such as the updates callback being fired of the MtProtoSender such as the updates callback being fired
or a pending request being confirmed. or a pending request being confirmed.
Any unhandled object (likely updates) will be passed to
update_state.process(TLObject).
""" """
# TODO Don't ignore updates with self._recv_lock:
self._logger.debug('Receiving a message...') body = self.connection.recv()
body = self.connection.recv()
message, remote_msg_id, remote_seq = self._decode_msg(body) message, remote_msg_id, remote_seq = self._decode_msg(body)
with BinaryReader(message) as reader: with BinaryReader(message) as reader:
self._process_msg( self._process_msg(remote_msg_id, remote_seq, reader, update_state)
remote_msg_id, remote_seq, reader, updates=None)
self._logger.debug('Received message.')
# endregion # endregion
@ -115,8 +100,6 @@ class MtProtoSender:
def _send_packet(self, packet, request): def _send_packet(self, packet, request):
"""Sends the given packet bytes with the additional """Sends the given packet bytes with the additional
information of the original request. information of the original request.
This does NOT lock the threads!
""" """
request.request_msg_id = self.session.get_new_msg_id() request.request_msg_id = self.session.get_new_msg_id()
@ -142,7 +125,8 @@ class MtProtoSender:
self.session.auth_key.key_id, signed=False) self.session.auth_key.key_id, signed=False)
cipher_writer.write(msg_key) cipher_writer.write(msg_key)
cipher_writer.write(cipher_text) cipher_writer.write(cipher_text)
self.connection.send(cipher_writer.get_bytes()) with self._send_lock:
self.connection.send(cipher_writer.get_bytes())
def _decode_msg(self, body): def _decode_msg(self, body):
"""Decodes an received encrypted message body bytes""" """Decodes an received encrypted message body bytes"""
@ -172,7 +156,7 @@ class MtProtoSender:
return message, remote_msg_id, remote_sequence return message, remote_msg_id, remote_sequence
def _process_msg(self, msg_id, sequence, reader, updates): def _process_msg(self, msg_id, sequence, reader, state):
"""Processes and handles a Telegram message. """Processes and handles a Telegram message.
Returns True if the message was handled correctly and doesn't Returns True if the message was handled correctly and doesn't
@ -193,10 +177,10 @@ class MtProtoSender:
return self._handle_pong(msg_id, sequence, reader) return self._handle_pong(msg_id, sequence, reader)
if code == 0x73f1f8dc: # msg_container if code == 0x73f1f8dc: # msg_container
return self._handle_container(msg_id, sequence, reader, updates) return self._handle_container(msg_id, sequence, reader, state)
if code == 0x3072cfa1: # gzip_packed if code == 0x3072cfa1: # gzip_packed
return self._handle_gzip_packed(msg_id, sequence, reader, updates) return self._handle_gzip_packed(msg_id, sequence, reader, state)
if code == 0xedab447b: # bad_server_salt if code == 0xedab447b: # bad_server_salt
return self._handle_bad_server_salt(msg_id, sequence, reader) return self._handle_bad_server_salt(msg_id, sequence, reader)
@ -221,16 +205,15 @@ class MtProtoSender:
# If the code is not parsed manually then it should be a TLObject. # If the code is not parsed manually then it should be a TLObject.
if code in tlobjects: if code in tlobjects:
result = reader.tgread_object() result = reader.tgread_object()
if self.unhandled_callbacks: if state is None:
self._logger.debug(
'Passing TLObject to callbacks %s', repr(result)
)
for callback in self.unhandled_callbacks:
callback(result)
else:
self._logger.debug( self._logger.debug(
'Ignoring unhandled TLObject %s', repr(result) 'Ignoring unhandled TLObject %s', repr(result)
) )
else:
self._logger.debug(
'Processing TLObject %s', repr(result)
)
state.process(result)
return True return True
@ -261,7 +244,7 @@ class MtProtoSender:
return True return True
def _handle_container(self, msg_id, sequence, reader, updates): def _handle_container(self, msg_id, sequence, reader, state):
self._logger.debug('Handling container') self._logger.debug('Handling container')
reader.read_int(signed=False) # code reader.read_int(signed=False) # code
size = reader.read_int() size = reader.read_int()
@ -274,8 +257,7 @@ class MtProtoSender:
# Note that this code is IMPORTANT for skipping RPC results of # Note that this code is IMPORTANT for skipping RPC results of
# lost requests (i.e., ones from the previous connection session) # lost requests (i.e., ones from the previous connection session)
try: try:
if not self._process_msg( if not self._process_msg(inner_msg_id, sequence, reader, state):
inner_msg_id, sequence, reader, updates):
reader.set_position(begin_position + inner_length) reader.set_position(begin_position + inner_length)
except: except:
# If any error is raised, something went wrong; skip the packet # If any error is raised, something went wrong; skip the packet
@ -366,14 +348,13 @@ class MtProtoSender:
self._logger.debug('Lost request will be skipped.') self._logger.debug('Lost request will be skipped.')
return False return False
def _handle_gzip_packed(self, msg_id, sequence, reader, updates): def _handle_gzip_packed(self, msg_id, sequence, reader, state):
self._logger.debug('Handling gzip packed data') self._logger.debug('Handling gzip packed data')
reader.read_int(signed=False) # code reader.read_int(signed=False) # code
packed_data = reader.tgread_bytes() packed_data = reader.tgread_bytes()
unpacked_data = gzip.decompress(packed_data) unpacked_data = gzip.decompress(packed_data)
with BinaryReader(unpacked_data) as compressed_reader: with BinaryReader(unpacked_data) as compressed_reader:
return self._process_msg( return self._process_msg(msg_id, sequence, compressed_reader, state)
msg_id, sequence, compressed_reader, updates)
# endregion # endregion

View File

@ -26,6 +26,7 @@ from .tl.functions.upload import (
) )
from .tl.types import InputFile, InputFileBig from .tl.types import InputFile, InputFileBig
from .tl.types.upload import FileCdnRedirect from .tl.types.upload import FileCdnRedirect
from .update_state import UpdateState
from .utils import get_appropriated_part_size from .utils import get_appropriated_part_size
@ -56,7 +57,9 @@ class TelegramBareClient:
def __init__(self, session, api_id, api_hash, def __init__(self, session, api_id, api_hash,
connection_mode=ConnectionMode.TCP_FULL, connection_mode=ConnectionMode.TCP_FULL,
proxy=None, timeout=timedelta(seconds=5)): proxy=None,
process_updates=False,
timeout=timedelta(seconds=5)):
"""Initializes the Telegram client with the specified API ID and Hash. """Initializes the Telegram client with the specified API ID and Hash.
Session must always be a Session instance, and an optional proxy Session must always be a Session instance, and an optional proxy
can also be specified to be used on the connection. can also be specified to be used on the connection.
@ -74,11 +77,9 @@ class TelegramBareClient:
# the time since it's a (somewhat expensive) process. # the time since it's a (somewhat expensive) process.
self._cached_clients = {} self._cached_clients = {}
# Update callbacks (functions accepting a single TLObject) go here # This member will process updates if enabled.
# # One may change self.updates.enabled at any later point.
# Note that changing the list to which this variable points to self.updates = UpdateState(process_updates)
# will not reflect the changes on the existing senders.
self._update_callbacks = []
# These will be set later # These will be set later
self.dc_options = None self.dc_options = None
@ -127,7 +128,6 @@ class TelegramBareClient:
self.session.save() self.session.save()
self._sender = MtProtoSender(connection, self.session) self._sender = MtProtoSender(connection, self.session)
self._sender.unhandled_callbacks = self._update_callbacks
self._sender.connect() self._sender.connect()
# Now it's time to send an InitConnectionRequest # Now it's time to send an InitConnectionRequest
@ -312,7 +312,7 @@ class TelegramBareClient:
request.confirm_received.wait() # TODO Socket's timeout here? request.confirm_received.wait() # TODO Socket's timeout here?
else: else:
while not request.confirm_received.is_set(): while not request.confirm_received.is_set():
self._sender.receive() self._sender.receive(update_state=self.updates)
if request.rpc_error: if request.rpc_error:
raise request.rpc_error raise request.rpc_error

View File

@ -1,8 +1,8 @@
import os import os
import threading
from datetime import datetime, timedelta from datetime import datetime, timedelta
from mimetypes import guess_type from mimetypes import guess_type
from threading import RLock, Thread from threading import Thread
import threading
from . import TelegramBareClient from . import TelegramBareClient
from . import helpers as utils from . import helpers as utils
@ -13,6 +13,7 @@ from .errors import (
) )
from .network import ConnectionMode from .network import ConnectionMode
from .tl import Session, TLObject from .tl import Session, TLObject
from .tl.functions import PingRequest
from .tl.functions.account import ( from .tl.functions.account import (
GetPasswordRequest GetPasswordRequest
) )
@ -27,6 +28,9 @@ from .tl.functions.messages import (
GetDialogsRequest, GetHistoryRequest, ReadHistoryRequest, SendMediaRequest, GetDialogsRequest, GetHistoryRequest, ReadHistoryRequest, SendMediaRequest,
SendMessageRequest SendMessageRequest
) )
from .tl.functions.updates import (
GetStateRequest
)
from .tl.functions.users import ( from .tl.functions.users import (
GetUsersRequest GetUsersRequest
) )
@ -46,9 +50,6 @@ class TelegramClient(TelegramBareClient):
As opposed to the TelegramBareClient, this one features downloading As opposed to the TelegramBareClient, this one features downloading
media from different data centers, starting a second thread to media from different data centers, starting a second thread to
handle updates, and some very common functionality. handle updates, and some very common functionality.
This should be used when the (slight) overhead of having locks,
threads, and possibly multiple connections is not an issue.
""" """
# region Initialization # region Initialization
@ -56,6 +57,7 @@ class TelegramClient(TelegramBareClient):
def __init__(self, session, api_id, api_hash, def __init__(self, session, api_id, api_hash,
connection_mode=ConnectionMode.TCP_FULL, connection_mode=ConnectionMode.TCP_FULL,
proxy=None, proxy=None,
process_updates=False,
timeout=timedelta(seconds=5), timeout=timedelta(seconds=5),
**kwargs): **kwargs):
"""Initializes the Telegram client with the specified API ID and Hash. """Initializes the Telegram client with the specified API ID and Hash.
@ -69,6 +71,16 @@ class TelegramClient(TelegramBareClient):
This will only affect how messages are sent over the network This will only affect how messages are sent over the network
and how much processing is required before sending them. and how much processing is required before sending them.
If 'process_updates' is set to True, incoming updates will be
processed and you must manually call 'self.updates.poll()' from
another thread to retrieve the saved update objects, or your
memory will fill with these. You may modify the value of
'self.updates.polling' at any later point.
Despite the value of 'process_updates', if you later call
'.add_update_handler(...)', updates will also be processed
and the update objects will be passed to the handlers you added.
If more named arguments are provided as **kwargs, they will be If more named arguments are provided as **kwargs, they will be
used to update the Session instance. Most common settings are: used to update the Session instance. Most common settings are:
device_model = platform.node() device_model = platform.node()
@ -92,12 +104,12 @@ class TelegramClient(TelegramBareClient):
super().__init__( super().__init__(
session, api_id, api_hash, session, api_id, api_hash,
connection_mode=connection_mode, proxy=proxy, timeout=timeout connection_mode=connection_mode,
proxy=proxy,
process_updates=process_updates,
timeout=timeout
) )
# Safety across multiple threads (for the updates thread)
self._lock = RLock()
# Used on connection - the user may modify these and reconnect # Used on connection - the user may modify these and reconnect
kwargs['app_version'] = kwargs.get('app_version', self.__version__) kwargs['app_version'] = kwargs.get('app_version', self.__version__)
for name, value in kwargs.items(): for name, value in kwargs.items():
@ -114,6 +126,10 @@ class TelegramClient(TelegramBareClient):
# Constantly read for results and updates from within the main client # Constantly read for results and updates from within the main client
self._recv_thread = None self._recv_thread = None
# Default PingRequest delay
self._last_ping = datetime.now()
self._ping_delay = timedelta(minutes=1)
# endregion # endregion
# region Connecting # region Connecting
@ -145,6 +161,8 @@ class TelegramClient(TelegramBareClient):
target=self._recv_thread_impl target=self._recv_thread_impl
) )
self._recv_thread.start() self._recv_thread.start()
if self.updates.enabled:
self.sync_updates()
return ok return ok
@ -205,19 +223,20 @@ class TelegramClient(TelegramBareClient):
*args will be ignored. *args will be ignored.
""" """
try: if self._recv_thread is not None and \
self._lock.acquire() threading.get_ident() == self._recv_thread.ident:
raise AssertionError('Cannot invoke requests from the ReadThread')
try:
# Users may call this method from within some update handler. # Users may call this method from within some update handler.
# If this is the case, then the thread invoking the request # If this is the case, then the thread invoking the request
# will be the one which should be reading (but is invoking the # will be the one which should be reading (but is invoking the
# request) thus not being available to read it "in the background" # request) thus not being available to read it "in the background"
# and it's needed to call receive. # and it's needed to call receive.
call_receive = self._recv_thread is None or \
threading.get_ident() == self._recv_thread.ident
# TODO Retry if 'result' is None? # TODO Retry if 'result' is None?
return super().invoke(request, call_receive=call_receive) return super().invoke(
request, call_receive=self._recv_thread is None
)
except (PhoneMigrateError, NetworkMigrateError, UserMigrateError) as e: except (PhoneMigrateError, NetworkMigrateError, UserMigrateError) as e:
self._logger.debug('DC error when invoking request, ' self._logger.debug('DC error when invoking request, '
@ -227,9 +246,6 @@ class TelegramClient(TelegramBareClient):
self.reconnect(new_dc=e.new_dc) self.reconnect(new_dc=e.new_dc)
return self.invoke(request) return self.invoke(request)
finally:
self._lock.release()
# Let people use client(SomeRequest()) instead client.invoke(...) # Let people use client(SomeRequest()) instead client.invoke(...)
__call__ = invoke __call__ = invoke
@ -404,8 +420,6 @@ class TelegramClient(TelegramBareClient):
no_webpage=not link_preview no_webpage=not link_preview
) )
result = self(request) result = self(request)
for callback in self._update_callbacks:
callback(result)
return request.random_id return request.random_id
def get_message_history(self, def get_message_history(self,
@ -893,16 +907,26 @@ class TelegramClient(TelegramBareClient):
# region Updates handling # region Updates handling
def sync_updates(self):
"""Synchronizes self.updates to their initial state. Will be
called automatically on connection if self.updates.enabled = True,
otherwise it should be called manually after enabling updates.
"""
self.updates.process(self(GetStateRequest()))
def add_update_handler(self, handler): def add_update_handler(self, handler):
"""Adds an update handler (a function which takes a TLObject, """Adds an update handler (a function which takes a TLObject,
an update, as its parameter) and listens for updates""" an update, as its parameter) and listens for updates"""
self._update_callbacks.append(handler) sync = not self.updates.handlers
self.updates.handlers.append(handler)
if sync:
self.sync_updates()
def remove_update_handler(self, handler): def remove_update_handler(self, handler):
self._update_callbacks.remove(handler) self.updates.handlers.remove(handler)
def list_update_handlers(self): def list_update_handlers(self):
return self._update_callbacks[:] return self.updates.handlers[:]
# endregion # endregion
@ -918,7 +942,13 @@ class TelegramClient(TelegramBareClient):
def _recv_thread_impl(self): def _recv_thread_impl(self):
while self._sender and self._sender.is_connected(): while self._sender and self._sender.is_connected():
try: try:
self._sender.receive() if datetime.now() > self._last_ping + self._ping_delay:
self._sender.send(PingRequest(
int.from_bytes(os.urandom(8), 'big', signed=True)
))
self._last_ping = datetime.now()
self._sender.receive(update_state=self.updates)
except TimeoutError: except TimeoutError:
# No problem. # No problem.
pass pass

67
telethon/update_state.py Normal file
View File

@ -0,0 +1,67 @@
from collections import deque
from datetime import datetime
from threading import RLock, Event
from .tl import types as tl
class UpdateState:
"""Used to hold the current state of processed updates.
To retrieve an update, .poll() should be called.
"""
def __init__(self, polling):
self._polling = polling
self.handlers = []
self._updates_lock = RLock()
self._updates_available = Event()
self._updates = deque()
# https://core.telegram.org/api/updates
self._state = tl.updates.State(0, 0, datetime.now(), 0, 0)
def can_poll(self):
"""Returns True if a call to .poll() won't lock"""
return self._updates_available.is_set()
def poll(self):
"""Polls an update or blocks until an update object is available"""
if not self._polling:
raise ValueError('Updates are not being polled hence not saved.')
self._updates_available.wait()
with self._updates_lock:
update = self._updates.popleft()
if not self._updates:
self._updates_available.clear()
return update
def get_polling(self):
return self._polling
def set_polling(self, polling):
self._polling = polling
if not polling:
with self._updates_lock:
self._updates.clear()
polling = property(fget=get_polling, fset=set_polling)
def process(self, update):
"""Processes an update object. This method is normally called by
the library itself.
"""
if not self._polling or not self.handlers:
return
with self._updates_lock:
if isinstance(update, tl.updates.State):
self._state = update
elif not hasattr(update, 'pts') or update.pts > self._state.pts:
self._state.pts = getattr(update, 'pts', self._state.pts)
for handler in self.handlers:
handler(update)
if self._polling:
self._updates.append(update)
self._updates_available.set()

View File

@ -51,7 +51,8 @@ class InteractiveTelegramClient(TelegramClient):
print('Initializing interactive example...') print('Initializing interactive example...')
super().__init__( super().__init__(
session_user_id, api_id, api_hash, session_user_id, api_id, api_hash,
connection_mode=ConnectionMode.TCP_ABRIDGED, proxy=proxy connection_mode=ConnectionMode.TCP_ABRIDGED,
proxy=proxy
) )
# Store all the found media in memory here, # Store all the found media in memory here,