Support custom event loops

This commit is contained in:
Lonami Exo 2018-06-14 19:35:12 +02:00
parent 908dfa148b
commit 0f14f3b16a
11 changed files with 48 additions and 40 deletions

View File

@ -217,7 +217,8 @@ class MessageMethods(UploadMethods, MessageParseMethods):
else: else:
request.max_date = r.messages[-1].date request.max_date = r.messages[-1].date
await asyncio.sleep(max(wait_time - (time.time() - start), 0)) await asyncio.sleep(
max(wait_time - (time.time() - start), 0), loop=self._loop)
async def get_messages(self, *args, **kwargs): async def get_messages(self, *args, **kwargs):
""" """

View File

@ -1,4 +1,5 @@
import abc import abc
import asyncio
import logging import logging
import platform import platform
import warnings import warnings
@ -112,7 +113,8 @@ class TelegramBaseClient(abc.ABC):
system_version=None, system_version=None,
app_version=None, app_version=None,
lang_code='en', lang_code='en',
system_lang_code='en'): system_lang_code='en',
loop=None):
"""Refer to TelegramClient.__init__ for docs on this method""" """Refer to TelegramClient.__init__ for docs on this method"""
if not api_id or not api_hash: if not api_id or not api_hash:
raise ValueError( raise ValueError(
@ -120,6 +122,7 @@ class TelegramBaseClient(abc.ABC):
"Refer to telethon.rtfd.io for more information.") "Refer to telethon.rtfd.io for more information.")
self._use_ipv6 = use_ipv6 self._use_ipv6 = use_ipv6
self._loop = loop or asyncio.get_event_loop()
# Determine what session object we have # Determine what session object we have
if isinstance(session, str) or session is None: if isinstance(session, str) or session is None:
@ -143,12 +146,9 @@ class TelegramBaseClient(abc.ABC):
self.api_id = int(api_id) self.api_id = int(api_id)
self.api_hash = api_hash self.api_hash = api_hash
# This is the main sender, which will be used from the thread
# that calls .connect(). Every other thread will spawn a new
# temporary connection. The connection on this one is always
# kept open so Telegram can send us updates.
if isinstance(connection, type): if isinstance(connection, type):
connection = connection(proxy=proxy, timeout=timeout) connection = connection(
proxy=proxy, timeout=timeout, loop=self._loop)
# Used on connection. Capture the variables in a lambda since # Used on connection. Capture the variables in a lambda since
# exporting clients need to create this InvokeWithLayerRequest. # exporting clients need to create this InvokeWithLayerRequest.
@ -169,7 +169,7 @@ class TelegramBaseClient(abc.ABC):
state = MTProtoState(self.session.auth_key) state = MTProtoState(self.session.auth_key)
self._connection = connection self._connection = connection
self._sender = MTProtoSender( self._sender = MTProtoSender(
state, connection, state, connection, self._loop,
first_query=self._init_with(functions.help.GetConfigRequest()), first_query=self._init_with(functions.help.GetConfigRequest()),
update_callback=self._handle_update update_callback=self._handle_update
) )
@ -211,6 +211,14 @@ class TelegramBaseClient(abc.ABC):
# endregion # endregion
# region Properties
@property
def loop(self):
return self._loop
# endregion
# region Connecting # region Connecting
async def connect(self): async def connect(self):
@ -287,12 +295,11 @@ class TelegramBaseClient(abc.ABC):
auth = self._exported_auths.get(dc_id) auth = self._exported_auths.get(dc_id)
dc = await self._get_dc(dc_id) dc = await self._get_dc(dc_id)
state = MTProtoState(auth) state = MTProtoState(auth)
# TODO Don't hardcode ConnectionTcpFull()
# Can't reuse self._sender._connection as it has its own seqno. # Can't reuse self._sender._connection as it has its own seqno.
# #
# If one were to do that, Telegram would reset the connection # If one were to do that, Telegram would reset the connection
# with no further clues. # with no further clues.
sender = MTProtoSender(state, ConnectionTcpFull()) sender = MTProtoSender(state, self._connection.clone(), self._loop)
await sender.connect(dc.ip_address, dc.port) await sender.connect(dc.ip_address, dc.port)
if not auth: if not auth:
__log__.info('Exporting authorization for data center %s', dc) __log__.info('Exporting authorization for data center %s', dc)

View File

@ -144,7 +144,7 @@ class UpdateMethods(UserMethods):
# region Private methods # region Private methods
def _handle_update(self, update): def _handle_update(self, update):
asyncio.ensure_future(self._dispatch_update(update)) asyncio.ensure_future(self._dispatch_update(update), loop=self._loop)
async def _dispatch_update(self, update): async def _dispatch_update(self, update):
if self._events_pending_resolve: if self._events_pending_resolve:

View File

@ -30,7 +30,7 @@ class UserMethods(TelegramBaseClient):
pass pass
except (errors.FloodWaitError, errors.FloodTestPhoneWaitError) as e: except (errors.FloodWaitError, errors.FloodTestPhoneWaitError) as e:
if e.seconds <= self.session.flood_sleep_threshold: if e.seconds <= self.session.flood_sleep_threshold:
await asyncio.sleep(e.seconds) await asyncio.sleep(e.seconds, loop=self._loop)
else: else:
raise raise
except (errors.PhoneMigrateError, errors.NetworkMigrateError, except (errors.PhoneMigrateError, errors.NetworkMigrateError,

View File

@ -38,16 +38,16 @@ class TcpClient:
class SocketClosed(ConnectionError): class SocketClosed(ConnectionError):
pass pass
def __init__(self, proxy=None, timeout=timedelta(seconds=5), loop=None): def __init__(self, *, loop, proxy=None, timeout=timedelta(seconds=5)):
""" """
Initializes the TCP client. Initializes the TCP client.
:param proxy: the proxy to be used, if any. :param proxy: the proxy to be used, if any.
:param timeout: the timeout for connect, read and write operations. :param timeout: the timeout for connect, read and write operations.
""" """
self._loop = loop
self.proxy = proxy self.proxy = proxy
self._socket = None self._socket = None
self._loop = loop or asyncio.get_event_loop()
self._closed = asyncio.Event(loop=self._loop) self._closed = asyncio.Event(loop=self._loop)
self._closed.set() self._closed.set()

View File

@ -21,13 +21,15 @@ class Connection(abc.ABC):
Subclasses should implement the actual protocol Subclasses should implement the actual protocol
being used when encoding/decoding messages. being used when encoding/decoding messages.
""" """
def __init__(self, proxy=None, timeout=timedelta(seconds=5)): def __init__(self, *, loop, proxy=None, timeout=timedelta(seconds=5)):
""" """
Initializes a new connection. Initializes a new connection.
:param loop: the event loop to be used.
:param proxy: whether to use a proxy or not. :param proxy: whether to use a proxy or not.
:param timeout: timeout to be used for all operations. :param timeout: timeout to be used for all operations.
""" """
self._loop = loop
self._proxy = proxy self._proxy = proxy
self._timeout = timeout self._timeout = timeout
@ -54,10 +56,13 @@ class Connection(abc.ABC):
"""Closes the connection.""" """Closes the connection."""
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod
def clone(self): def clone(self):
"""Creates a copy of this Connection.""" """Creates a copy of this Connection."""
raise NotImplementedError return self.__class__(
loop=self._loop,
proxy=self._proxy,
timeout=self._timeout
)
@abc.abstractmethod @abc.abstractmethod
async def recv(self): async def recv(self):

View File

@ -14,9 +14,6 @@ class ConnectionTcpAbridged(ConnectionTcpFull):
await self.conn.write(b'\xef') await self.conn.write(b'\xef')
return result return result
def clone(self):
return ConnectionTcpAbridged(self._proxy, self._timeout)
async def recv(self): async def recv(self):
length = struct.unpack('<B', await self.read(1))[0] length = struct.unpack('<B', await self.read(1))[0]
if length >= 127: if length >= 127:

View File

@ -13,10 +13,12 @@ class ConnectionTcpFull(Connection):
Default Telegram mode. Sends 12 additional bytes and Default Telegram mode. Sends 12 additional bytes and
needs to calculate the CRC value of the packet itself. needs to calculate the CRC value of the packet itself.
""" """
def __init__(self, proxy=None, timeout=timedelta(seconds=5)): def __init__(self, *, loop, proxy=None, timeout=timedelta(seconds=5)):
super().__init__(proxy, timeout) super().__init__(loop=loop, proxy=proxy, timeout=timeout)
self._send_counter = 0 self._send_counter = 0
self.conn = TcpClient(proxy=self._proxy, timeout=self._timeout) self.conn = TcpClient(
proxy=self._proxy, timeout=self._timeout, loop=self._loop
)
self.read = self.conn.read self.read = self.conn.read
self.write = self.conn.write self.write = self.conn.write
@ -40,9 +42,6 @@ class ConnectionTcpFull(Connection):
async def close(self): async def close(self):
self.conn.close() self.conn.close()
def clone(self):
return ConnectionTcpFull(self._proxy, self._timeout)
async def recv(self): async def recv(self):
packet_len_seq = await self.read(8) # 4 and 4 packet_len_seq = await self.read(8) # 4 and 4
packet_len, seq = struct.unpack('<ii', packet_len_seq) packet_len, seq = struct.unpack('<ii', packet_len_seq)

View File

@ -13,9 +13,6 @@ class ConnectionTcpIntermediate(ConnectionTcpFull):
await self.conn.write(b'\xee\xee\xee\xee') await self.conn.write(b'\xee\xee\xee\xee')
return result return result
def clone(self):
return ConnectionTcpIntermediate(self._proxy, self._timeout)
async def recv(self): async def recv(self):
return await self.read(struct.unpack('<i', await self.read(4))[0]) return await self.read(struct.unpack('<i', await self.read(4))[0])

View File

@ -12,8 +12,8 @@ class ConnectionTcpObfuscated(ConnectionTcpAbridged):
every message with a randomly generated key using the every message with a randomly generated key using the
AES-CTR mode so the packets are harder to discern. AES-CTR mode so the packets are harder to discern.
""" """
def __init__(self, proxy=None, timeout=timedelta(seconds=5)): def __init__(self, *, loop, proxy=None, timeout=timedelta(seconds=5)):
super().__init__(proxy, timeout) super().__init__(loop=loop, proxy=proxy, timeout=timeout)
self._aes_encrypt, self._aes_decrypt = None, None self._aes_encrypt, self._aes_decrypt = None, None
self.read = lambda s: self._aes_decrypt.encrypt(self.conn.read(s)) self.read = lambda s: self._aes_decrypt.encrypt(self.conn.read(s))
self.write = lambda d: self.conn.write(self._aes_encrypt.encrypt(d)) self.write = lambda d: self.conn.write(self._aes_encrypt.encrypt(d))
@ -45,6 +45,3 @@ class ConnectionTcpObfuscated(ConnectionTcpAbridged):
random[56:64] = self._aes_encrypt.encrypt(bytes(random))[56:64] random[56:64] = self._aes_encrypt.encrypt(bytes(random))[56:64]
await self.conn.write(bytes(random)) await self.conn.write(bytes(random))
return result return result
def clone(self):
return ConnectionTcpObfuscated(self._proxy, self._timeout)

View File

@ -39,10 +39,11 @@ class MTProtoSender:
A new authorization key will be generated on connection if no other A new authorization key will be generated on connection if no other
key exists yet. key exists yet.
""" """
def __init__(self, state, connection, *, retries=5, def __init__(self, state, connection, loop, *, retries=5,
first_query=None, update_callback=None): first_query=None, update_callback=None):
self.state = state self.state = state
self._connection = connection self._connection = connection
self._loop = loop
self._ip = None self._ip = None
self._port = None self._port = None
self._retries = retries self._retries = retries
@ -231,9 +232,13 @@ class MTProtoSender:
raise _last_error raise _last_error
__log__.debug('Starting send loop') __log__.debug('Starting send loop')
self._send_loop_handle = asyncio.ensure_future(self._send_loop()) self._send_loop_handle = asyncio.ensure_future(
self._send_loop(), loop=self._loop)
__log__.debug('Starting receive loop') __log__.debug('Starting receive loop')
self._recv_loop_handle = asyncio.ensure_future(self._recv_loop()) self._recv_loop_handle = asyncio.ensure_future(
self._recv_loop(), loop=self._loop)
if self._is_first_query: if self._is_first_query:
__log__.debug('Running first query') __log__.debug('Running first query')
self._is_first_query = False self._is_first_query = False
@ -347,11 +352,11 @@ class MTProtoSender:
continue continue
except ConnectionError as e: except ConnectionError as e:
__log__.info('Connection reset while receiving %s', e) __log__.info('Connection reset while receiving %s', e)
asyncio.ensure_future(self._reconnect()) asyncio.ensure_future(self._reconnect(), loop=self._loop)
break break
except OSError as e: except OSError as e:
__log__.warning('OSError while receiving %s', e) __log__.warning('OSError while receiving %s', e)
asyncio.ensure_future(self._reconnect()) asyncio.ensure_future(self._reconnect(), loop=self._loop)
break break
# TODO Check salt, session_id and sequence_number # TODO Check salt, session_id and sequence_number
@ -370,7 +375,7 @@ class MTProtoSender:
# an actually broken authkey? # an actually broken authkey?
__log__.warning('Broken authorization key?: {}'.format(e)) __log__.warning('Broken authorization key?: {}'.format(e))
self.state.auth_key = None self.state.auth_key = None
asyncio.ensure_future(self._reconnect()) asyncio.ensure_future(self._reconnect(), loop=self._loop)
break break
except SecurityError as e: except SecurityError as e:
# A step while decoding had the incorrect data. This message # A step while decoding had the incorrect data. This message