Update handlers works; it also seems stable

This commit is contained in:
Andrey Egorov 2017-10-22 15:06:36 +03:00
parent 917665852d
commit 780e0ceddf
9 changed files with 300 additions and 265 deletions

View File

@ -5,13 +5,12 @@ import socket
from datetime import timedelta from datetime import timedelta
from io import BytesIO, BufferedWriter from io import BytesIO, BufferedWriter
loop = asyncio.get_event_loop()
class TcpClient: class TcpClient:
def __init__(self, proxy=None, timeout=timedelta(seconds=5)): def __init__(self, proxy=None, timeout=timedelta(seconds=5), loop=None):
self.proxy = proxy self.proxy = proxy
self._socket = None self._socket = None
self._loop = loop if loop else asyncio.get_event_loop()
if isinstance(timeout, timedelta): if isinstance(timeout, timedelta):
self.timeout = timeout.seconds self.timeout = timeout.seconds
@ -31,7 +30,7 @@ class TcpClient:
else: # tuple, list, etc. else: # tuple, list, etc.
self._socket.set_proxy(*self.proxy) self._socket.set_proxy(*self.proxy)
self._socket.settimeout(self.timeout) self._socket.setblocking(False)
async def connect(self, ip, port): async def connect(self, ip, port):
"""Connects to the specified IP and port number. """Connects to the specified IP and port number.
@ -42,20 +41,27 @@ class TcpClient:
else: else:
mode, address = socket.AF_INET, (ip, port) mode, address = socket.AF_INET, (ip, port)
timeout = 1
while True: while True:
try: try:
while not self._socket: if not self._socket:
self._recreate_socket(mode) self._recreate_socket(mode)
await loop.sock_connect(self._socket, address) await self._loop.sock_connect(self._socket, address)
break # Successful connection, stop retrying to connect break # Successful connection, stop retrying to connect
except ConnectionError:
self._socket = None
await asyncio.sleep(min(timeout, 15))
timeout *= 2
except OSError as e: except OSError as e:
# There are some errors that we know how to handle, and # There are some errors that we know how to handle, and
# the loop will allow us to retry # the loop will allow us to retry
if e.errno == errno.EBADF: if e.errno in [errno.EBADF, errno.ENOTSOCK, errno.EINVAL]:
# Bad file descriptor, i.e. socket was closed, set it # Bad file descriptor, i.e. socket was closed, set it
# to none to recreate it on the next iteration # to none to recreate it on the next iteration
self._socket = None self._socket = None
await asyncio.sleep(min(timeout, 15))
timeout *= 2
else: else:
raise raise
@ -81,13 +87,14 @@ class TcpClient:
raise ConnectionResetError() raise ConnectionResetError()
try: try:
await loop.sock_sendall(self._socket, data) await asyncio.wait_for(self._loop.sock_sendall(self._socket, data),
except socket.timeout as e: timeout=self.timeout, loop=self._loop)
except asyncio.TimeoutError as e:
raise TimeoutError() from e raise TimeoutError() from e
except BrokenPipeError: except BrokenPipeError:
self._raise_connection_reset() self._raise_connection_reset()
except OSError as e: except OSError as e:
if e.errno == errno.EBADF: if e.errno in [errno.EBADF, errno.ENOTSOCK, errno.ENETUNREACH, errno.EINVAL, errno.ENOTCONN]:
self._raise_connection_reset() self._raise_connection_reset()
else: else:
raise raise
@ -104,11 +111,12 @@ class TcpClient:
bytes_left = size bytes_left = size
while bytes_left != 0: while bytes_left != 0:
try: try:
partial = await loop.sock_recv(self._socket, bytes_left) partial = await asyncio.wait_for(self._loop.sock_recv(self._socket, bytes_left),
except socket.timeout as e: timeout=self.timeout, loop=self._loop)
except asyncio.TimeoutError as e:
raise TimeoutError() from e raise TimeoutError() from e
except OSError as e: except OSError as e:
if e.errno == errno.EBADF or e.errno == errno.ENOTSOCK: if e.errno in [errno.EBADF, errno.ENOTSOCK, errno.ENETUNREACH, errno.EINVAL, errno.ENOTCONN]:
self._raise_connection_reset() self._raise_connection_reset()
else: else:
raise raise

View File

@ -43,13 +43,13 @@ class Connection:
""" """
def __init__(self, mode=ConnectionMode.TCP_FULL, def __init__(self, mode=ConnectionMode.TCP_FULL,
proxy=None, timeout=timedelta(seconds=5)): proxy=None, timeout=timedelta(seconds=5), loop=None):
self._mode = mode self._mode = mode
self._send_counter = 0 self._send_counter = 0
self._aes_encrypt, self._aes_decrypt = None, None self._aes_encrypt, self._aes_decrypt = None, None
# TODO Rename "TcpClient" as some sort of generic socket? # TODO Rename "TcpClient" as some sort of generic socket?
self.conn = TcpClient(proxy=proxy, timeout=timeout) self.conn = TcpClient(proxy=proxy, timeout=timeout, loop=loop)
# Sending messages # Sending messages
if mode == ConnectionMode.TCP_FULL: if mode == ConnectionMode.TCP_FULL:
@ -206,7 +206,7 @@ class Connection:
return await self.conn.read(length) return await self.conn.read(length)
async def _read_obfuscated(self, length): async def _read_obfuscated(self, length):
return await self._aes_decrypt.encrypt(self.conn.read(length)) return self._aes_decrypt.encrypt(await self.conn.read(length))
# endregion # endregion

View File

@ -1,6 +1,8 @@
import gzip import gzip
import logging import logging
import struct import struct
import asyncio
from asyncio import Event
from .. import helpers as utils from .. import helpers as utils
from ..crypto import AES from ..crypto import AES
@ -30,17 +32,15 @@ class MtProtoSender:
in parallel, so thread-safety (hence locking) isn't needed. in parallel, so thread-safety (hence locking) isn't needed.
""" """
def __init__(self, session, connection): def __init__(self, session, connection, loop=None):
"""Creates a new MtProtoSender configured to send messages through """Creates a new MtProtoSender configured to send messages through
'connection' and using the parameters from 'session'. 'connection' and using the parameters from 'session'.
""" """
self.session = session self.session = session
self.connection = connection self.connection = connection
self._loop = loop if loop else asyncio.get_event_loop()
self._logger = logging.getLogger(__name__) self._logger = logging.getLogger(__name__)
# Message IDs that need confirmation
self._need_confirmation = []
# Requests (as msg_id: Message) sent waiting to be received # Requests (as msg_id: Message) sent waiting to be received
self._pending_receive = {} self._pending_receive = {}
@ -54,12 +54,11 @@ class MtProtoSender:
def disconnect(self): def disconnect(self):
"""Disconnects from the server""" """Disconnects from the server"""
self.connection.close() self.connection.close()
self._need_confirmation.clear()
self._clear_all_pending() self._clear_all_pending()
def clone(self): def clone(self):
"""Creates a copy of this MtProtoSender as a new connection""" """Creates a copy of this MtProtoSender as a new connection"""
return MtProtoSender(self.session, self.connection.clone()) return MtProtoSender(self.session, self.connection.clone(), self._loop)
# region Send and receive # region Send and receive
@ -67,21 +66,23 @@ class MtProtoSender:
"""Sends the specified MTProtoRequest, previously sending any message """Sends the specified MTProtoRequest, previously sending any message
which needed confirmation.""" which needed confirmation."""
# Prepare the event of every request
for r in requests:
if r.confirm_received is None:
r.confirm_received = Event(loop=self._loop)
else:
r.confirm_received.clear()
# Finally send our packed request(s) # Finally send our packed request(s)
messages = [TLMessage(self.session, r) for r in requests] messages = [TLMessage(self.session, r) for r in requests]
self._pending_receive.update({m.msg_id: m for m in messages}) self._pending_receive.update({m.msg_id: m for m in messages})
# Pack everything in the same container if we need to send AckRequests
if self._need_confirmation:
messages.append(
TLMessage(self.session, MsgsAck(self._need_confirmation))
)
self._need_confirmation.clear()
if len(messages) == 1: if len(messages) == 1:
message = messages[0] message = messages[0]
else: else:
message = TLMessage(self.session, MessageContainer(messages)) message = TLMessage(self.session, MessageContainer(messages))
for m in messages:
m.container_msg_id = message.msg_id
await self._send_message(message) await self._send_message(message)
@ -115,6 +116,7 @@ class MtProtoSender:
message, remote_msg_id, remote_seq = self._decode_msg(body) message, remote_msg_id, remote_seq = self._decode_msg(body)
with BinaryReader(message) as reader: with BinaryReader(message) as reader:
await self._process_msg(remote_msg_id, remote_seq, reader, update_state) await self._process_msg(remote_msg_id, remote_seq, reader, update_state)
await self._send_acknowledge(remote_msg_id)
# endregion # endregion
@ -174,7 +176,6 @@ class MtProtoSender:
""" """
# TODO Check salt, session_id and sequence_number # TODO Check salt, session_id and sequence_number
self._need_confirmation.append(msg_id)
code = reader.read_int(signed=False) code = reader.read_int(signed=False)
reader.seek(-4) reader.seek(-4)
@ -210,14 +211,14 @@ class MtProtoSender:
if code == MsgsAck.CONSTRUCTOR_ID: # may handle the request we wanted if code == MsgsAck.CONSTRUCTOR_ID: # may handle the request we wanted
ack = reader.tgread_object() ack = reader.tgread_object()
assert isinstance(ack, MsgsAck) assert isinstance(ack, MsgsAck)
# Ignore every ack request *unless* when logging out, when it's # Ignore every ack request *unless* when logging out,
# when it seems to only make sense. We also need to set a non-None # when it seems to only make sense. We also need to set a non-None
# result since Telegram doesn't send the response for these. # result since Telegram doesn't send the response for these.
for msg_id in ack.msg_ids: for msg_id in ack.msg_ids:
r = self._pop_request_of_type(msg_id, LogOutRequest) r = self._pop_request_of_type(msg_id, LogOutRequest)
if r: if r:
r.result = True # Telegram won't send this value r.result = True # Telegram won't send this value
r.confirm_received() r.confirm_received.set()
self._logger.debug('Message ack confirmed', r) self._logger.debug('Message ack confirmed', r)
return True return True
@ -259,11 +260,29 @@ class MtProtoSender:
if message and isinstance(message.request, t): if message and isinstance(message.request, t):
return self._pending_receive.pop(msg_id).request return self._pending_receive.pop(msg_id).request
def _pop_requests_of_container(self, container_msg_id):
msgs = [msg for msg in self._pending_receive.values() if msg.container_msg_id == container_msg_id]
requests = [msg.request for msg in msgs]
for msg in msgs:
self._pending_receive.pop(msg.msg_id, None)
return requests
def _clear_all_pending(self): def _clear_all_pending(self):
for r in self._pending_receive.values(): for r in self._pending_receive.values():
r.confirm_received.set() r.request.confirm_received.set()
self._pending_receive.clear() self._pending_receive.clear()
async def _resend_request(self, msg_id):
request = self._pop_request(msg_id)
if request:
self._logger.debug('requests is about to resend')
await self.send(request)
return
requests = self._pop_requests_of_container(msg_id)
if requests:
self._logger.debug('container of requests is about to resend')
await self.send(*requests)
async def _handle_pong(self, msg_id, sequence, reader): async def _handle_pong(self, msg_id, sequence, reader):
self._logger.debug('Handling pong') self._logger.debug('Handling pong')
pong = reader.tgread_object() pong = reader.tgread_object()
@ -303,10 +322,9 @@ class MtProtoSender:
self.session.salt = struct.unpack( self.session.salt = struct.unpack(
'<Q', struct.pack('<q', bad_salt.new_server_salt) '<Q', struct.pack('<q', bad_salt.new_server_salt)
)[0] )[0]
self.session.save()
request = self._pop_request(bad_salt.bad_msg_id) await self._resend_request(bad_salt.bad_msg_id)
if request:
await self.send(request)
return True return True
@ -322,15 +340,18 @@ class MtProtoSender:
self.session.update_time_offset(correct_msg_id=msg_id) self.session.update_time_offset(correct_msg_id=msg_id)
self._logger.debug('Read Bad Message error: ' + str(error)) self._logger.debug('Read Bad Message error: ' + str(error))
self._logger.debug('Attempting to use the correct time offset.') self._logger.debug('Attempting to use the correct time offset.')
await self._resend_request(bad_msg.bad_msg_id)
return True return True
elif bad_msg.error_code == 32: elif bad_msg.error_code == 32:
# msg_seqno too low, so just pump it up by some "large" amount # msg_seqno too low, so just pump it up by some "large" amount
# TODO A better fix would be to start with a new fresh session ID # TODO A better fix would be to start with a new fresh session ID
self.session._sequence += 64 self.session._sequence += 64
await self._resend_request(bad_msg.bad_msg_id)
return True return True
elif bad_msg.error_code == 33: elif bad_msg.error_code == 33:
# msg_seqno too high never seems to happen but just in case # msg_seqno too high never seems to happen but just in case
self.session._sequence -= 16 self.session._sequence -= 16
await self._resend_request(bad_msg.bad_msg_id)
return True return True
else: else:
raise error raise error
@ -341,7 +362,6 @@ class MtProtoSender:
# TODO For now, simply ack msg_new.answer_msg_id # TODO For now, simply ack msg_new.answer_msg_id
# Relevant tdesktop source code: https://goo.gl/VvpCC6 # Relevant tdesktop source code: https://goo.gl/VvpCC6
await self._send_acknowledge(msg_new.answer_msg_id)
return True return True
async def _handle_msg_new_detailed_info(self, msg_id, sequence, reader): async def _handle_msg_new_detailed_info(self, msg_id, sequence, reader):
@ -350,7 +370,6 @@ class MtProtoSender:
# TODO For now, simply ack msg_new.answer_msg_id # TODO For now, simply ack msg_new.answer_msg_id
# Relevant tdesktop source code: https://goo.gl/G7DPsR # Relevant tdesktop source code: https://goo.gl/G7DPsR
await self._send_acknowledge(msg_new.answer_msg_id)
return True return True
async def _handle_new_session_created(self, msg_id, sequence, reader): async def _handle_new_session_created(self, msg_id, sequence, reader):
@ -378,9 +397,6 @@ class MtProtoSender:
reader.read_int(), reader.tgread_string() reader.read_int(), reader.tgread_string()
) )
# Acknowledge that we received the error
await self._send_acknowledge(request_id)
if request: if request:
request.rpc_error = error request.rpc_error = error
request.confirm_received.set() request.confirm_received.set()

View File

@ -1,10 +1,10 @@
import logging import logging
import os import os
import warnings import asyncio
from datetime import timedelta, datetime from datetime import timedelta, datetime
from hashlib import md5 from hashlib import md5
from io import BytesIO from io import BytesIO
from time import sleep from asyncio import Lock
from . import helpers as utils from . import helpers as utils
from .crypto import rsa, CdnDecrypter from .crypto import rsa, CdnDecrypter
@ -17,7 +17,7 @@ from .network import authenticator, MtProtoSender, Connection, ConnectionMode
from .tl import TLObject, Session from .tl import TLObject, Session
from .tl.all_tlobjects import LAYER from .tl.all_tlobjects import LAYER
from .tl.functions import ( from .tl.functions import (
InitConnectionRequest, InvokeWithLayerRequest InitConnectionRequest, InvokeWithLayerRequest, PingRequest
) )
from .tl.functions.auth import ( from .tl.functions.auth import (
ImportAuthorizationRequest, ExportAuthorizationRequest ImportAuthorizationRequest, ExportAuthorizationRequest
@ -67,6 +67,7 @@ class TelegramBareClient:
connection_mode=ConnectionMode.TCP_FULL, connection_mode=ConnectionMode.TCP_FULL,
proxy=None, proxy=None,
timeout=timedelta(seconds=5), timeout=timedelta(seconds=5),
loop=None,
**kwargs): **kwargs):
"""Refer to TelegramClient.__init__ for docs on this method""" """Refer to TelegramClient.__init__ for docs on this method"""
if not api_id or not api_hash: if not api_id or not api_hash:
@ -82,6 +83,8 @@ class TelegramBareClient:
'The given session must be a str or a Session instance.' 'The given session must be a str or a Session instance.'
) )
self._loop = loop if loop else asyncio.get_event_loop()
self.session = session self.session = session
self.api_id = int(api_id) self.api_id = int(api_id)
self.api_hash = api_hash self.api_hash = api_hash
@ -92,12 +95,18 @@ class TelegramBareClient:
# that calls .connect(). Every other thread will spawn a new # that calls .connect(). Every other thread will spawn a new
# temporary connection. The connection on this one is always # temporary connection. The connection on this one is always
# kept open so Telegram can send us updates. # kept open so Telegram can send us updates.
self._sender = MtProtoSender(self.session, Connection( self._sender = MtProtoSender(
mode=connection_mode, proxy=proxy, timeout=timeout self.session,
)) Connection(mode=connection_mode, proxy=proxy, timeout=timeout, loop=self._loop),
self._loop
)
self._logger = logging.getLogger(__name__) self._logger = logging.getLogger(__name__)
# Two coroutines may be calling reconnect() when the connection is lost,
# we only want one to actually perform the reconnection.
self._reconnect_lock = Lock(loop=self._loop)
# Cache "exported" sessions as 'dc_id: Session' not to recreate # Cache "exported" sessions as 'dc_id: Session' not to recreate
# them all the time since generating a new key is a relatively # them all the time since generating a new key is a relatively
# expensive operation. # expensive operation.
@ -105,7 +114,7 @@ class TelegramBareClient:
# This member will process updates if enabled. # This member will process updates if enabled.
# One may change self.updates.enabled at any later point. # One may change self.updates.enabled at any later point.
self.updates = UpdateState(workers=None) self.updates = UpdateState(self._loop)
# Used on connection - the user may modify these and reconnect # Used on connection - the user may modify these and reconnect
kwargs['app_version'] = kwargs.get('app_version', self.__version__) kwargs['app_version'] = kwargs.get('app_version', self.__version__)
@ -129,10 +138,11 @@ class TelegramBareClient:
# Uploaded files cache so subsequent calls are instant # Uploaded files cache so subsequent calls are instant
self._upload_cache = {} self._upload_cache = {}
# Default PingRequest delay self._recv_loop = None
self._last_ping = datetime.now() self._ping_loop = None
self._ping_delay = timedelta(minutes=1)
# Default PingRequest delay
self._ping_delay = timedelta(minutes=1)
# endregion # endregion
@ -167,6 +177,7 @@ class TelegramBareClient:
self.session.auth_key, self.session.time_offset = \ self.session.auth_key, self.session.time_offset = \
await authenticator.do_authentication(self._sender.connection) await authenticator.do_authentication(self._sender.connection)
except BrokenAuthKeyError: except BrokenAuthKeyError:
self._user_connected = False
return False return False
self.session.layer = LAYER self.session.layer = LAYER
@ -198,12 +209,12 @@ class TelegramBareClient:
# another data center and this would raise UserMigrateError) # another data center and this would raise UserMigrateError)
# to also assert whether the user is logged in or not. # to also assert whether the user is logged in or not.
self._user_connected = True self._user_connected = True
if _sync_updates and not _cdn: if _sync_updates and not _cdn and not self._authorized:
try: try:
await self.sync_updates() await self.sync_updates()
self._set_connected_and_authorized() self._set_connected_and_authorized()
except UnauthorizedError: except UnauthorizedError:
self._authorized = False pass
return True return True
@ -211,7 +222,7 @@ class TelegramBareClient:
# This is fine, probably layer migration # This is fine, probably layer migration
self._logger.debug('Found invalid item, probably migrating', e) self._logger.debug('Found invalid item, probably migrating', e)
self.disconnect() self.disconnect()
return self.connect( return await self.connect(
_exported_auth=_exported_auth, _exported_auth=_exported_auth,
_sync_updates=_sync_updates, _sync_updates=_sync_updates,
_cdn=_cdn _cdn=_cdn
@ -261,7 +272,17 @@ class TelegramBareClient:
""" """
if new_dc is None: if new_dc is None:
# Assume we are disconnected due to some error, so connect again # Assume we are disconnected due to some error, so connect again
return await self.connect() try:
await self._reconnect_lock.acquire()
# Another thread may have connected again, so check that first
if self.is_connected():
return True
return await self.connect()
except ConnectionResetError:
return False
finally:
self._reconnect_lock.release()
else: else:
self.disconnect() self.disconnect()
self.session.auth_key = None # Force creating new auth_key self.session.auth_key = None # Force creating new auth_key
@ -337,7 +358,8 @@ class TelegramBareClient:
client = TelegramBareClient( client = TelegramBareClient(
session, self.api_id, self.api_hash, session, self.api_id, self.api_hash,
proxy=self._sender.connection.conn.proxy, proxy=self._sender.connection.conn.proxy,
timeout=self._sender.connection.get_timeout() timeout=self._sender.connection.get_timeout(),
loop=self._loop
) )
await client.connect(_exported_auth=export_auth, _sync_updates=False) await client.connect(_exported_auth=export_auth, _sync_updates=False)
client._authorized = True # We exported the auth, so we got auth client._authorized = True # We exported the auth, so we got auth
@ -356,7 +378,8 @@ class TelegramBareClient:
client = TelegramBareClient( client = TelegramBareClient(
session, self.api_id, self.api_hash, session, self.api_id, self.api_hash,
proxy=self._sender.connection.conn.proxy, proxy=self._sender.connection.conn.proxy,
timeout=self._sender.connection.get_timeout() timeout=self._sender.connection.get_timeout(),
loop=self._loop
) )
# This will make use of the new RSA keys for this specific CDN. # This will make use of the new RSA keys for this specific CDN.
@ -381,55 +404,52 @@ class TelegramBareClient:
x.content_related for x in requests): x.content_related for x in requests):
raise ValueError('You can only invoke requests, not types!') raise ValueError('You can only invoke requests, not types!')
# TODO Determine the sender to be used (main or a new connection) # We should call receive from this thread if there's no background
sender = self._sender # .clone(), .connect() # thread reading or if the server disconnected us and we're trying
# to reconnect. This is because the read thread may either be
# locked also trying to reconnect or we may be said thread already.
call_receive = self._recv_loop is None
try: for retry in range(retries):
for _ in range(retries): result = await self._invoke(call_receive, retry, *requests)
result = await self._invoke(sender, *requests) if result is not None:
if result is not None: return result
return result
raise ValueError('Number of retries reached 0.') return None
finally:
if sender != self._sender:
sender.disconnect() # Close temporary connections
# Let people use client.invoke(SomeRequest()) instead client(...) # Let people use client.invoke(SomeRequest()) instead client(...)
invoke = __call__ invoke = __call__
async def _invoke(self, sender, *requests): async def _invoke(self, call_receive, retry, *requests):
try: try:
# Ensure that we start with no previous errors (i.e. resending) # Ensure that we start with no previous errors (i.e. resending)
for x in requests: for x in requests:
x.confirm_received.clear()
x.rpc_error = None x.rpc_error = None
await sender.send(*requests) await self._sender.send(*requests)
while not all(x.confirm_received.is_set() for x in requests):
await sender.receive(update_state=self.updates)
except TimeoutError: if not call_receive:
pass # We will just retry await asyncio.wait(
list(map(lambda x: x.confirm_received.wait(), requests)),
timeout=self._sender.connection.get_timeout(),
loop=self._loop
)
else:
while not all(x.confirm_received.is_set() for x in requests):
await self._sender.receive(update_state=self.updates)
except ConnectionResetError: except ConnectionResetError:
if not self._user_connected: if not self._user_connected or self._reconnect_lock.locked():
# Only attempt reconnecting if we're authorized # Only attempt reconnecting if the user called connect and not
# reconnecting already.
raise raise
self._logger.debug('Server disconnected us. Reconnecting and ' self._logger.debug('Server disconnected us. Reconnecting and '
'resending request...') 'resending request... (%d)' % retry)
await self._reconnect()
if sender != self._sender: if not self._sender.is_connected():
# TODO Try reconnecting forever too? await asyncio.sleep(retry + 1, loop=self._loop)
await sender.connect() return None
else:
while self._user_connected and not await self._reconnect():
sleep(0.1) # Retry forever until we can send the request
finally:
if sender != self._sender:
sender.disconnect()
try: try:
raise next(x.rpc_error for x in requests if x.rpc_error) raise next(x.rpc_error for x in requests if x.rpc_error)
@ -452,7 +472,7 @@ class TelegramBareClient:
) )
await self._reconnect(new_dc=e.new_dc) await self._reconnect(new_dc=e.new_dc)
return await self._invoke(sender, *requests) return None
except ServerError as e: except ServerError as e:
# Telegram is having some issues, just retry # Telegram is having some issues, just retry
@ -467,7 +487,8 @@ class TelegramBareClient:
self._logger.debug( self._logger.debug(
'Sleep of %d seconds below threshold, sleeping' % e.seconds 'Sleep of %d seconds below threshold, sleeping' % e.seconds
) )
sleep(e.seconds) await asyncio.sleep(e.seconds, loop=self._loop)
return None
# Some really basic functionality # Some really basic functionality
@ -670,16 +691,13 @@ class TelegramBareClient:
""" """
self.updates.process(await self(GetStateRequest())) self.updates.process(await self(GetStateRequest()))
def add_update_handler(self, handler): async def add_update_handler(self, handler):
"""Adds an update handler (a function which takes a TLObject, """Adds an update handler (a function which takes a TLObject,
an update, as its parameter) and listens for updates""" an update, as its parameter) and listens for updates"""
if not self.updates.get_workers:
warnings.warn("There are no update workers running, so adding an update handler will have no effect.")
sync = not self.updates.handlers sync = not self.updates.handlers
self.updates.handlers.append(handler) self.updates.handlers.append(handler)
if sync: if sync:
self.sync_updates() await self.sync_updates()
def remove_update_handler(self, handler): def remove_update_handler(self, handler):
self.updates.handlers.remove(handler) self.updates.handlers.remove(handler)
@ -693,6 +711,63 @@ class TelegramBareClient:
def _set_connected_and_authorized(self): def _set_connected_and_authorized(self):
self._authorized = True self._authorized = True
# TODO self.updates.setup_workers() if self._recv_loop is None:
self._recv_loop = asyncio.ensure_future(self._recv_loop_impl(), loop=self._loop)
if self._ping_loop is None:
self._ping_loop = asyncio.ensure_future(self._ping_loop_impl(), loop=self._loop)
async def _ping_loop_impl(self):
while self._user_connected:
await self(PingRequest(int.from_bytes(os.urandom(8), 'big', signed=True)))
await asyncio.sleep(self._ping_delay.seconds, loop=self._loop)
self._ping_loop = None
async def _recv_loop_impl(self):
need_reconnect = False
timeout = 1
while self._user_connected:
try:
if need_reconnect:
need_reconnect = False
while self._user_connected and not await self._reconnect():
await asyncio.sleep(0.1, loop=self._loop) # Retry forever, this is instant messaging
await self._sender.receive(update_state=self.updates)
except TimeoutError:
# No problem.
pass
except ConnectionError as error:
self._logger.debug(error)
need_reconnect = True
await asyncio.sleep(min(timeout, 15), loop=self._loop)
timeout *= 2
except Exception as error:
# Unknown exception, pass it to the main thread
self._logger.debug(
'[ERROR] Unknown error on the read loop, please report',
error
)
try:
import socks
if isinstance(error, (
socks.GeneralProxyError, socks.ProxyConnectionError
)):
# This is a known error, and it's not related to
# Telegram but rather to the proxy. Disconnect and
# hand it over to the main thread.
self._background_error = error
self.disconnect()
break
except ImportError:
"Not using PySocks, so it can't be a socket error"
# If something strange happens we don't want to enter an
# infinite loop where all we do is raise an exception, so
# add a little sleep to avoid the CPU usage going mad.
await asyncio.sleep(0.1, loop=self._loop)
break
timeout = 1
self._recv_loop = None
# endregion # endregion

View File

@ -61,6 +61,7 @@ class TelegramClient(TelegramBareClient):
connection_mode=ConnectionMode.TCP_FULL, connection_mode=ConnectionMode.TCP_FULL,
proxy=None, proxy=None,
timeout=timedelta(seconds=5), timeout=timedelta(seconds=5),
loop=None,
**kwargs): **kwargs):
"""Initializes the Telegram client with the specified API ID and Hash. """Initializes the Telegram client with the specified API ID and Hash.
@ -87,6 +88,7 @@ class TelegramClient(TelegramBareClient):
connection_mode=connection_mode, connection_mode=connection_mode,
proxy=proxy, proxy=proxy,
timeout=timeout, timeout=timeout,
loop=loop,
**kwargs **kwargs
) )
@ -104,8 +106,9 @@ class TelegramClient(TelegramBareClient):
"""Sends a code request to the specified phone number""" """Sends a code request to the specified phone number"""
phone = EntityDatabase.parse_phone(phone) or self._phone phone = EntityDatabase.parse_phone(phone) or self._phone
result = await self(SendCodeRequest(phone, self.api_id, self.api_hash)) result = await self(SendCodeRequest(phone, self.api_id, self.api_hash))
self._phone = phone if result:
self._phone_code_hash = result.phone_code_hash self._phone = phone
self._phone_code_hash = result.phone_code_hash
return result return result
async def sign_in(self, phone=None, code=None, async def sign_in(self, phone=None, code=None,
@ -169,8 +172,10 @@ class TelegramClient(TelegramBareClient):
'and a password only if an RPCError was raised before.' 'and a password only if an RPCError was raised before.'
) )
self._set_connected_and_authorized() if result:
return result.user self._set_connected_and_authorized()
return result.user
return result
async def sign_up(self, code, first_name, last_name=''): async def sign_up(self, code, first_name, last_name=''):
"""Signs up to Telegram. Make sure you sent a code request first!""" """Signs up to Telegram. Make sure you sent a code request first!"""
@ -182,8 +187,10 @@ class TelegramClient(TelegramBareClient):
last_name=last_name last_name=last_name
)) ))
self._set_connected_and_authorized() if result:
return result.user self._set_connected_and_authorized()
return result.user
return result
async def log_out(self): async def log_out(self):
"""Logs out and deletes the current session. """Logs out and deletes the current session.
@ -239,7 +246,7 @@ class TelegramClient(TelegramBareClient):
offset_peer=offset_peer, offset_peer=offset_peer,
limit=need if need < float('inf') else 0 limit=need if need < float('inf') else 0
)) ))
if not r.dialogs: if not r or not r.dialogs:
break break
for d in r.dialogs: for d in r.dialogs:
@ -288,10 +295,12 @@ class TelegramClient(TelegramBareClient):
:return List[telethon.tl.custom.Draft]: A list of open drafts :return List[telethon.tl.custom.Draft]: A list of open drafts
""" """
response = await self(GetAllDraftsRequest()) response = await self(GetAllDraftsRequest())
self.session.process_entities(response) if response:
self.session.generate_sequence(response.seq) self.session.process_entities(response)
drafts = [Draft._from_update(self, u) for u in response.updates] self.session.generate_sequence(response.seq)
return drafts drafts = [Draft._from_update(self, u) for u in response.updates]
return drafts
return response
async def send_message(self, async def send_message(self,
entity, entity,
@ -313,6 +322,9 @@ class TelegramClient(TelegramBareClient):
reply_to_msg_id=self._get_reply_to(reply_to) reply_to_msg_id=self._get_reply_to(reply_to)
) )
result = await self(request) result = await self(request)
if not result:
return result
if isinstance(result, UpdateShortSentMessage): if isinstance(result, UpdateShortSentMessage):
return Message( return Message(
id=result.id, id=result.id,
@ -407,6 +419,8 @@ class TelegramClient(TelegramBareClient):
min_id=min_id, min_id=min_id,
add_offset=add_offset add_offset=add_offset
)) ))
if not result:
return result
# The result may be a messages slice (not all messages were retrieved) # The result may be a messages slice (not all messages were retrieved)
# or simply a messages TLObject. In the later case, no "count" # or simply a messages TLObject. In the later case, no "count"

View File

@ -11,6 +11,12 @@ class MessageContainer(TLObject):
self.content_related = False self.content_related = False
self.messages = messages self.messages = messages
def to_dict(self, recursive=True):
return {
'content_related': self.content_related,
'messages': self.messages,
}
def to_bytes(self): def to_bytes(self):
return struct.pack( return struct.pack(
'<Ii', MessageContainer.CONSTRUCTOR_ID, len(self.messages) '<Ii', MessageContainer.CONSTRUCTOR_ID, len(self.messages)
@ -25,3 +31,9 @@ class MessageContainer(TLObject):
inner_sequence = reader.read_int() inner_sequence = reader.read_int()
inner_length = reader.read_int() inner_length = reader.read_int()
yield inner_msg_id, inner_sequence, inner_length yield inner_msg_id, inner_sequence, inner_length
def __str__(self):
return TLObject.pretty_format(self)
def stringify(self):
return TLObject.pretty_format(self, indent=0)

View File

@ -1,4 +1,5 @@
import struct import struct
import logging
from . import TLObject, GzipPacked from . import TLObject, GzipPacked
@ -11,7 +12,23 @@ class TLMessage(TLObject):
self.msg_id = session.get_new_msg_id() self.msg_id = session.get_new_msg_id()
self.seq_no = session.generate_sequence(request.content_related) self.seq_no = session.generate_sequence(request.content_related)
self.request = request self.request = request
self.container_msg_id = None
logging.getLogger(__name__).debug(self)
def to_dict(self, recursive=True):
return {
'msg_id': self.msg_id,
'seq_no': self.seq_no,
'request': self.request,
'container_msg_id': self.container_msg_id,
}
def to_bytes(self): def to_bytes(self):
body = GzipPacked.gzip_if_smaller(self.request) body = GzipPacked.gzip_if_smaller(self.request)
return struct.pack('<qii', self.msg_id, self.seq_no, len(body)) + body return struct.pack('<qii', self.msg_id, self.seq_no, len(body)) + body
def __str__(self):
return TLObject.pretty_format(self)
def stringify(self):
return TLObject.pretty_format(self, indent=0)

View File

@ -1,12 +1,10 @@
from datetime import datetime from datetime import datetime
from threading import Event
class TLObject: class TLObject:
def __init__(self): def __init__(self):
self.request_msg_id = 0 # Long self.request_msg_id = 0 # Long
self.confirm_received = None
self.confirm_received = Event()
self.rpc_error = None self.rpc_error = None
# These should be overrode # These should be overrode

View File

@ -1,8 +1,8 @@
import logging import logging
import pickle import pickle
import asyncio
from collections import deque from collections import deque
from datetime import datetime from datetime import datetime
from threading import RLock, Event, Thread
from .tl import types as tl from .tl import types as tl
@ -13,177 +13,72 @@ class UpdateState:
""" """
WORKER_POLL_TIMEOUT = 5.0 # Avoid waiting forever on the workers WORKER_POLL_TIMEOUT = 5.0 # Avoid waiting forever on the workers
def __init__(self, workers=None): def __init__(self, loop=None):
"""
:param workers: This integer parameter has three possible cases:
workers is None: Updates will *not* be stored on self.
workers = 0: Another thread is responsible for calling self.poll()
workers > 0: 'workers' background threads will be spawned, any
any of them will invoke all the self.handlers.
"""
self._workers = workers
self._worker_threads = []
self.handlers = [] self.handlers = []
self._updates_lock = RLock()
self._updates_available = Event()
self._updates = deque()
self._latest_updates = deque(maxlen=10) self._latest_updates = deque(maxlen=10)
self._loop = loop if loop else asyncio.get_event_loop()
self._logger = logging.getLogger(__name__) self._logger = logging.getLogger(__name__)
# https://core.telegram.org/api/updates # https://core.telegram.org/api/updates
self._state = tl.updates.State(0, 0, datetime.now(), 0, 0) self._state = tl.updates.State(0, 0, datetime.now(), 0, 0)
def can_poll(self): def handle_update(self, update):
"""Returns True if a call to .poll() won't lock""" for handler in self.handlers:
return self._updates_available.is_set() asyncio.ensure_future(handler(update), loop=self._loop)
def poll(self, timeout=None):
"""Polls an update or blocks until an update object is available.
If 'timeout is not None', it should be a floating point value,
and the method will 'return None' if waiting times out.
"""
if not self._updates_available.wait(timeout=timeout):
return
with self._updates_lock:
if not self._updates_available.is_set():
return
update = self._updates.popleft()
if not self._updates:
self._updates_available.clear()
if isinstance(update, Exception):
raise update # Some error was set through (surely StopIteration)
return update
def get_workers(self):
return self._workers
def set_workers(self, n):
"""Changes the number of workers running.
If 'n is None', clears all pending updates from memory.
"""
self.stop_workers()
self._workers = n
if n is None:
self._updates.clear()
else:
self.setup_workers()
workers = property(fget=get_workers, fset=set_workers)
def stop_workers(self):
"""Raises "StopIterationException" on the worker threads to stop them,
and also clears all of them off the list
"""
if self._workers:
with self._updates_lock:
# Insert at the beginning so the very next poll causes an error
# on all the worker threads
# TODO Should this reset the pts and such?
for _ in range(self._workers):
self._updates.appendleft(StopIteration())
self._updates_available.set()
for t in self._worker_threads:
t.join()
self._worker_threads.clear()
def setup_workers(self):
if self._worker_threads or not self._workers:
# There already are workers, or workers is None or 0. Do nothing.
return
for i in range(self._workers):
thread = Thread(
target=UpdateState._worker_loop,
name='UpdateWorker{}'.format(i),
daemon=True,
args=(self, i)
)
self._worker_threads.append(thread)
thread.start()
def _worker_loop(self, wid):
while True:
try:
update = self.poll(timeout=UpdateState.WORKER_POLL_TIMEOUT)
# TODO Maybe people can add different handlers per update type
if update:
for handler in self.handlers:
handler(update)
except StopIteration:
break
except Exception as e:
# We don't want to crash a worker thread due to any reason
self._logger.debug(
'[ERROR] Unhandled exception on worker {}'.format(wid), e
)
def process(self, update): def process(self, update):
"""Processes an update object. This method is normally called by """Processes an update object. This method is normally called by
the library itself. the library itself.
""" """
if self._workers is None: if isinstance(update, tl.updates.State):
return # No processing needs to be done if nobody's working self._state = update
return # Nothing else to be done
with self._updates_lock: pts = getattr(update, 'pts', self._state.pts)
if isinstance(update, tl.updates.State): if hasattr(update, 'pts') and pts <= self._state.pts:
self._state = update return # We already handled this update
return # Nothing else to be done
pts = getattr(update, 'pts', self._state.pts) self._state.pts = pts
if hasattr(update, 'pts') and pts <= self._state.pts:
return # We already handled this update
self._state.pts = pts # TODO There must be a better way to handle updates rather than
# keeping a queue with the latest updates only, and handling
# the 'pts' correctly should be enough. However some updates
# like UpdateUserStatus (even inside UpdateShort) will be called
# repeatedly very often if invoking anything inside an update
# handler. TODO Figure out why.
"""
client = TelegramClient('anon', api_id, api_hash, update_workers=1)
client.connect()
def handle(u):
client.get_me()
client.add_update_handler(handle)
input('Enter to exit.')
"""
data = pickle.dumps(update.to_dict())
if data in self._latest_updates:
return # Duplicated too
# TODO There must be a better way to handle updates rather than self._latest_updates.append(data)
# keeping a queue with the latest updates only, and handling
# the 'pts' correctly should be enough. However some updates
# like UpdateUserStatus (even inside UpdateShort) will be called
# repeatedly very often if invoking anything inside an update
# handler. TODO Figure out why.
"""
client = TelegramClient('anon', api_id, api_hash, update_workers=1)
client.connect()
def handle(u):
client.get_me()
client.add_update_handler(handle)
input('Enter to exit.')
"""
data = pickle.dumps(update.to_dict())
if data in self._latest_updates:
return # Duplicated too
self._latest_updates.append(data) if type(update).SUBCLASS_OF_ID == 0x8af52aac: # crc32(b'Updates')
# Expand "Updates" into "Update", and pass these to callbacks.
# Since .users and .chats have already been processed, we
# don't need to care about those either.
if isinstance(update, tl.UpdateShort):
self.handle_update(update.update)
if type(update).SUBCLASS_OF_ID == 0x8af52aac: # crc32(b'Updates') elif isinstance(update, (tl.Updates, tl.UpdatesCombined)):
# Expand "Updates" into "Update", and pass these to callbacks. for upd in update.updates:
# Since .users and .chats have already been processed, we self.handle_update(upd)
# don't need to care about those either.
if isinstance(update, tl.UpdateShort):
self._updates.append(update.update)
self._updates_available.set()
elif isinstance(update, (tl.Updates, tl.UpdatesCombined)): elif not isinstance(update, tl.UpdatesTooLong):
self._updates.extend(update.updates) # TODO Handle "Updates too long"
self._updates_available.set() self.handle_update(update)
elif not isinstance(update, tl.UpdatesTooLong): elif type(update).SUBCLASS_OF_ID == 0x9f89304e: # crc32(b'Update')
# TODO Handle "Updates too long" self.handle_update(update)
self._updates.append(update) else:
self._updates_available.set() self._logger.debug('Ignoring "update" of type {}'.format(
type(update).__name__)
elif type(update).SUBCLASS_OF_ID == 0x9f89304e: # crc32(b'Update') )
self._updates.append(update)
self._updates_available.set()
else:
self._logger.debug('Ignoring "update" of type {}'.format(
type(update).__name__)
)