Merge pull request #370 from andr-04/asyncio

Made update system for asyncio functional
This commit is contained in:
Lonami 2017-10-28 11:07:41 +02:00 committed by GitHub
commit 6dc0ee9d6c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 277 additions and 256 deletions

View File

@ -5,13 +5,18 @@ import socket
from datetime import timedelta from datetime import timedelta
from io import BytesIO, BufferedWriter from io import BytesIO, BufferedWriter
loop = asyncio.get_event_loop() MAX_TIMEOUT = 15 # in seconds
CONN_RESET_ERRNOS = {
errno.EBADF, errno.ENOTSOCK, errno.ENETUNREACH,
errno.EINVAL, errno.ENOTCONN
}
class TcpClient: class TcpClient:
def __init__(self, proxy=None, timeout=timedelta(seconds=5)): def __init__(self, proxy=None, timeout=timedelta(seconds=5), loop=None):
self.proxy = proxy self.proxy = proxy
self._socket = None self._socket = None
self._loop = loop if loop else asyncio.get_event_loop()
if isinstance(timeout, timedelta): if isinstance(timeout, timedelta):
self.timeout = timeout.seconds self.timeout = timeout.seconds
@ -31,7 +36,7 @@ class TcpClient:
else: # tuple, list, etc. else: # tuple, list, etc.
self._socket.set_proxy(*self.proxy) self._socket.set_proxy(*self.proxy)
self._socket.settimeout(self.timeout) self._socket.setblocking(False)
async def connect(self, ip, port): async def connect(self, ip, port):
"""Connects to the specified IP and port number. """Connects to the specified IP and port number.
@ -42,20 +47,27 @@ class TcpClient:
else: else:
mode, address = socket.AF_INET, (ip, port) mode, address = socket.AF_INET, (ip, port)
timeout = 1
while True: while True:
try: try:
while not self._socket: if not self._socket:
self._recreate_socket(mode) self._recreate_socket(mode)
await loop.sock_connect(self._socket, address) await self._loop.sock_connect(self._socket, address)
break # Successful connection, stop retrying to connect break # Successful connection, stop retrying to connect
except ConnectionError:
self._socket = None
await asyncio.sleep(timeout)
timeout = min(timeout * 2, MAX_TIMEOUT)
except OSError as e: except OSError as e:
# There are some errors that we know how to handle, and # There are some errors that we know how to handle, and
# the loop will allow us to retry # the loop will allow us to retry
if e.errno == errno.EBADF: if e.errno in (errno.EBADF, errno.ENOTSOCK, errno.EINVAL):
# Bad file descriptor, i.e. socket was closed, set it # Bad file descriptor, i.e. socket was closed, set it
# to none to recreate it on the next iteration # to none to recreate it on the next iteration
self._socket = None self._socket = None
await asyncio.sleep(timeout)
timeout = min(timeout * 2, MAX_TIMEOUT)
else: else:
raise raise
@ -81,13 +93,14 @@ class TcpClient:
raise ConnectionResetError() raise ConnectionResetError()
try: try:
await loop.sock_sendall(self._socket, data) await asyncio.wait_for(self._loop.sock_sendall(self._socket, data),
except socket.timeout as e: timeout=self.timeout, loop=self._loop)
except asyncio.TimeoutError as e:
raise TimeoutError() from e raise TimeoutError() from e
except BrokenPipeError: except BrokenPipeError:
self._raise_connection_reset() self._raise_connection_reset()
except OSError as e: except OSError as e:
if e.errno == errno.EBADF: if e.errno in CONN_RESET_ERRNOS:
self._raise_connection_reset() self._raise_connection_reset()
else: else:
raise raise
@ -104,11 +117,12 @@ class TcpClient:
bytes_left = size bytes_left = size
while bytes_left != 0: while bytes_left != 0:
try: try:
partial = await loop.sock_recv(self._socket, bytes_left) partial = await asyncio.wait_for(self._loop.sock_recv(self._socket, bytes_left),
except socket.timeout as e: timeout=self.timeout, loop=self._loop)
except asyncio.TimeoutError as e:
raise TimeoutError() from e raise TimeoutError() from e
except OSError as e: except OSError as e:
if e.errno == errno.EBADF or e.errno == errno.ENOTSOCK: if e.errno in CONN_RESET_ERRNOS:
self._raise_connection_reset() self._raise_connection_reset()
else: else:
raise raise

View File

@ -43,13 +43,13 @@ class Connection:
""" """
def __init__(self, mode=ConnectionMode.TCP_FULL, def __init__(self, mode=ConnectionMode.TCP_FULL,
proxy=None, timeout=timedelta(seconds=5)): proxy=None, timeout=timedelta(seconds=5), loop=None):
self._mode = mode self._mode = mode
self._send_counter = 0 self._send_counter = 0
self._aes_encrypt, self._aes_decrypt = None, None self._aes_encrypt, self._aes_decrypt = None, None
# TODO Rename "TcpClient" as some sort of generic socket? # TODO Rename "TcpClient" as some sort of generic socket?
self.conn = TcpClient(proxy=proxy, timeout=timeout) self.conn = TcpClient(proxy=proxy, timeout=timeout, loop=loop)
# Sending messages # Sending messages
if mode == ConnectionMode.TCP_FULL: if mode == ConnectionMode.TCP_FULL:
@ -206,7 +206,7 @@ class Connection:
return await self.conn.read(length) return await self.conn.read(length)
async def _read_obfuscated(self, length): async def _read_obfuscated(self, length):
return await self._aes_decrypt.encrypt(self.conn.read(length)) return self._aes_decrypt.encrypt(await self.conn.read(length))
# endregion # endregion

View File

@ -1,6 +1,8 @@
import gzip import gzip
import logging import logging
import struct import struct
import asyncio
from asyncio import Event
from .. import helpers as utils from .. import helpers as utils
from ..crypto import AES from ..crypto import AES
@ -30,17 +32,15 @@ class MtProtoSender:
in parallel, so thread-safety (hence locking) isn't needed. in parallel, so thread-safety (hence locking) isn't needed.
""" """
def __init__(self, session, connection): def __init__(self, session, connection, loop=None):
"""Creates a new MtProtoSender configured to send messages through """Creates a new MtProtoSender configured to send messages through
'connection' and using the parameters from 'session'. 'connection' and using the parameters from 'session'.
""" """
self.session = session self.session = session
self.connection = connection self.connection = connection
self._loop = loop if loop else asyncio.get_event_loop()
self._logger = logging.getLogger(__name__) self._logger = logging.getLogger(__name__)
# Message IDs that need confirmation
self._need_confirmation = set()
# Requests (as msg_id: Message) sent waiting to be received # Requests (as msg_id: Message) sent waiting to be received
self._pending_receive = {} self._pending_receive = {}
@ -54,12 +54,11 @@ class MtProtoSender:
def disconnect(self): def disconnect(self):
"""Disconnects from the server""" """Disconnects from the server"""
self.connection.close() self.connection.close()
self._need_confirmation.clear()
self._clear_all_pending() self._clear_all_pending()
def clone(self): def clone(self):
"""Creates a copy of this MtProtoSender as a new connection""" """Creates a copy of this MtProtoSender as a new connection"""
return MtProtoSender(self.session, self.connection.clone()) return MtProtoSender(self.session, self.connection.clone(), self._loop)
# region Send and receive # region Send and receive
@ -67,21 +66,23 @@ class MtProtoSender:
"""Sends the specified MTProtoRequest, previously sending any message """Sends the specified MTProtoRequest, previously sending any message
which needed confirmation.""" which needed confirmation."""
# Prepare the event of every request
for r in requests:
if r.confirm_received is None:
r.confirm_received = Event(loop=self._loop)
else:
r.confirm_received.clear()
# Finally send our packed request(s) # Finally send our packed request(s)
messages = [TLMessage(self.session, r) for r in requests] messages = [TLMessage(self.session, r) for r in requests]
self._pending_receive.update({m.msg_id: m for m in messages}) self._pending_receive.update({m.msg_id: m for m in messages})
# Pack everything in the same container if we need to send AckRequests
if self._need_confirmation:
messages.append(
TLMessage(self.session, MsgsAck(list(self._need_confirmation)))
)
self._need_confirmation.clear()
if len(messages) == 1: if len(messages) == 1:
message = messages[0] message = messages[0]
else: else:
message = TLMessage(self.session, MessageContainer(messages)) message = TLMessage(self.session, MessageContainer(messages))
for m in messages:
m.container_msg_id = message.msg_id
await self._send_message(message) await self._send_message(message)
@ -115,6 +116,7 @@ class MtProtoSender:
message, remote_msg_id, remote_seq = self._decode_msg(body) message, remote_msg_id, remote_seq = self._decode_msg(body)
with BinaryReader(message) as reader: with BinaryReader(message) as reader:
await self._process_msg(remote_msg_id, remote_seq, reader, update_state) await self._process_msg(remote_msg_id, remote_seq, reader, update_state)
await self._send_acknowledge(remote_msg_id)
# endregion # endregion
@ -174,7 +176,6 @@ class MtProtoSender:
""" """
# TODO Check salt, session_id and sequence_number # TODO Check salt, session_id and sequence_number
self._need_confirmation.add(msg_id)
code = reader.read_int(signed=False) code = reader.read_int(signed=False)
reader.seek(-4) reader.seek(-4)
@ -210,7 +211,7 @@ class MtProtoSender:
if code == MsgsAck.CONSTRUCTOR_ID: # may handle the request we wanted if code == MsgsAck.CONSTRUCTOR_ID: # may handle the request we wanted
ack = reader.tgread_object() ack = reader.tgread_object()
assert isinstance(ack, MsgsAck) assert isinstance(ack, MsgsAck)
# Ignore every ack request *unless* when logging out, when it's # Ignore every ack request *unless* when logging out,
# when it seems to only make sense. We also need to set a non-None # when it seems to only make sense. We also need to set a non-None
# result since Telegram doesn't send the response for these. # result since Telegram doesn't send the response for these.
for msg_id in ack.msg_ids: for msg_id in ack.msg_ids:
@ -259,11 +260,29 @@ class MtProtoSender:
if message and isinstance(message.request, t): if message and isinstance(message.request, t):
return self._pending_receive.pop(msg_id).request return self._pending_receive.pop(msg_id).request
def _pop_requests_of_container(self, container_msg_id):
msgs = [msg for msg in self._pending_receive.values() if msg.container_msg_id == container_msg_id]
requests = [msg.request for msg in msgs]
for msg in msgs:
self._pending_receive.pop(msg.msg_id, None)
return requests
def _clear_all_pending(self): def _clear_all_pending(self):
for r in self._pending_receive.values(): for r in self._pending_receive.values():
r.request.confirm_received.set() r.request.confirm_received.set()
self._pending_receive.clear() self._pending_receive.clear()
async def _resend_request(self, msg_id):
request = self._pop_request(msg_id)
if request:
self._logger.debug('requests is about to resend')
await self.send(request)
return
requests = self._pop_requests_of_container(msg_id)
if requests:
self._logger.debug('container of requests is about to resend')
await self.send(*requests)
async def _handle_pong(self, msg_id, sequence, reader): async def _handle_pong(self, msg_id, sequence, reader):
self._logger.debug('Handling pong') self._logger.debug('Handling pong')
pong = reader.tgread_object() pong = reader.tgread_object()
@ -305,9 +324,7 @@ class MtProtoSender:
)[0] )[0]
self.session.save() self.session.save()
request = self._pop_request(bad_salt.bad_msg_id) await self._resend_request(bad_salt.bad_msg_id)
if request:
await self.send(request)
return True return True
@ -323,15 +340,18 @@ class MtProtoSender:
self.session.update_time_offset(correct_msg_id=msg_id) self.session.update_time_offset(correct_msg_id=msg_id)
self._logger.debug('Read Bad Message error: ' + str(error)) self._logger.debug('Read Bad Message error: ' + str(error))
self._logger.debug('Attempting to use the correct time offset.') self._logger.debug('Attempting to use the correct time offset.')
await self._resend_request(bad_msg.bad_msg_id)
return True return True
elif bad_msg.error_code == 32: elif bad_msg.error_code == 32:
# msg_seqno too low, so just pump it up by some "large" amount # msg_seqno too low, so just pump it up by some "large" amount
# TODO A better fix would be to start with a new fresh session ID # TODO A better fix would be to start with a new fresh session ID
self.session._sequence += 64 self.session._sequence += 64
await self._resend_request(bad_msg.bad_msg_id)
return True return True
elif bad_msg.error_code == 33: elif bad_msg.error_code == 33:
# msg_seqno too high never seems to happen but just in case # msg_seqno too high never seems to happen but just in case
self.session._sequence -= 16 self.session._sequence -= 16
await self._resend_request(bad_msg.bad_msg_id)
return True return True
else: else:
raise error raise error
@ -342,7 +362,6 @@ class MtProtoSender:
# TODO For now, simply ack msg_new.answer_msg_id # TODO For now, simply ack msg_new.answer_msg_id
# Relevant tdesktop source code: https://goo.gl/VvpCC6 # Relevant tdesktop source code: https://goo.gl/VvpCC6
await self._send_acknowledge(msg_new.answer_msg_id)
return True return True
async def _handle_msg_new_detailed_info(self, msg_id, sequence, reader): async def _handle_msg_new_detailed_info(self, msg_id, sequence, reader):
@ -351,7 +370,6 @@ class MtProtoSender:
# TODO For now, simply ack msg_new.answer_msg_id # TODO For now, simply ack msg_new.answer_msg_id
# Relevant tdesktop source code: https://goo.gl/G7DPsR # Relevant tdesktop source code: https://goo.gl/G7DPsR
await self._send_acknowledge(msg_new.answer_msg_id)
return True return True
async def _handle_new_session_created(self, msg_id, sequence, reader): async def _handle_new_session_created(self, msg_id, sequence, reader):
@ -379,9 +397,6 @@ class MtProtoSender:
reader.read_int(), reader.tgread_string() reader.read_int(), reader.tgread_string()
) )
# Acknowledge that we received the error
await self._send_acknowledge(request_id)
if request: if request:
request.rpc_error = error request.rpc_error = error
request.confirm_received.set() request.confirm_received.set()
@ -412,11 +427,6 @@ class MtProtoSender:
async def _handle_gzip_packed(self, msg_id, sequence, reader, state): async def _handle_gzip_packed(self, msg_id, sequence, reader, state):
self._logger.debug('Handling gzip packed data') self._logger.debug('Handling gzip packed data')
with BinaryReader(GzipPacked.read(reader)) as compressed_reader: with BinaryReader(GzipPacked.read(reader)) as compressed_reader:
# We are reentering process_msg, which seemingly the same msg_id
# to the self._need_confirmation set. Remove it from there first
# to avoid any future conflicts (i.e. if we "ignore" messages
# that we are already aware of, see 1a91c02 and old 63dfb1e)
self._need_confirmation -= {msg_id}
return await self._process_msg(msg_id, sequence, compressed_reader, state) return await self._process_msg(msg_id, sequence, compressed_reader, state)
# endregion # endregion

View File

@ -1,10 +1,10 @@
import logging import logging
import os import os
import warnings import asyncio
from datetime import timedelta, datetime from datetime import timedelta, datetime
from hashlib import md5 from hashlib import md5
from io import BytesIO from io import BytesIO
from time import sleep from asyncio import Lock
from . import helpers as utils from . import helpers as utils
from .crypto import rsa, CdnDecrypter from .crypto import rsa, CdnDecrypter
@ -17,7 +17,7 @@ from .network import authenticator, MtProtoSender, Connection, ConnectionMode
from .tl import TLObject, Session from .tl import TLObject, Session
from .tl.all_tlobjects import LAYER from .tl.all_tlobjects import LAYER
from .tl.functions import ( from .tl.functions import (
InitConnectionRequest, InvokeWithLayerRequest InitConnectionRequest, InvokeWithLayerRequest, PingRequest
) )
from .tl.functions.auth import ( from .tl.functions.auth import (
ImportAuthorizationRequest, ExportAuthorizationRequest ImportAuthorizationRequest, ExportAuthorizationRequest
@ -67,6 +67,7 @@ class TelegramBareClient:
connection_mode=ConnectionMode.TCP_FULL, connection_mode=ConnectionMode.TCP_FULL,
proxy=None, proxy=None,
timeout=timedelta(seconds=5), timeout=timedelta(seconds=5),
loop=None,
**kwargs): **kwargs):
"""Refer to TelegramClient.__init__ for docs on this method""" """Refer to TelegramClient.__init__ for docs on this method"""
if not api_id or not api_hash: if not api_id or not api_hash:
@ -82,6 +83,8 @@ class TelegramBareClient:
'The given session must be a str or a Session instance.' 'The given session must be a str or a Session instance.'
) )
self._loop = loop if loop else asyncio.get_event_loop()
self.session = session self.session = session
self.api_id = int(api_id) self.api_id = int(api_id)
self.api_hash = api_hash self.api_hash = api_hash
@ -92,12 +95,18 @@ class TelegramBareClient:
# that calls .connect(). Every other thread will spawn a new # that calls .connect(). Every other thread will spawn a new
# temporary connection. The connection on this one is always # temporary connection. The connection on this one is always
# kept open so Telegram can send us updates. # kept open so Telegram can send us updates.
self._sender = MtProtoSender(self.session, Connection( self._sender = MtProtoSender(
mode=connection_mode, proxy=proxy, timeout=timeout self.session,
)) Connection(mode=connection_mode, proxy=proxy, timeout=timeout, loop=self._loop),
self._loop
)
self._logger = logging.getLogger(__name__) self._logger = logging.getLogger(__name__)
# Two coroutines may be calling reconnect() when the connection is lost,
# we only want one to actually perform the reconnection.
self._reconnect_lock = Lock(loop=self._loop)
# Cache "exported" sessions as 'dc_id: Session' not to recreate # Cache "exported" sessions as 'dc_id: Session' not to recreate
# them all the time since generating a new key is a relatively # them all the time since generating a new key is a relatively
# expensive operation. # expensive operation.
@ -105,7 +114,7 @@ class TelegramBareClient:
# This member will process updates if enabled. # This member will process updates if enabled.
# One may change self.updates.enabled at any later point. # One may change self.updates.enabled at any later point.
self.updates = UpdateState(workers=None) self.updates = UpdateState(self._loop)
# Used on connection - the user may modify these and reconnect # Used on connection - the user may modify these and reconnect
kwargs['app_version'] = kwargs.get('app_version', self.__version__) kwargs['app_version'] = kwargs.get('app_version', self.__version__)
@ -129,10 +138,11 @@ class TelegramBareClient:
# Uploaded files cache so subsequent calls are instant # Uploaded files cache so subsequent calls are instant
self._upload_cache = {} self._upload_cache = {}
# Default PingRequest delay self._recv_loop = None
self._last_ping = datetime.now() self._ping_loop = None
self._ping_delay = timedelta(minutes=1)
# Default PingRequest delay
self._ping_delay = timedelta(minutes=1)
# endregion # endregion
@ -167,6 +177,7 @@ class TelegramBareClient:
self.session.auth_key, self.session.time_offset = \ self.session.auth_key, self.session.time_offset = \
await authenticator.do_authentication(self._sender.connection) await authenticator.do_authentication(self._sender.connection)
except BrokenAuthKeyError: except BrokenAuthKeyError:
self._user_connected = False
return False return False
self.session.layer = LAYER self.session.layer = LAYER
@ -213,7 +224,7 @@ class TelegramBareClient:
# This is fine, probably layer migration # This is fine, probably layer migration
self._logger.debug('Found invalid item, probably migrating', e) self._logger.debug('Found invalid item, probably migrating', e)
self.disconnect() self.disconnect()
return self.connect( return await self.connect(
_exported_auth=_exported_auth, _exported_auth=_exported_auth,
_sync_updates=_sync_updates, _sync_updates=_sync_updates,
_cdn=_cdn _cdn=_cdn
@ -263,7 +274,17 @@ class TelegramBareClient:
""" """
if new_dc is None: if new_dc is None:
# Assume we are disconnected due to some error, so connect again # Assume we are disconnected due to some error, so connect again
try:
await self._reconnect_lock.acquire()
# Another thread may have connected again, so check that first
if self.is_connected():
return True
return await self.connect() return await self.connect()
except ConnectionResetError:
return False
finally:
self._reconnect_lock.release()
else: else:
self.disconnect() self.disconnect()
self.session.auth_key = None # Force creating new auth_key self.session.auth_key = None # Force creating new auth_key
@ -339,7 +360,8 @@ class TelegramBareClient:
client = TelegramBareClient( client = TelegramBareClient(
session, self.api_id, self.api_hash, session, self.api_id, self.api_hash,
proxy=self._sender.connection.conn.proxy, proxy=self._sender.connection.conn.proxy,
timeout=self._sender.connection.get_timeout() timeout=self._sender.connection.get_timeout(),
loop=self._loop
) )
await client.connect(_exported_auth=export_auth, _sync_updates=False) await client.connect(_exported_auth=export_auth, _sync_updates=False)
client._authorized = True # We exported the auth, so we got auth client._authorized = True # We exported the auth, so we got auth
@ -358,7 +380,8 @@ class TelegramBareClient:
client = TelegramBareClient( client = TelegramBareClient(
session, self.api_id, self.api_hash, session, self.api_id, self.api_hash,
proxy=self._sender.connection.conn.proxy, proxy=self._sender.connection.conn.proxy,
timeout=self._sender.connection.get_timeout() timeout=self._sender.connection.get_timeout(),
loop=self._loop
) )
# This will make use of the new RSA keys for this specific CDN. # This will make use of the new RSA keys for this specific CDN.
@ -383,53 +406,51 @@ class TelegramBareClient:
x.content_related for x in requests): x.content_related for x in requests):
raise ValueError('You can only invoke requests, not types!') raise ValueError('You can only invoke requests, not types!')
# TODO Determine the sender to be used (main or a new connection) # We should call receive from this thread if there's no background
sender = self._sender # .clone(), .connect() # thread reading or if the server disconnected us and we're trying
# We're on the same connection so no need to pass update_state=None # to reconnect. This is because the read thread may either be
# to avoid getting messages that we haven't acknowledged yet. # locked also trying to reconnect or we may be said thread already.
call_receive = self._recv_loop is None
try: for retry in range(retries):
for _ in range(retries): result = await self._invoke(call_receive, retry, *requests)
result = await self._invoke(sender, *requests)
if result is not None: if result is not None:
return result return result
raise ValueError('Number of retries reached 0.') raise ValueError('Number of retries reached 0.')
finally:
if sender != self._sender:
sender.disconnect() # Close temporary connections
# Let people use client.invoke(SomeRequest()) instead client(...) # Let people use client.invoke(SomeRequest()) instead client(...)
invoke = __call__ invoke = __call__
async def _invoke(self, sender, *requests): async def _invoke(self, call_receive, retry, *requests):
try: try:
# Ensure that we start with no previous errors (i.e. resending) # Ensure that we start with no previous errors (i.e. resending)
for x in requests: for x in requests:
x.confirm_received.clear()
x.rpc_error = None x.rpc_error = None
await sender.send(*requests) await self._sender.send(*requests)
while not all(x.confirm_received.is_set() for x in requests):
await sender.receive(update_state=self.updates)
except TimeoutError: if not call_receive:
pass # We will just retry await asyncio.wait(
list(map(lambda x: x.confirm_received.wait(), requests)),
timeout=self._sender.connection.get_timeout(),
loop=self._loop
)
else:
while not all(x.confirm_received.is_set() for x in requests):
await self._sender.receive(update_state=self.updates)
except ConnectionResetError: except ConnectionResetError:
if not self._user_connected: if not self._user_connected or self._reconnect_lock.locked():
# Only attempt reconnecting if we're authorized # Only attempt reconnecting if the user called connect and not
# reconnecting already.
raise raise
self._logger.debug('Server disconnected us. Reconnecting and ' self._logger.debug('Server disconnected us. Reconnecting and '
'resending request...') 'resending request... (%d)' % retry)
await self._reconnect()
if sender != self._sender: if not self._sender.is_connected():
# TODO Try reconnecting forever too? await asyncio.sleep(retry + 1, loop=self._loop)
await sender.connect()
else:
while self._user_connected and not await self._reconnect():
sleep(0.1) # Retry forever until we can send the request
return None return None
try: try:
@ -453,7 +474,7 @@ class TelegramBareClient:
) )
await self._reconnect(new_dc=e.new_dc) await self._reconnect(new_dc=e.new_dc)
return await self._invoke(sender, *requests) return None
except ServerError as e: except ServerError as e:
# Telegram is having some issues, just retry # Telegram is having some issues, just retry
@ -468,7 +489,8 @@ class TelegramBareClient:
self._logger.debug( self._logger.debug(
'Sleep of %d seconds below threshold, sleeping' % e.seconds 'Sleep of %d seconds below threshold, sleeping' % e.seconds
) )
sleep(e.seconds) await asyncio.sleep(e.seconds, loop=self._loop)
return None
# Some really basic functionality # Some really basic functionality
@ -671,16 +693,9 @@ class TelegramBareClient:
""" """
self.updates.process(await self(GetStateRequest())) self.updates.process(await self(GetStateRequest()))
def add_update_handler(self, handler): async def add_update_handler(self, handler):
"""Adds an update handler (a function which takes a TLObject, """Adds an update handler (a function which takes a TLObject,
an update, as its parameter) and listens for updates""" an update, as its parameter) and listens for updates"""
if self.updates.workers is None:
warnings.warn(
"You have not setup any workers, so you won't receive updates."
" Pass update_workers=4 when creating the TelegramClient,"
" or set client.self.updates.workers = 4"
)
self.updates.handlers.append(handler) self.updates.handlers.append(handler)
def remove_update_handler(self, handler): def remove_update_handler(self, handler):
@ -695,6 +710,60 @@ class TelegramBareClient:
def _set_connected_and_authorized(self): def _set_connected_and_authorized(self):
self._authorized = True self._authorized = True
# TODO self.updates.setup_workers() if self._recv_loop is None:
self._recv_loop = asyncio.ensure_future(self._recv_loop_impl(), loop=self._loop)
if self._ping_loop is None:
self._ping_loop = asyncio.ensure_future(self._ping_loop_impl(), loop=self._loop)
async def _ping_loop_impl(self):
while self._user_connected:
await self(PingRequest(int.from_bytes(os.urandom(8), 'big', signed=True)))
await asyncio.sleep(self._ping_delay.seconds, loop=self._loop)
self._ping_loop = None
async def _recv_loop_impl(self):
need_reconnect = False
while self._user_connected:
try:
if need_reconnect:
need_reconnect = False
while self._user_connected and not await self._reconnect():
await asyncio.sleep(0.1, loop=self._loop) # Retry forever, this is instant messaging
await self._sender.receive(update_state=self.updates)
except TimeoutError:
# No problem.
pass
except ConnectionError as error:
self._logger.debug(error)
need_reconnect = True
await asyncio.sleep(1, loop=self._loop)
except Exception as error:
# Unknown exception, pass it to the main thread
self._logger.debug(
'[ERROR] Unknown error on the read loop, please report',
error
)
try:
import socks
if isinstance(error, (
socks.GeneralProxyError, socks.ProxyConnectionError
)):
# This is a known error, and it's not related to
# Telegram but rather to the proxy. Disconnect and
# hand it over to the main thread.
self._background_error = error
self.disconnect()
break
except ImportError:
"Not using PySocks, so it can't be a socket error"
# If something strange happens we don't want to enter an
# infinite loop where all we do is raise an exception, so
# add a little sleep to avoid the CPU usage going mad.
await asyncio.sleep(0.1, loop=self._loop)
break
self._recv_loop = None
# endregion # endregion

View File

@ -61,6 +61,7 @@ class TelegramClient(TelegramBareClient):
connection_mode=ConnectionMode.TCP_FULL, connection_mode=ConnectionMode.TCP_FULL,
proxy=None, proxy=None,
timeout=timedelta(seconds=5), timeout=timedelta(seconds=5),
loop=None,
**kwargs): **kwargs):
"""Initializes the Telegram client with the specified API ID and Hash. """Initializes the Telegram client with the specified API ID and Hash.
@ -87,6 +88,7 @@ class TelegramClient(TelegramBareClient):
connection_mode=connection_mode, connection_mode=connection_mode,
proxy=proxy, proxy=proxy,
timeout=timeout, timeout=timeout,
loop=loop,
**kwargs **kwargs
) )
@ -202,7 +204,7 @@ class TelegramClient(TelegramBareClient):
"""Gets "me" (the self user) which is currently authenticated, """Gets "me" (the self user) which is currently authenticated,
or None if the request fails (hence, not authenticated).""" or None if the request fails (hence, not authenticated)."""
try: try:
return await self(GetUsersRequest([InputUserSelf()]))[0] return (await self(GetUsersRequest([InputUserSelf()])))[0]
except UnauthorizedError: except UnauthorizedError:
return None return None
@ -313,6 +315,7 @@ class TelegramClient(TelegramBareClient):
reply_to_msg_id=self._get_reply_to(reply_to) reply_to_msg_id=self._get_reply_to(reply_to)
) )
result = await self(request) result = await self(request)
if isinstance(result, UpdateShortSentMessage): if isinstance(result, UpdateShortSentMessage):
return Message( return Message(
id=result.id, id=result.id,

View File

@ -11,6 +11,15 @@ class MessageContainer(TLObject):
self.content_related = False self.content_related = False
self.messages = messages self.messages = messages
def to_dict(self, recursive=True):
return {
'content_related': self.content_related,
'messages':
([] if self.messages is None else [
None if x is None else x.to_dict() for x in self.messages
]) if recursive else self.messages,
}
def __bytes__(self): def __bytes__(self):
return struct.pack( return struct.pack(
'<Ii', MessageContainer.CONSTRUCTOR_ID, len(self.messages) '<Ii', MessageContainer.CONSTRUCTOR_ID, len(self.messages)
@ -25,3 +34,9 @@ class MessageContainer(TLObject):
inner_sequence = reader.read_int() inner_sequence = reader.read_int()
inner_length = reader.read_int() inner_length = reader.read_int()
yield inner_msg_id, inner_sequence, inner_length yield inner_msg_id, inner_sequence, inner_length
def __str__(self):
return TLObject.pretty_format(self)
def stringify(self):
return TLObject.pretty_format(self, indent=0)

View File

@ -1,4 +1,5 @@
import struct import struct
import logging
from . import TLObject, GzipPacked from . import TLObject, GzipPacked
@ -11,7 +12,23 @@ class TLMessage(TLObject):
self.msg_id = session.get_new_msg_id() self.msg_id = session.get_new_msg_id()
self.seq_no = session.generate_sequence(request.content_related) self.seq_no = session.generate_sequence(request.content_related)
self.request = request self.request = request
self.container_msg_id = None
logging.getLogger(__name__).debug(self)
def to_dict(self, recursive=True):
return {
'msg_id': self.msg_id,
'seq_no': self.seq_no,
'request': self.request,
'container_msg_id': self.container_msg_id,
}
def __bytes__(self): def __bytes__(self):
body = GzipPacked.gzip_if_smaller(self.request) body = GzipPacked.gzip_if_smaller(self.request)
return struct.pack('<qii', self.msg_id, self.seq_no, len(body)) + body return struct.pack('<qii', self.msg_id, self.seq_no, len(body)) + body
def __str__(self):
return TLObject.pretty_format(self)
def stringify(self):
return TLObject.pretty_format(self, indent=0)

View File

@ -1,12 +1,10 @@
from datetime import datetime from datetime import datetime
from threading import Event
class TLObject: class TLObject:
def __init__(self): def __init__(self):
self.request_msg_id = 0 # Long self.request_msg_id = 0 # Long
self.confirm_received = None
self.confirm_received = Event()
self.rpc_error = None self.rpc_error = None
# These should be overrode # These should be overrode

View File

@ -1,8 +1,8 @@
import logging import logging
import pickle import pickle
import asyncio
from collections import deque from collections import deque
from datetime import datetime from datetime import datetime
from threading import RLock, Event, Thread
from .tl import types as tl from .tl import types as tl
@ -13,126 +13,24 @@ class UpdateState:
""" """
WORKER_POLL_TIMEOUT = 5.0 # Avoid waiting forever on the workers WORKER_POLL_TIMEOUT = 5.0 # Avoid waiting forever on the workers
def __init__(self, workers=None): def __init__(self, loop=None):
"""
:param workers: This integer parameter has three possible cases:
workers is None: Updates will *not* be stored on self.
workers = 0: Another thread is responsible for calling self.poll()
workers > 0: 'workers' background threads will be spawned, any
any of them will invoke all the self.handlers.
"""
self._workers = workers
self._worker_threads = []
self.handlers = [] self.handlers = []
self._updates_lock = RLock()
self._updates_available = Event()
self._updates = deque()
self._latest_updates = deque(maxlen=10) self._latest_updates = deque(maxlen=10)
self._loop = loop if loop else asyncio.get_event_loop()
self._logger = logging.getLogger(__name__) self._logger = logging.getLogger(__name__)
# https://core.telegram.org/api/updates # https://core.telegram.org/api/updates
self._state = tl.updates.State(0, 0, datetime.now(), 0, 0) self._state = tl.updates.State(0, 0, datetime.now(), 0, 0)
def can_poll(self): def handle_update(self, update):
"""Returns True if a call to .poll() won't lock"""
return self._updates_available.is_set()
def poll(self, timeout=None):
"""Polls an update or blocks until an update object is available.
If 'timeout is not None', it should be a floating point value,
and the method will 'return None' if waiting times out.
"""
if not self._updates_available.wait(timeout=timeout):
return
with self._updates_lock:
if not self._updates_available.is_set():
return
update = self._updates.popleft()
if not self._updates:
self._updates_available.clear()
if isinstance(update, Exception):
raise update # Some error was set through (surely StopIteration)
return update
def get_workers(self):
return self._workers
def set_workers(self, n):
"""Changes the number of workers running.
If 'n is None', clears all pending updates from memory.
"""
self.stop_workers()
self._workers = n
if n is None:
self._updates.clear()
else:
self.setup_workers()
workers = property(fget=get_workers, fset=set_workers)
def stop_workers(self):
"""Raises "StopIterationException" on the worker threads to stop them,
and also clears all of them off the list
"""
if self._workers:
with self._updates_lock:
# Insert at the beginning so the very next poll causes an error
# on all the worker threads
# TODO Should this reset the pts and such?
for _ in range(self._workers):
self._updates.appendleft(StopIteration())
self._updates_available.set()
for t in self._worker_threads:
t.join()
self._worker_threads.clear()
def setup_workers(self):
if self._worker_threads or not self._workers:
# There already are workers, or workers is None or 0. Do nothing.
return
for i in range(self._workers):
thread = Thread(
target=UpdateState._worker_loop,
name='UpdateWorker{}'.format(i),
daemon=True,
args=(self, i)
)
self._worker_threads.append(thread)
thread.start()
def _worker_loop(self, wid):
while True:
try:
update = self.poll(timeout=UpdateState.WORKER_POLL_TIMEOUT)
# TODO Maybe people can add different handlers per update type
if update:
for handler in self.handlers: for handler in self.handlers:
handler(update) asyncio.ensure_future(handler(update), loop=self._loop)
except StopIteration:
break
except Exception as e:
# We don't want to crash a worker thread due to any reason
self._logger.debug(
'[ERROR] Unhandled exception on worker {}'.format(wid), e
)
def process(self, update): def process(self, update):
"""Processes an update object. This method is normally called by """Processes an update object. This method is normally called by
the library itself. the library itself.
""" """
if self._workers is None:
return # No processing needs to be done if nobody's working
with self._updates_lock:
if isinstance(update, tl.updates.State): if isinstance(update, tl.updates.State):
self._state = update self._state = update
return # Nothing else to be done return # Nothing else to be done
@ -168,21 +66,18 @@ class UpdateState:
# Since .users and .chats have already been processed, we # Since .users and .chats have already been processed, we
# don't need to care about those either. # don't need to care about those either.
if isinstance(update, tl.UpdateShort): if isinstance(update, tl.UpdateShort):
self._updates.append(update.update) self.handle_update(update.update)
self._updates_available.set()
elif isinstance(update, (tl.Updates, tl.UpdatesCombined)): elif isinstance(update, (tl.Updates, tl.UpdatesCombined)):
self._updates.extend(update.updates) for upd in update.updates:
self._updates_available.set() self.handle_update(upd)
elif not isinstance(update, tl.UpdatesTooLong): elif not isinstance(update, tl.UpdatesTooLong):
# TODO Handle "Updates too long" # TODO Handle "Updates too long"
self._updates.append(update) self.handle_update(update)
self._updates_available.set()
elif type(update).SUBCLASS_OF_ID == 0x9f89304e: # crc32(b'Update') elif type(update).SUBCLASS_OF_ID == 0x9f89304e: # crc32(b'Update')
self._updates.append(update) self.handle_update(update)
self._updates_available.set()
else: else:
self._logger.debug('Ignoring "update" of type {}'.format( self._logger.debug('Ignoring "update" of type {}'.format(
type(update).__name__) type(update).__name__)