Support custom event loops

This commit is contained in:
Lonami Exo 2018-06-14 19:35:12 +02:00
parent 908dfa148b
commit 0f14f3b16a
11 changed files with 48 additions and 40 deletions

View File

@ -217,7 +217,8 @@ class MessageMethods(UploadMethods, MessageParseMethods):
else:
request.max_date = r.messages[-1].date
await asyncio.sleep(max(wait_time - (time.time() - start), 0))
await asyncio.sleep(
max(wait_time - (time.time() - start), 0), loop=self._loop)
async def get_messages(self, *args, **kwargs):
"""

View File

@ -1,4 +1,5 @@
import abc
import asyncio
import logging
import platform
import warnings
@ -112,7 +113,8 @@ class TelegramBaseClient(abc.ABC):
system_version=None,
app_version=None,
lang_code='en',
system_lang_code='en'):
system_lang_code='en',
loop=None):
"""Refer to TelegramClient.__init__ for docs on this method"""
if not api_id or not api_hash:
raise ValueError(
@ -120,6 +122,7 @@ class TelegramBaseClient(abc.ABC):
"Refer to telethon.rtfd.io for more information.")
self._use_ipv6 = use_ipv6
self._loop = loop or asyncio.get_event_loop()
# Determine what session object we have
if isinstance(session, str) or session is None:
@ -143,12 +146,9 @@ class TelegramBaseClient(abc.ABC):
self.api_id = int(api_id)
self.api_hash = api_hash
# This is the main sender, which will be used from the thread
# that calls .connect(). Every other thread will spawn a new
# temporary connection. The connection on this one is always
# kept open so Telegram can send us updates.
if isinstance(connection, type):
connection = connection(proxy=proxy, timeout=timeout)
connection = connection(
proxy=proxy, timeout=timeout, loop=self._loop)
# Used on connection. Capture the variables in a lambda since
# exporting clients need to create this InvokeWithLayerRequest.
@ -169,7 +169,7 @@ class TelegramBaseClient(abc.ABC):
state = MTProtoState(self.session.auth_key)
self._connection = connection
self._sender = MTProtoSender(
state, connection,
state, connection, self._loop,
first_query=self._init_with(functions.help.GetConfigRequest()),
update_callback=self._handle_update
)
@ -211,6 +211,14 @@ class TelegramBaseClient(abc.ABC):
# endregion
# region Properties
@property
def loop(self):
return self._loop
# endregion
# region Connecting
async def connect(self):
@ -287,12 +295,11 @@ class TelegramBaseClient(abc.ABC):
auth = self._exported_auths.get(dc_id)
dc = await self._get_dc(dc_id)
state = MTProtoState(auth)
# TODO Don't hardcode ConnectionTcpFull()
# Can't reuse self._sender._connection as it has its own seqno.
#
# If one were to do that, Telegram would reset the connection
# with no further clues.
sender = MTProtoSender(state, ConnectionTcpFull())
sender = MTProtoSender(state, self._connection.clone(), self._loop)
await sender.connect(dc.ip_address, dc.port)
if not auth:
__log__.info('Exporting authorization for data center %s', dc)

View File

@ -144,7 +144,7 @@ class UpdateMethods(UserMethods):
# region Private methods
def _handle_update(self, update):
asyncio.ensure_future(self._dispatch_update(update))
asyncio.ensure_future(self._dispatch_update(update), loop=self._loop)
async def _dispatch_update(self, update):
if self._events_pending_resolve:

View File

@ -30,7 +30,7 @@ class UserMethods(TelegramBaseClient):
pass
except (errors.FloodWaitError, errors.FloodTestPhoneWaitError) as e:
if e.seconds <= self.session.flood_sleep_threshold:
await asyncio.sleep(e.seconds)
await asyncio.sleep(e.seconds, loop=self._loop)
else:
raise
except (errors.PhoneMigrateError, errors.NetworkMigrateError,

View File

@ -38,16 +38,16 @@ class TcpClient:
class SocketClosed(ConnectionError):
pass
def __init__(self, proxy=None, timeout=timedelta(seconds=5), loop=None):
def __init__(self, *, loop, proxy=None, timeout=timedelta(seconds=5)):
"""
Initializes the TCP client.
:param proxy: the proxy to be used, if any.
:param timeout: the timeout for connect, read and write operations.
"""
self._loop = loop
self.proxy = proxy
self._socket = None
self._loop = loop or asyncio.get_event_loop()
self._closed = asyncio.Event(loop=self._loop)
self._closed.set()

View File

@ -21,13 +21,15 @@ class Connection(abc.ABC):
Subclasses should implement the actual protocol
being used when encoding/decoding messages.
"""
def __init__(self, proxy=None, timeout=timedelta(seconds=5)):
def __init__(self, *, loop, proxy=None, timeout=timedelta(seconds=5)):
"""
Initializes a new connection.
:param loop: the event loop to be used.
:param proxy: whether to use a proxy or not.
:param timeout: timeout to be used for all operations.
"""
self._loop = loop
self._proxy = proxy
self._timeout = timeout
@ -54,10 +56,13 @@ class Connection(abc.ABC):
"""Closes the connection."""
raise NotImplementedError
@abc.abstractmethod
def clone(self):
"""Creates a copy of this Connection."""
raise NotImplementedError
return self.__class__(
loop=self._loop,
proxy=self._proxy,
timeout=self._timeout
)
@abc.abstractmethod
async def recv(self):

View File

@ -14,9 +14,6 @@ class ConnectionTcpAbridged(ConnectionTcpFull):
await self.conn.write(b'\xef')
return result
def clone(self):
return ConnectionTcpAbridged(self._proxy, self._timeout)
async def recv(self):
length = struct.unpack('<B', await self.read(1))[0]
if length >= 127:

View File

@ -13,10 +13,12 @@ class ConnectionTcpFull(Connection):
Default Telegram mode. Sends 12 additional bytes and
needs to calculate the CRC value of the packet itself.
"""
def __init__(self, proxy=None, timeout=timedelta(seconds=5)):
super().__init__(proxy, timeout)
def __init__(self, *, loop, proxy=None, timeout=timedelta(seconds=5)):
super().__init__(loop=loop, proxy=proxy, timeout=timeout)
self._send_counter = 0
self.conn = TcpClient(proxy=self._proxy, timeout=self._timeout)
self.conn = TcpClient(
proxy=self._proxy, timeout=self._timeout, loop=self._loop
)
self.read = self.conn.read
self.write = self.conn.write
@ -40,9 +42,6 @@ class ConnectionTcpFull(Connection):
async def close(self):
self.conn.close()
def clone(self):
return ConnectionTcpFull(self._proxy, self._timeout)
async def recv(self):
packet_len_seq = await self.read(8) # 4 and 4
packet_len, seq = struct.unpack('<ii', packet_len_seq)

View File

@ -13,9 +13,6 @@ class ConnectionTcpIntermediate(ConnectionTcpFull):
await self.conn.write(b'\xee\xee\xee\xee')
return result
def clone(self):
return ConnectionTcpIntermediate(self._proxy, self._timeout)
async def recv(self):
return await self.read(struct.unpack('<i', await self.read(4))[0])

View File

@ -12,8 +12,8 @@ class ConnectionTcpObfuscated(ConnectionTcpAbridged):
every message with a randomly generated key using the
AES-CTR mode so the packets are harder to discern.
"""
def __init__(self, proxy=None, timeout=timedelta(seconds=5)):
super().__init__(proxy, timeout)
def __init__(self, *, loop, proxy=None, timeout=timedelta(seconds=5)):
super().__init__(loop=loop, proxy=proxy, timeout=timeout)
self._aes_encrypt, self._aes_decrypt = None, None
self.read = lambda s: self._aes_decrypt.encrypt(self.conn.read(s))
self.write = lambda d: self.conn.write(self._aes_encrypt.encrypt(d))
@ -45,6 +45,3 @@ class ConnectionTcpObfuscated(ConnectionTcpAbridged):
random[56:64] = self._aes_encrypt.encrypt(bytes(random))[56:64]
await self.conn.write(bytes(random))
return result
def clone(self):
return ConnectionTcpObfuscated(self._proxy, self._timeout)

View File

@ -39,10 +39,11 @@ class MTProtoSender:
A new authorization key will be generated on connection if no other
key exists yet.
"""
def __init__(self, state, connection, *, retries=5,
def __init__(self, state, connection, loop, *, retries=5,
first_query=None, update_callback=None):
self.state = state
self._connection = connection
self._loop = loop
self._ip = None
self._port = None
self._retries = retries
@ -231,9 +232,13 @@ class MTProtoSender:
raise _last_error
__log__.debug('Starting send loop')
self._send_loop_handle = asyncio.ensure_future(self._send_loop())
self._send_loop_handle = asyncio.ensure_future(
self._send_loop(), loop=self._loop)
__log__.debug('Starting receive loop')
self._recv_loop_handle = asyncio.ensure_future(self._recv_loop())
self._recv_loop_handle = asyncio.ensure_future(
self._recv_loop(), loop=self._loop)
if self._is_first_query:
__log__.debug('Running first query')
self._is_first_query = False
@ -347,11 +352,11 @@ class MTProtoSender:
continue
except ConnectionError as e:
__log__.info('Connection reset while receiving %s', e)
asyncio.ensure_future(self._reconnect())
asyncio.ensure_future(self._reconnect(), loop=self._loop)
break
except OSError as e:
__log__.warning('OSError while receiving %s', e)
asyncio.ensure_future(self._reconnect())
asyncio.ensure_future(self._reconnect(), loop=self._loop)
break
# TODO Check salt, session_id and sequence_number
@ -370,7 +375,7 @@ class MTProtoSender:
# an actually broken authkey?
__log__.warning('Broken authorization key?: {}'.format(e))
self.state.auth_key = None
asyncio.ensure_future(self._reconnect())
asyncio.ensure_future(self._reconnect(), loop=self._loop)
break
except SecurityError as e:
# A step while decoding had the incorrect data. This message