mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-02-16 19:41:07 +03:00
Support custom event loops
This commit is contained in:
parent
908dfa148b
commit
0f14f3b16a
|
@ -217,7 +217,8 @@ class MessageMethods(UploadMethods, MessageParseMethods):
|
||||||
else:
|
else:
|
||||||
request.max_date = r.messages[-1].date
|
request.max_date = r.messages[-1].date
|
||||||
|
|
||||||
await asyncio.sleep(max(wait_time - (time.time() - start), 0))
|
await asyncio.sleep(
|
||||||
|
max(wait_time - (time.time() - start), 0), loop=self._loop)
|
||||||
|
|
||||||
async def get_messages(self, *args, **kwargs):
|
async def get_messages(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import abc
|
import abc
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import platform
|
import platform
|
||||||
import warnings
|
import warnings
|
||||||
|
@ -112,7 +113,8 @@ class TelegramBaseClient(abc.ABC):
|
||||||
system_version=None,
|
system_version=None,
|
||||||
app_version=None,
|
app_version=None,
|
||||||
lang_code='en',
|
lang_code='en',
|
||||||
system_lang_code='en'):
|
system_lang_code='en',
|
||||||
|
loop=None):
|
||||||
"""Refer to TelegramClient.__init__ for docs on this method"""
|
"""Refer to TelegramClient.__init__ for docs on this method"""
|
||||||
if not api_id or not api_hash:
|
if not api_id or not api_hash:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -120,6 +122,7 @@ class TelegramBaseClient(abc.ABC):
|
||||||
"Refer to telethon.rtfd.io for more information.")
|
"Refer to telethon.rtfd.io for more information.")
|
||||||
|
|
||||||
self._use_ipv6 = use_ipv6
|
self._use_ipv6 = use_ipv6
|
||||||
|
self._loop = loop or asyncio.get_event_loop()
|
||||||
|
|
||||||
# Determine what session object we have
|
# Determine what session object we have
|
||||||
if isinstance(session, str) or session is None:
|
if isinstance(session, str) or session is None:
|
||||||
|
@ -143,12 +146,9 @@ class TelegramBaseClient(abc.ABC):
|
||||||
self.api_id = int(api_id)
|
self.api_id = int(api_id)
|
||||||
self.api_hash = api_hash
|
self.api_hash = api_hash
|
||||||
|
|
||||||
# This is the main sender, which will be used from the thread
|
|
||||||
# that calls .connect(). Every other thread will spawn a new
|
|
||||||
# temporary connection. The connection on this one is always
|
|
||||||
# kept open so Telegram can send us updates.
|
|
||||||
if isinstance(connection, type):
|
if isinstance(connection, type):
|
||||||
connection = connection(proxy=proxy, timeout=timeout)
|
connection = connection(
|
||||||
|
proxy=proxy, timeout=timeout, loop=self._loop)
|
||||||
|
|
||||||
# Used on connection. Capture the variables in a lambda since
|
# Used on connection. Capture the variables in a lambda since
|
||||||
# exporting clients need to create this InvokeWithLayerRequest.
|
# exporting clients need to create this InvokeWithLayerRequest.
|
||||||
|
@ -169,7 +169,7 @@ class TelegramBaseClient(abc.ABC):
|
||||||
state = MTProtoState(self.session.auth_key)
|
state = MTProtoState(self.session.auth_key)
|
||||||
self._connection = connection
|
self._connection = connection
|
||||||
self._sender = MTProtoSender(
|
self._sender = MTProtoSender(
|
||||||
state, connection,
|
state, connection, self._loop,
|
||||||
first_query=self._init_with(functions.help.GetConfigRequest()),
|
first_query=self._init_with(functions.help.GetConfigRequest()),
|
||||||
update_callback=self._handle_update
|
update_callback=self._handle_update
|
||||||
)
|
)
|
||||||
|
@ -211,6 +211,14 @@ class TelegramBaseClient(abc.ABC):
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
# region Properties
|
||||||
|
|
||||||
|
@property
|
||||||
|
def loop(self):
|
||||||
|
return self._loop
|
||||||
|
|
||||||
|
# endregion
|
||||||
|
|
||||||
# region Connecting
|
# region Connecting
|
||||||
|
|
||||||
async def connect(self):
|
async def connect(self):
|
||||||
|
@ -287,12 +295,11 @@ class TelegramBaseClient(abc.ABC):
|
||||||
auth = self._exported_auths.get(dc_id)
|
auth = self._exported_auths.get(dc_id)
|
||||||
dc = await self._get_dc(dc_id)
|
dc = await self._get_dc(dc_id)
|
||||||
state = MTProtoState(auth)
|
state = MTProtoState(auth)
|
||||||
# TODO Don't hardcode ConnectionTcpFull()
|
|
||||||
# Can't reuse self._sender._connection as it has its own seqno.
|
# Can't reuse self._sender._connection as it has its own seqno.
|
||||||
#
|
#
|
||||||
# If one were to do that, Telegram would reset the connection
|
# If one were to do that, Telegram would reset the connection
|
||||||
# with no further clues.
|
# with no further clues.
|
||||||
sender = MTProtoSender(state, ConnectionTcpFull())
|
sender = MTProtoSender(state, self._connection.clone(), self._loop)
|
||||||
await sender.connect(dc.ip_address, dc.port)
|
await sender.connect(dc.ip_address, dc.port)
|
||||||
if not auth:
|
if not auth:
|
||||||
__log__.info('Exporting authorization for data center %s', dc)
|
__log__.info('Exporting authorization for data center %s', dc)
|
||||||
|
|
|
@ -144,7 +144,7 @@ class UpdateMethods(UserMethods):
|
||||||
# region Private methods
|
# region Private methods
|
||||||
|
|
||||||
def _handle_update(self, update):
|
def _handle_update(self, update):
|
||||||
asyncio.ensure_future(self._dispatch_update(update))
|
asyncio.ensure_future(self._dispatch_update(update), loop=self._loop)
|
||||||
|
|
||||||
async def _dispatch_update(self, update):
|
async def _dispatch_update(self, update):
|
||||||
if self._events_pending_resolve:
|
if self._events_pending_resolve:
|
||||||
|
|
|
@ -30,7 +30,7 @@ class UserMethods(TelegramBaseClient):
|
||||||
pass
|
pass
|
||||||
except (errors.FloodWaitError, errors.FloodTestPhoneWaitError) as e:
|
except (errors.FloodWaitError, errors.FloodTestPhoneWaitError) as e:
|
||||||
if e.seconds <= self.session.flood_sleep_threshold:
|
if e.seconds <= self.session.flood_sleep_threshold:
|
||||||
await asyncio.sleep(e.seconds)
|
await asyncio.sleep(e.seconds, loop=self._loop)
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
except (errors.PhoneMigrateError, errors.NetworkMigrateError,
|
except (errors.PhoneMigrateError, errors.NetworkMigrateError,
|
||||||
|
|
|
@ -38,16 +38,16 @@ class TcpClient:
|
||||||
class SocketClosed(ConnectionError):
|
class SocketClosed(ConnectionError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def __init__(self, proxy=None, timeout=timedelta(seconds=5), loop=None):
|
def __init__(self, *, loop, proxy=None, timeout=timedelta(seconds=5)):
|
||||||
"""
|
"""
|
||||||
Initializes the TCP client.
|
Initializes the TCP client.
|
||||||
|
|
||||||
:param proxy: the proxy to be used, if any.
|
:param proxy: the proxy to be used, if any.
|
||||||
:param timeout: the timeout for connect, read and write operations.
|
:param timeout: the timeout for connect, read and write operations.
|
||||||
"""
|
"""
|
||||||
|
self._loop = loop
|
||||||
self.proxy = proxy
|
self.proxy = proxy
|
||||||
self._socket = None
|
self._socket = None
|
||||||
self._loop = loop or asyncio.get_event_loop()
|
|
||||||
self._closed = asyncio.Event(loop=self._loop)
|
self._closed = asyncio.Event(loop=self._loop)
|
||||||
self._closed.set()
|
self._closed.set()
|
||||||
|
|
||||||
|
|
|
@ -21,13 +21,15 @@ class Connection(abc.ABC):
|
||||||
Subclasses should implement the actual protocol
|
Subclasses should implement the actual protocol
|
||||||
being used when encoding/decoding messages.
|
being used when encoding/decoding messages.
|
||||||
"""
|
"""
|
||||||
def __init__(self, proxy=None, timeout=timedelta(seconds=5)):
|
def __init__(self, *, loop, proxy=None, timeout=timedelta(seconds=5)):
|
||||||
"""
|
"""
|
||||||
Initializes a new connection.
|
Initializes a new connection.
|
||||||
|
|
||||||
|
:param loop: the event loop to be used.
|
||||||
:param proxy: whether to use a proxy or not.
|
:param proxy: whether to use a proxy or not.
|
||||||
:param timeout: timeout to be used for all operations.
|
:param timeout: timeout to be used for all operations.
|
||||||
"""
|
"""
|
||||||
|
self._loop = loop
|
||||||
self._proxy = proxy
|
self._proxy = proxy
|
||||||
self._timeout = timeout
|
self._timeout = timeout
|
||||||
|
|
||||||
|
@ -54,10 +56,13 @@ class Connection(abc.ABC):
|
||||||
"""Closes the connection."""
|
"""Closes the connection."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
"""Creates a copy of this Connection."""
|
"""Creates a copy of this Connection."""
|
||||||
raise NotImplementedError
|
return self.__class__(
|
||||||
|
loop=self._loop,
|
||||||
|
proxy=self._proxy,
|
||||||
|
timeout=self._timeout
|
||||||
|
)
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def recv(self):
|
async def recv(self):
|
||||||
|
|
|
@ -14,9 +14,6 @@ class ConnectionTcpAbridged(ConnectionTcpFull):
|
||||||
await self.conn.write(b'\xef')
|
await self.conn.write(b'\xef')
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def clone(self):
|
|
||||||
return ConnectionTcpAbridged(self._proxy, self._timeout)
|
|
||||||
|
|
||||||
async def recv(self):
|
async def recv(self):
|
||||||
length = struct.unpack('<B', await self.read(1))[0]
|
length = struct.unpack('<B', await self.read(1))[0]
|
||||||
if length >= 127:
|
if length >= 127:
|
||||||
|
|
|
@ -13,10 +13,12 @@ class ConnectionTcpFull(Connection):
|
||||||
Default Telegram mode. Sends 12 additional bytes and
|
Default Telegram mode. Sends 12 additional bytes and
|
||||||
needs to calculate the CRC value of the packet itself.
|
needs to calculate the CRC value of the packet itself.
|
||||||
"""
|
"""
|
||||||
def __init__(self, proxy=None, timeout=timedelta(seconds=5)):
|
def __init__(self, *, loop, proxy=None, timeout=timedelta(seconds=5)):
|
||||||
super().__init__(proxy, timeout)
|
super().__init__(loop=loop, proxy=proxy, timeout=timeout)
|
||||||
self._send_counter = 0
|
self._send_counter = 0
|
||||||
self.conn = TcpClient(proxy=self._proxy, timeout=self._timeout)
|
self.conn = TcpClient(
|
||||||
|
proxy=self._proxy, timeout=self._timeout, loop=self._loop
|
||||||
|
)
|
||||||
self.read = self.conn.read
|
self.read = self.conn.read
|
||||||
self.write = self.conn.write
|
self.write = self.conn.write
|
||||||
|
|
||||||
|
@ -40,9 +42,6 @@ class ConnectionTcpFull(Connection):
|
||||||
async def close(self):
|
async def close(self):
|
||||||
self.conn.close()
|
self.conn.close()
|
||||||
|
|
||||||
def clone(self):
|
|
||||||
return ConnectionTcpFull(self._proxy, self._timeout)
|
|
||||||
|
|
||||||
async def recv(self):
|
async def recv(self):
|
||||||
packet_len_seq = await self.read(8) # 4 and 4
|
packet_len_seq = await self.read(8) # 4 and 4
|
||||||
packet_len, seq = struct.unpack('<ii', packet_len_seq)
|
packet_len, seq = struct.unpack('<ii', packet_len_seq)
|
||||||
|
|
|
@ -13,9 +13,6 @@ class ConnectionTcpIntermediate(ConnectionTcpFull):
|
||||||
await self.conn.write(b'\xee\xee\xee\xee')
|
await self.conn.write(b'\xee\xee\xee\xee')
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def clone(self):
|
|
||||||
return ConnectionTcpIntermediate(self._proxy, self._timeout)
|
|
||||||
|
|
||||||
async def recv(self):
|
async def recv(self):
|
||||||
return await self.read(struct.unpack('<i', await self.read(4))[0])
|
return await self.read(struct.unpack('<i', await self.read(4))[0])
|
||||||
|
|
||||||
|
|
|
@ -12,8 +12,8 @@ class ConnectionTcpObfuscated(ConnectionTcpAbridged):
|
||||||
every message with a randomly generated key using the
|
every message with a randomly generated key using the
|
||||||
AES-CTR mode so the packets are harder to discern.
|
AES-CTR mode so the packets are harder to discern.
|
||||||
"""
|
"""
|
||||||
def __init__(self, proxy=None, timeout=timedelta(seconds=5)):
|
def __init__(self, *, loop, proxy=None, timeout=timedelta(seconds=5)):
|
||||||
super().__init__(proxy, timeout)
|
super().__init__(loop=loop, proxy=proxy, timeout=timeout)
|
||||||
self._aes_encrypt, self._aes_decrypt = None, None
|
self._aes_encrypt, self._aes_decrypt = None, None
|
||||||
self.read = lambda s: self._aes_decrypt.encrypt(self.conn.read(s))
|
self.read = lambda s: self._aes_decrypt.encrypt(self.conn.read(s))
|
||||||
self.write = lambda d: self.conn.write(self._aes_encrypt.encrypt(d))
|
self.write = lambda d: self.conn.write(self._aes_encrypt.encrypt(d))
|
||||||
|
@ -45,6 +45,3 @@ class ConnectionTcpObfuscated(ConnectionTcpAbridged):
|
||||||
random[56:64] = self._aes_encrypt.encrypt(bytes(random))[56:64]
|
random[56:64] = self._aes_encrypt.encrypt(bytes(random))[56:64]
|
||||||
await self.conn.write(bytes(random))
|
await self.conn.write(bytes(random))
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def clone(self):
|
|
||||||
return ConnectionTcpObfuscated(self._proxy, self._timeout)
|
|
||||||
|
|
|
@ -39,10 +39,11 @@ class MTProtoSender:
|
||||||
A new authorization key will be generated on connection if no other
|
A new authorization key will be generated on connection if no other
|
||||||
key exists yet.
|
key exists yet.
|
||||||
"""
|
"""
|
||||||
def __init__(self, state, connection, *, retries=5,
|
def __init__(self, state, connection, loop, *, retries=5,
|
||||||
first_query=None, update_callback=None):
|
first_query=None, update_callback=None):
|
||||||
self.state = state
|
self.state = state
|
||||||
self._connection = connection
|
self._connection = connection
|
||||||
|
self._loop = loop
|
||||||
self._ip = None
|
self._ip = None
|
||||||
self._port = None
|
self._port = None
|
||||||
self._retries = retries
|
self._retries = retries
|
||||||
|
@ -231,9 +232,13 @@ class MTProtoSender:
|
||||||
raise _last_error
|
raise _last_error
|
||||||
|
|
||||||
__log__.debug('Starting send loop')
|
__log__.debug('Starting send loop')
|
||||||
self._send_loop_handle = asyncio.ensure_future(self._send_loop())
|
self._send_loop_handle = asyncio.ensure_future(
|
||||||
|
self._send_loop(), loop=self._loop)
|
||||||
|
|
||||||
__log__.debug('Starting receive loop')
|
__log__.debug('Starting receive loop')
|
||||||
self._recv_loop_handle = asyncio.ensure_future(self._recv_loop())
|
self._recv_loop_handle = asyncio.ensure_future(
|
||||||
|
self._recv_loop(), loop=self._loop)
|
||||||
|
|
||||||
if self._is_first_query:
|
if self._is_first_query:
|
||||||
__log__.debug('Running first query')
|
__log__.debug('Running first query')
|
||||||
self._is_first_query = False
|
self._is_first_query = False
|
||||||
|
@ -347,11 +352,11 @@ class MTProtoSender:
|
||||||
continue
|
continue
|
||||||
except ConnectionError as e:
|
except ConnectionError as e:
|
||||||
__log__.info('Connection reset while receiving %s', e)
|
__log__.info('Connection reset while receiving %s', e)
|
||||||
asyncio.ensure_future(self._reconnect())
|
asyncio.ensure_future(self._reconnect(), loop=self._loop)
|
||||||
break
|
break
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
__log__.warning('OSError while receiving %s', e)
|
__log__.warning('OSError while receiving %s', e)
|
||||||
asyncio.ensure_future(self._reconnect())
|
asyncio.ensure_future(self._reconnect(), loop=self._loop)
|
||||||
break
|
break
|
||||||
|
|
||||||
# TODO Check salt, session_id and sequence_number
|
# TODO Check salt, session_id and sequence_number
|
||||||
|
@ -370,7 +375,7 @@ class MTProtoSender:
|
||||||
# an actually broken authkey?
|
# an actually broken authkey?
|
||||||
__log__.warning('Broken authorization key?: {}'.format(e))
|
__log__.warning('Broken authorization key?: {}'.format(e))
|
||||||
self.state.auth_key = None
|
self.state.auth_key = None
|
||||||
asyncio.ensure_future(self._reconnect())
|
asyncio.ensure_future(self._reconnect(), loop=self._loop)
|
||||||
break
|
break
|
||||||
except SecurityError as e:
|
except SecurityError as e:
|
||||||
# A step while decoding had the incorrect data. This message
|
# A step while decoding had the incorrect data. This message
|
||||||
|
|
Loading…
Reference in New Issue
Block a user