mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2024-11-10 19:46:36 +03:00
Support custom event loops
This commit is contained in:
parent
908dfa148b
commit
0f14f3b16a
|
@ -217,7 +217,8 @@ class MessageMethods(UploadMethods, MessageParseMethods):
|
|||
else:
|
||||
request.max_date = r.messages[-1].date
|
||||
|
||||
await asyncio.sleep(max(wait_time - (time.time() - start), 0))
|
||||
await asyncio.sleep(
|
||||
max(wait_time - (time.time() - start), 0), loop=self._loop)
|
||||
|
||||
async def get_messages(self, *args, **kwargs):
|
||||
"""
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import abc
|
||||
import asyncio
|
||||
import logging
|
||||
import platform
|
||||
import warnings
|
||||
|
@ -112,7 +113,8 @@ class TelegramBaseClient(abc.ABC):
|
|||
system_version=None,
|
||||
app_version=None,
|
||||
lang_code='en',
|
||||
system_lang_code='en'):
|
||||
system_lang_code='en',
|
||||
loop=None):
|
||||
"""Refer to TelegramClient.__init__ for docs on this method"""
|
||||
if not api_id or not api_hash:
|
||||
raise ValueError(
|
||||
|
@ -120,6 +122,7 @@ class TelegramBaseClient(abc.ABC):
|
|||
"Refer to telethon.rtfd.io for more information.")
|
||||
|
||||
self._use_ipv6 = use_ipv6
|
||||
self._loop = loop or asyncio.get_event_loop()
|
||||
|
||||
# Determine what session object we have
|
||||
if isinstance(session, str) or session is None:
|
||||
|
@ -143,12 +146,9 @@ class TelegramBaseClient(abc.ABC):
|
|||
self.api_id = int(api_id)
|
||||
self.api_hash = api_hash
|
||||
|
||||
# This is the main sender, which will be used from the thread
|
||||
# that calls .connect(). Every other thread will spawn a new
|
||||
# temporary connection. The connection on this one is always
|
||||
# kept open so Telegram can send us updates.
|
||||
if isinstance(connection, type):
|
||||
connection = connection(proxy=proxy, timeout=timeout)
|
||||
connection = connection(
|
||||
proxy=proxy, timeout=timeout, loop=self._loop)
|
||||
|
||||
# Used on connection. Capture the variables in a lambda since
|
||||
# exporting clients need to create this InvokeWithLayerRequest.
|
||||
|
@ -169,7 +169,7 @@ class TelegramBaseClient(abc.ABC):
|
|||
state = MTProtoState(self.session.auth_key)
|
||||
self._connection = connection
|
||||
self._sender = MTProtoSender(
|
||||
state, connection,
|
||||
state, connection, self._loop,
|
||||
first_query=self._init_with(functions.help.GetConfigRequest()),
|
||||
update_callback=self._handle_update
|
||||
)
|
||||
|
@ -211,6 +211,14 @@ class TelegramBaseClient(abc.ABC):
|
|||
|
||||
# endregion
|
||||
|
||||
# region Properties
|
||||
|
||||
@property
|
||||
def loop(self):
|
||||
return self._loop
|
||||
|
||||
# endregion
|
||||
|
||||
# region Connecting
|
||||
|
||||
async def connect(self):
|
||||
|
@ -287,12 +295,11 @@ class TelegramBaseClient(abc.ABC):
|
|||
auth = self._exported_auths.get(dc_id)
|
||||
dc = await self._get_dc(dc_id)
|
||||
state = MTProtoState(auth)
|
||||
# TODO Don't hardcode ConnectionTcpFull()
|
||||
# Can't reuse self._sender._connection as it has its own seqno.
|
||||
#
|
||||
# If one were to do that, Telegram would reset the connection
|
||||
# with no further clues.
|
||||
sender = MTProtoSender(state, ConnectionTcpFull())
|
||||
sender = MTProtoSender(state, self._connection.clone(), self._loop)
|
||||
await sender.connect(dc.ip_address, dc.port)
|
||||
if not auth:
|
||||
__log__.info('Exporting authorization for data center %s', dc)
|
||||
|
|
|
@ -144,7 +144,7 @@ class UpdateMethods(UserMethods):
|
|||
# region Private methods
|
||||
|
||||
def _handle_update(self, update):
|
||||
asyncio.ensure_future(self._dispatch_update(update))
|
||||
asyncio.ensure_future(self._dispatch_update(update), loop=self._loop)
|
||||
|
||||
async def _dispatch_update(self, update):
|
||||
if self._events_pending_resolve:
|
||||
|
|
|
@ -30,7 +30,7 @@ class UserMethods(TelegramBaseClient):
|
|||
pass
|
||||
except (errors.FloodWaitError, errors.FloodTestPhoneWaitError) as e:
|
||||
if e.seconds <= self.session.flood_sleep_threshold:
|
||||
await asyncio.sleep(e.seconds)
|
||||
await asyncio.sleep(e.seconds, loop=self._loop)
|
||||
else:
|
||||
raise
|
||||
except (errors.PhoneMigrateError, errors.NetworkMigrateError,
|
||||
|
|
|
@ -38,16 +38,16 @@ class TcpClient:
|
|||
class SocketClosed(ConnectionError):
|
||||
pass
|
||||
|
||||
def __init__(self, proxy=None, timeout=timedelta(seconds=5), loop=None):
|
||||
def __init__(self, *, loop, proxy=None, timeout=timedelta(seconds=5)):
|
||||
"""
|
||||
Initializes the TCP client.
|
||||
|
||||
:param proxy: the proxy to be used, if any.
|
||||
:param timeout: the timeout for connect, read and write operations.
|
||||
"""
|
||||
self._loop = loop
|
||||
self.proxy = proxy
|
||||
self._socket = None
|
||||
self._loop = loop or asyncio.get_event_loop()
|
||||
self._closed = asyncio.Event(loop=self._loop)
|
||||
self._closed.set()
|
||||
|
||||
|
|
|
@ -21,13 +21,15 @@ class Connection(abc.ABC):
|
|||
Subclasses should implement the actual protocol
|
||||
being used when encoding/decoding messages.
|
||||
"""
|
||||
def __init__(self, proxy=None, timeout=timedelta(seconds=5)):
|
||||
def __init__(self, *, loop, proxy=None, timeout=timedelta(seconds=5)):
|
||||
"""
|
||||
Initializes a new connection.
|
||||
|
||||
:param loop: the event loop to be used.
|
||||
:param proxy: whether to use a proxy or not.
|
||||
:param timeout: timeout to be used for all operations.
|
||||
"""
|
||||
self._loop = loop
|
||||
self._proxy = proxy
|
||||
self._timeout = timeout
|
||||
|
||||
|
@ -54,10 +56,13 @@ class Connection(abc.ABC):
|
|||
"""Closes the connection."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def clone(self):
|
||||
"""Creates a copy of this Connection."""
|
||||
raise NotImplementedError
|
||||
return self.__class__(
|
||||
loop=self._loop,
|
||||
proxy=self._proxy,
|
||||
timeout=self._timeout
|
||||
)
|
||||
|
||||
@abc.abstractmethod
|
||||
async def recv(self):
|
||||
|
|
|
@ -14,9 +14,6 @@ class ConnectionTcpAbridged(ConnectionTcpFull):
|
|||
await self.conn.write(b'\xef')
|
||||
return result
|
||||
|
||||
def clone(self):
|
||||
return ConnectionTcpAbridged(self._proxy, self._timeout)
|
||||
|
||||
async def recv(self):
|
||||
length = struct.unpack('<B', await self.read(1))[0]
|
||||
if length >= 127:
|
||||
|
|
|
@ -13,10 +13,12 @@ class ConnectionTcpFull(Connection):
|
|||
Default Telegram mode. Sends 12 additional bytes and
|
||||
needs to calculate the CRC value of the packet itself.
|
||||
"""
|
||||
def __init__(self, proxy=None, timeout=timedelta(seconds=5)):
|
||||
super().__init__(proxy, timeout)
|
||||
def __init__(self, *, loop, proxy=None, timeout=timedelta(seconds=5)):
|
||||
super().__init__(loop=loop, proxy=proxy, timeout=timeout)
|
||||
self._send_counter = 0
|
||||
self.conn = TcpClient(proxy=self._proxy, timeout=self._timeout)
|
||||
self.conn = TcpClient(
|
||||
proxy=self._proxy, timeout=self._timeout, loop=self._loop
|
||||
)
|
||||
self.read = self.conn.read
|
||||
self.write = self.conn.write
|
||||
|
||||
|
@ -40,9 +42,6 @@ class ConnectionTcpFull(Connection):
|
|||
async def close(self):
|
||||
self.conn.close()
|
||||
|
||||
def clone(self):
|
||||
return ConnectionTcpFull(self._proxy, self._timeout)
|
||||
|
||||
async def recv(self):
|
||||
packet_len_seq = await self.read(8) # 4 and 4
|
||||
packet_len, seq = struct.unpack('<ii', packet_len_seq)
|
||||
|
|
|
@ -13,9 +13,6 @@ class ConnectionTcpIntermediate(ConnectionTcpFull):
|
|||
await self.conn.write(b'\xee\xee\xee\xee')
|
||||
return result
|
||||
|
||||
def clone(self):
|
||||
return ConnectionTcpIntermediate(self._proxy, self._timeout)
|
||||
|
||||
async def recv(self):
|
||||
return await self.read(struct.unpack('<i', await self.read(4))[0])
|
||||
|
||||
|
|
|
@ -12,8 +12,8 @@ class ConnectionTcpObfuscated(ConnectionTcpAbridged):
|
|||
every message with a randomly generated key using the
|
||||
AES-CTR mode so the packets are harder to discern.
|
||||
"""
|
||||
def __init__(self, proxy=None, timeout=timedelta(seconds=5)):
|
||||
super().__init__(proxy, timeout)
|
||||
def __init__(self, *, loop, proxy=None, timeout=timedelta(seconds=5)):
|
||||
super().__init__(loop=loop, proxy=proxy, timeout=timeout)
|
||||
self._aes_encrypt, self._aes_decrypt = None, None
|
||||
self.read = lambda s: self._aes_decrypt.encrypt(self.conn.read(s))
|
||||
self.write = lambda d: self.conn.write(self._aes_encrypt.encrypt(d))
|
||||
|
@ -45,6 +45,3 @@ class ConnectionTcpObfuscated(ConnectionTcpAbridged):
|
|||
random[56:64] = self._aes_encrypt.encrypt(bytes(random))[56:64]
|
||||
await self.conn.write(bytes(random))
|
||||
return result
|
||||
|
||||
def clone(self):
|
||||
return ConnectionTcpObfuscated(self._proxy, self._timeout)
|
||||
|
|
|
@ -39,10 +39,11 @@ class MTProtoSender:
|
|||
A new authorization key will be generated on connection if no other
|
||||
key exists yet.
|
||||
"""
|
||||
def __init__(self, state, connection, *, retries=5,
|
||||
def __init__(self, state, connection, loop, *, retries=5,
|
||||
first_query=None, update_callback=None):
|
||||
self.state = state
|
||||
self._connection = connection
|
||||
self._loop = loop
|
||||
self._ip = None
|
||||
self._port = None
|
||||
self._retries = retries
|
||||
|
@ -231,9 +232,13 @@ class MTProtoSender:
|
|||
raise _last_error
|
||||
|
||||
__log__.debug('Starting send loop')
|
||||
self._send_loop_handle = asyncio.ensure_future(self._send_loop())
|
||||
self._send_loop_handle = asyncio.ensure_future(
|
||||
self._send_loop(), loop=self._loop)
|
||||
|
||||
__log__.debug('Starting receive loop')
|
||||
self._recv_loop_handle = asyncio.ensure_future(self._recv_loop())
|
||||
self._recv_loop_handle = asyncio.ensure_future(
|
||||
self._recv_loop(), loop=self._loop)
|
||||
|
||||
if self._is_first_query:
|
||||
__log__.debug('Running first query')
|
||||
self._is_first_query = False
|
||||
|
@ -347,11 +352,11 @@ class MTProtoSender:
|
|||
continue
|
||||
except ConnectionError as e:
|
||||
__log__.info('Connection reset while receiving %s', e)
|
||||
asyncio.ensure_future(self._reconnect())
|
||||
asyncio.ensure_future(self._reconnect(), loop=self._loop)
|
||||
break
|
||||
except OSError as e:
|
||||
__log__.warning('OSError while receiving %s', e)
|
||||
asyncio.ensure_future(self._reconnect())
|
||||
asyncio.ensure_future(self._reconnect(), loop=self._loop)
|
||||
break
|
||||
|
||||
# TODO Check salt, session_id and sequence_number
|
||||
|
@ -370,7 +375,7 @@ class MTProtoSender:
|
|||
# an actually broken authkey?
|
||||
__log__.warning('Broken authorization key?: {}'.format(e))
|
||||
self.state.auth_key = None
|
||||
asyncio.ensure_future(self._reconnect())
|
||||
asyncio.ensure_future(self._reconnect(), loop=self._loop)
|
||||
break
|
||||
except SecurityError as e:
|
||||
# A step while decoding had the incorrect data. This message
|
||||
|
|
Loading…
Reference in New Issue
Block a user