Avoid explicitly passing the loop to asyncio

This behaviour is deprecated and will be removed in future versions
of Python. Technically, it could be considered a bug (invalid usage
causing different behaviour from the expected one), and in practice
it should not break much code (because .get_event_loop() would likely
be the same event loop anyway).
This commit is contained in:
Lonami Exo 2020-07-25 18:39:35 +02:00
parent de17a19168
commit 1c3e7dda01
15 changed files with 65 additions and 73 deletions

View File

@ -219,19 +219,21 @@ Can I use threads?
================== ==================
Yes, you can, but you must understand that the loops themselves are Yes, you can, but you must understand that the loops themselves are
not thread safe. and you must be sure to know what is happening. You not thread safe. and you must be sure to know what is happening. The
may want to create a loop in a new thread and make sure to pass it to easiest and cleanest option is to use `asyncio.run` to create and manage
the client: the new event loop for you:
.. code-block:: python .. code-block:: python
import asyncio import asyncio
import threading import threading
def go(): async def actual_work():
loop = asyncio.new_event_loop()
client = TelegramClient(..., loop=loop) client = TelegramClient(..., loop=loop)
... ... # can use `await` here
def go():
asyncio.run(actual_work())
threading.Thread(target=go).start() threading.Thread(target=go).start()

View File

@ -378,7 +378,7 @@ class DialogMethods:
entities = [await self.get_input_entity(entity)] entities = [await self.get_input_entity(entity)]
else: else:
entities = await asyncio.gather( entities = await asyncio.gather(
*(self.get_input_entity(x) for x in entity), loop=self.loop) *(self.get_input_entity(x) for x in entity))
if folder is None: if folder is None:
raise ValueError('You must specify a folder') raise ValueError('You must specify a folder')

View File

@ -180,7 +180,8 @@ class TelegramBaseClient(abc.ABC):
Defaults to `lang_code`. Defaults to `lang_code`.
loop (`asyncio.AbstractEventLoop`, optional): loop (`asyncio.AbstractEventLoop`, optional):
Asyncio event loop to use. Defaults to `asyncio.get_event_loop()` Asyncio event loop to use. Defaults to `asyncio.get_event_loop()`.
This argument is ignored.
base_logger (`str` | `logging.Logger`, optional): base_logger (`str` | `logging.Logger`, optional):
Base logger name or instance to use. Base logger name or instance to use.
@ -227,7 +228,7 @@ class TelegramBaseClient(abc.ABC):
"Refer to telethon.rtfd.io for more information.") "Refer to telethon.rtfd.io for more information.")
self._use_ipv6 = use_ipv6 self._use_ipv6 = use_ipv6
self._loop = loop or asyncio.get_event_loop() self._loop = asyncio.get_event_loop()
if isinstance(base_logger, str): if isinstance(base_logger, str):
base_logger = logging.getLogger(base_logger) base_logger = logging.getLogger(base_logger)
@ -334,7 +335,7 @@ class TelegramBaseClient(abc.ABC):
) )
self._sender = MTProtoSender( self._sender = MTProtoSender(
self.session.auth_key, self._loop, self.session.auth_key,
loggers=self._log, loggers=self._log,
retries=self._connection_retries, retries=self._connection_retries,
delay=self._retry_delay, delay=self._retry_delay,
@ -350,15 +351,15 @@ class TelegramBaseClient(abc.ABC):
# Cache ``{dc_id: (_ExportState, MTProtoSender)}`` for all borrowed senders # Cache ``{dc_id: (_ExportState, MTProtoSender)}`` for all borrowed senders
self._borrowed_senders = {} self._borrowed_senders = {}
self._borrow_sender_lock = asyncio.Lock(loop=self._loop) self._borrow_sender_lock = asyncio.Lock()
self._updates_handle = None self._updates_handle = None
self._last_request = time.time() self._last_request = time.time()
self._channel_pts = {} self._channel_pts = {}
if sequential_updates: if sequential_updates:
self._updates_queue = asyncio.Queue(loop=self._loop) self._updates_queue = asyncio.Queue()
self._dispatching_updates_queue = asyncio.Event(loop=self._loop) self._dispatching_updates_queue = asyncio.Event()
else: else:
# Use a set of pending instead of a queue so we can properly # Use a set of pending instead of a queue so we can properly
# terminate all pending updates on disconnect. # terminate all pending updates on disconnect.
@ -481,7 +482,6 @@ class TelegramBaseClient(abc.ABC):
self.session.server_address, self.session.server_address,
self.session.port, self.session.port,
self.session.dc_id, self.session.dc_id,
loop=self._loop,
loggers=self._log, loggers=self._log,
proxy=self._proxy proxy=self._proxy
)): )):
@ -556,7 +556,7 @@ class TelegramBaseClient(abc.ABC):
for task in self._updates_queue: for task in self._updates_queue:
task.cancel() task.cancel()
await asyncio.wait(self._updates_queue, loop=self._loop) await asyncio.wait(self._updates_queue)
self._updates_queue.clear() self._updates_queue.clear()
pts, date = self._state_cache[None] pts, date = self._state_cache[None]
@ -639,12 +639,11 @@ class TelegramBaseClient(abc.ABC):
# #
# If one were to do that, Telegram would reset the connection # If one were to do that, Telegram would reset the connection
# with no further clues. # with no further clues.
sender = MTProtoSender(None, self._loop, loggers=self._log) sender = MTProtoSender(None, loggers=self._log)
await sender.connect(self._connection( await sender.connect(self._connection(
dc.ip_address, dc.ip_address,
dc.port, dc.port,
dc.id, dc.id,
loop=self._loop,
loggers=self._log, loggers=self._log,
proxy=self._proxy proxy=self._proxy
)) ))
@ -680,7 +679,6 @@ class TelegramBaseClient(abc.ABC):
dc.ip_address, dc.ip_address,
dc.port, dc.port,
dc.id, dc.id,
loop=self._loop,
loggers=self._log, loggers=self._log,
proxy=self._proxy proxy=self._proxy
)) ))

View File

@ -326,7 +326,7 @@ class UpdateMethods:
while self.is_connected(): while self.is_connected():
try: try:
await asyncio.wait_for( await asyncio.wait_for(
self.disconnected, timeout=60, loop=self._loop self.disconnected, timeout=60
) )
continue # We actually just want to act upon timeout continue # We actually just want to act upon timeout
except asyncio.TimeoutError: except asyncio.TimeoutError:

View File

@ -44,7 +44,7 @@ class UserMethods:
self._flood_waited_requests.pop(r.CONSTRUCTOR_ID, None) self._flood_waited_requests.pop(r.CONSTRUCTOR_ID, None)
elif diff <= self.flood_sleep_threshold: elif diff <= self.flood_sleep_threshold:
self._log[__name__].info(*_fmt_flood(diff, r, early=True)) self._log[__name__].info(*_fmt_flood(diff, r, early=True))
await asyncio.sleep(diff, loop=self._loop) await asyncio.sleep(diff)
self._flood_waited_requests.pop(r.CONSTRUCTOR_ID, None) self._flood_waited_requests.pop(r.CONSTRUCTOR_ID, None)
else: else:
raise errors.FloodWaitError(request=r, capture=diff) raise errors.FloodWaitError(request=r, capture=diff)
@ -99,7 +99,7 @@ class UserMethods:
if e.seconds <= self.flood_sleep_threshold: if e.seconds <= self.flood_sleep_threshold:
self._log[__name__].info(*_fmt_flood(e.seconds, request)) self._log[__name__].info(*_fmt_flood(e.seconds, request))
await asyncio.sleep(e.seconds, loop=self._loop) await asyncio.sleep(e.seconds)
else: else:
raise raise
except (errors.PhoneMigrateError, errors.NetworkMigrateError, except (errors.PhoneMigrateError, errors.NetworkMigrateError,

View File

@ -92,7 +92,7 @@ class EventBuilder(abc.ABC):
return return
if not self._resolve_lock: if not self._resolve_lock:
self._resolve_lock = asyncio.Lock(loop=client.loop) self._resolve_lock = asyncio.Lock()
async with self._resolve_lock: async with self._resolve_lock:
if not self.resolved: if not self.resolved:

View File

@ -206,10 +206,9 @@ class InlineQuery(EventBuilder):
return return
if results: if results:
futures = [self._as_future(x, self._client.loop) futures = [self._as_future(x) for x in results]
for x in results]
await asyncio.wait(futures, loop=self._client.loop) await asyncio.wait(futures)
# All futures will be in the `done` *set* that `wait` returns. # All futures will be in the `done` *set* that `wait` returns.
# #
@ -236,10 +235,10 @@ class InlineQuery(EventBuilder):
) )
@staticmethod @staticmethod
def _as_future(obj, loop): def _as_future(obj):
if inspect.isawaitable(obj): if inspect.isawaitable(obj):
return asyncio.ensure_future(obj, loop=loop) return asyncio.ensure_future(obj)
f = loop.create_future() f = asyncio.get_event_loop().create_future()
f.set_result(obj) f.set_result(obj)
return f return f

View File

@ -22,11 +22,10 @@ class MessagePacker:
point where outgoing requests are put, and where ready-messages are get. point where outgoing requests are put, and where ready-messages are get.
""" """
def __init__(self, state, loop, loggers): def __init__(self, state, loggers):
self._state = state self._state = state
self._loop = loop
self._deque = collections.deque() self._deque = collections.deque()
self._ready = asyncio.Event(loop=loop) self._ready = asyncio.Event()
self._log = loggers[__name__] self._log = loggers[__name__]
def append(self, state): def append(self, state):

View File

@ -28,11 +28,10 @@ class Connection(abc.ABC):
# should be one of `PacketCodec` implementations # should be one of `PacketCodec` implementations
packet_codec = None packet_codec = None
def __init__(self, ip, port, dc_id, *, loop, loggers, proxy=None): def __init__(self, ip, port, dc_id, *, loggers, proxy=None):
self._ip = ip self._ip = ip
self._port = port self._port = port
self._dc_id = dc_id # only for MTProxy, it's an abstraction leak self._dc_id = dc_id # only for MTProxy, it's an abstraction leak
self._loop = loop
self._log = loggers[__name__] self._log = loggers[__name__]
self._proxy = proxy self._proxy = proxy
self._reader = None self._reader = None
@ -48,9 +47,8 @@ class Connection(abc.ABC):
async def _connect(self, timeout=None, ssl=None): async def _connect(self, timeout=None, ssl=None):
if not self._proxy: if not self._proxy:
self._reader, self._writer = await asyncio.wait_for( self._reader, self._writer = await asyncio.wait_for(
asyncio.open_connection( asyncio.open_connection(self._ip, self._port, ssl=ssl),
self._ip, self._port, loop=self._loop, ssl=ssl), timeout=timeout
loop=self._loop, timeout=timeout
) )
else: else:
import socks import socks
@ -67,9 +65,8 @@ class Connection(abc.ABC):
s.settimeout(timeout) s.settimeout(timeout)
await asyncio.wait_for( await asyncio.wait_for(
self._loop.sock_connect(s, address), asyncio.get_event_loop().sock_connect(s, address),
timeout=timeout, timeout=timeout
loop=self._loop
) )
if ssl: if ssl:
if ssl_mod is None: if ssl_mod is None:
@ -87,8 +84,7 @@ class Connection(abc.ABC):
s.setblocking(False) s.setblocking(False)
self._reader, self._writer = \ self._reader, self._writer = await asyncio.open_connection(sock=s)
await asyncio.open_connection(sock=s, loop=self._loop)
self._codec = self.packet_codec(self) self._codec = self.packet_codec(self)
self._init_conn() self._init_conn()
@ -101,8 +97,9 @@ class Connection(abc.ABC):
await self._connect(timeout=timeout, ssl=ssl) await self._connect(timeout=timeout, ssl=ssl)
self._connected = True self._connected = True
self._send_task = self._loop.create_task(self._send_loop()) loop = asyncio.get_event_loop()
self._recv_task = self._loop.create_task(self._recv_loop()) self._send_task = loop.create_task(self._send_loop())
self._recv_task = loop.create_task(self._recv_loop())
async def disconnect(self): async def disconnect(self):
""" """

View File

@ -95,12 +95,12 @@ class TcpMTProxy(ObfuscatedConnection):
obfuscated_io = MTProxyIO obfuscated_io = MTProxyIO
# noinspection PyUnusedLocal # noinspection PyUnusedLocal
def __init__(self, ip, port, dc_id, *, loop, loggers, proxy=None): def __init__(self, ip, port, dc_id, *, loggers, proxy=None):
# connect to proxy's host and port instead of telegram's ones # connect to proxy's host and port instead of telegram's ones
proxy_host, proxy_port = self.address_info(proxy) proxy_host, proxy_port = self.address_info(proxy)
self._secret = bytes.fromhex(proxy[2]) self._secret = bytes.fromhex(proxy[2])
super().__init__( super().__init__(
proxy_host, proxy_port, dc_id, loop=loop, loggers=loggers) proxy_host, proxy_port, dc_id, loggers=loggers)
async def _connect(self, timeout=None, ssl=None): async def _connect(self, timeout=None, ssl=None):
await super()._connect(timeout=timeout, ssl=ssl) await super()._connect(timeout=timeout, ssl=ssl)

View File

@ -40,12 +40,11 @@ class MTProtoSender:
A new authorization key will be generated on connection if no other A new authorization key will be generated on connection if no other
key exists yet. key exists yet.
""" """
def __init__(self, auth_key, loop, *, loggers, def __init__(self, auth_key, *, loggers,
retries=5, delay=1, auto_reconnect=True, connect_timeout=None, retries=5, delay=1, auto_reconnect=True, connect_timeout=None,
auth_key_callback=None, auth_key_callback=None,
update_callback=None, auto_reconnect_callback=None): update_callback=None, auto_reconnect_callback=None):
self._connection = None self._connection = None
self._loop = loop
self._loggers = loggers self._loggers = loggers
self._log = loggers[__name__] self._log = loggers[__name__]
self._retries = retries self._retries = retries
@ -55,7 +54,7 @@ class MTProtoSender:
self._auth_key_callback = auth_key_callback self._auth_key_callback = auth_key_callback
self._update_callback = update_callback self._update_callback = update_callback
self._auto_reconnect_callback = auto_reconnect_callback self._auto_reconnect_callback = auto_reconnect_callback
self._connect_lock = asyncio.Lock(loop=loop) self._connect_lock = asyncio.Lock()
# Whether the user has explicitly connected or disconnected. # Whether the user has explicitly connected or disconnected.
# #
@ -65,7 +64,7 @@ class MTProtoSender:
# pending futures should be cancelled. # pending futures should be cancelled.
self._user_connected = False self._user_connected = False
self._reconnecting = False self._reconnecting = False
self._disconnected = self._loop.create_future() self._disconnected = asyncio.get_event_loop().create_future()
self._disconnected.set_result(None) self._disconnected.set_result(None)
# We need to join the loops upon disconnection # We need to join the loops upon disconnection
@ -78,8 +77,7 @@ class MTProtoSender:
# Outgoing messages are put in a queue and sent in a batch. # Outgoing messages are put in a queue and sent in a batch.
# Note that here we're also storing their ``_RequestState``. # Note that here we're also storing their ``_RequestState``.
self._send_queue = MessagePacker(self._state, self._loop, self._send_queue = MessagePacker(self._state, loggers=self._loggers)
loggers=self._loggers)
# Sent states are remembered until a response is received. # Sent states are remembered until a response is received.
self._pending_state = {} self._pending_state = {}
@ -171,7 +169,7 @@ class MTProtoSender:
if not utils.is_list_like(request): if not utils.is_list_like(request):
try: try:
state = RequestState(request, self._loop) state = RequestState(request)
except struct.error as e: except struct.error as e:
# "struct.error: required argument is not an integer" is not # "struct.error: required argument is not an integer" is not
# very helpful; log the request to find out what wasn't int. # very helpful; log the request to find out what wasn't int.
@ -186,7 +184,7 @@ class MTProtoSender:
state = None state = None
for req in request: for req in request:
try: try:
state = RequestState(req, self._loop, after=ordered and state) state = RequestState(req, after=ordered and state)
except struct.error as e: except struct.error as e:
self._log.error('Request caused struct.error: %s: %s', e, request) self._log.error('Request caused struct.error: %s: %s', e, request)
raise raise
@ -206,7 +204,7 @@ class MTProtoSender:
Note that it may resolve in either a ``ConnectionError`` Note that it may resolve in either a ``ConnectionError``
or any other unexpected error that could not be handled. or any other unexpected error that could not be handled.
""" """
return asyncio.shield(self._disconnected, loop=self._loop) return asyncio.shield(self._disconnected)
# Private methods # Private methods
@ -241,7 +239,7 @@ class MTProtoSender:
# reconnect cleanly after. # reconnect cleanly after.
await self._connection.disconnect() await self._connection.disconnect()
connected = False connected = False
await asyncio.sleep(self._delay, loop=self._loop) await asyncio.sleep(self._delay)
continue # next iteration we will try to reconnect continue # next iteration we will try to reconnect
break # all steps done, break retry loop break # all steps done, break retry loop
@ -253,17 +251,18 @@ class MTProtoSender:
await self._disconnect(error=e) await self._disconnect(error=e)
raise e raise e
loop = asyncio.get_event_loop()
self._log.debug('Starting send loop') self._log.debug('Starting send loop')
self._send_loop_handle = self._loop.create_task(self._send_loop()) self._send_loop_handle = loop.create_task(self._send_loop())
self._log.debug('Starting receive loop') self._log.debug('Starting receive loop')
self._recv_loop_handle = self._loop.create_task(self._recv_loop()) self._recv_loop_handle = loop.create_task(self._recv_loop())
# _disconnected only completes after manual disconnection # _disconnected only completes after manual disconnection
# or errors after which the sender cannot continue such # or errors after which the sender cannot continue such
# as failing to reconnect or any unexpected error. # as failing to reconnect or any unexpected error.
if self._disconnected.done(): if self._disconnected.done():
self._disconnected = self._loop.create_future() self._disconnected = loop.create_future()
self._log.info('Connection to %s complete!', self._connection) self._log.info('Connection to %s complete!', self._connection)
@ -378,7 +377,7 @@ class MTProtoSender:
self._pending_state.clear() self._pending_state.clear()
if self._auto_reconnect_callback: if self._auto_reconnect_callback:
self._loop.create_task(self._auto_reconnect_callback()) asyncio.get_event_loop().create_task(self._auto_reconnect_callback())
break break
else: else:
@ -398,7 +397,7 @@ class MTProtoSender:
# gets stuck. # gets stuck.
# TODO It still gets stuck? Investigate where and why. # TODO It still gets stuck? Investigate where and why.
self._reconnecting = True self._reconnecting = True
self._loop.create_task(self._reconnect(error)) asyncio.get_event_loop().create_task(self._reconnect(error))
# Loops # Loops
@ -411,7 +410,7 @@ class MTProtoSender:
""" """
while self._user_connected and not self._reconnecting: while self._user_connected and not self._reconnecting:
if self._pending_ack: if self._pending_ack:
ack = RequestState(MsgsAck(list(self._pending_ack)), self._loop) ack = RequestState(MsgsAck(list(self._pending_ack)))
self._send_queue.append(ack) self._send_queue.append(ack)
self._last_acks.append(ack) self._last_acks.append(ack)
self._pending_ack.clear() self._pending_ack.clear()
@ -564,7 +563,7 @@ class MTProtoSender:
if rpc_result.error: if rpc_result.error:
error = rpc_message_to_error(rpc_result.error, state.request) error = rpc_message_to_error(rpc_result.error, state.request)
self._send_queue.append( self._send_queue.append(
RequestState(MsgsAck([state.msg_id]), loop=self._loop)) RequestState(MsgsAck([state.msg_id])))
if not state.future.cancelled(): if not state.future.cancelled():
state.future.set_exception(error) state.future.set_exception(error)
@ -751,8 +750,8 @@ class MTProtoSender:
enqueuing a :tl:`MsgsStateInfo` to be sent at a later point. enqueuing a :tl:`MsgsStateInfo` to be sent at a later point.
""" """
self._send_queue.append(RequestState(MsgsStateInfo( self._send_queue.append(RequestState(MsgsStateInfo(
req_msg_id=message.msg_id, info=chr(1) * len(message.obj.msg_ids)), req_msg_id=message.msg_id, info=chr(1) * len(message.obj.msg_ids)
loop=self._loop)) )))
async def _handle_msg_all(self, message): async def _handle_msg_all(self, message):
""" """

View File

@ -10,10 +10,10 @@ class RequestState:
""" """
__slots__ = ('container_id', 'msg_id', 'request', 'data', 'future', 'after') __slots__ = ('container_id', 'msg_id', 'request', 'data', 'future', 'after')
def __init__(self, request, loop, after=None): def __init__(self, request, after=None):
self.container_id = None self.container_id = None
self.msg_id = None self.msg_id = None
self.request = request self.request = request
self.data = bytes(request) self.data = bytes(request)
self.future = asyncio.Future(loop=loop) self.future = asyncio.Future()
self.after = after self.after = after

View File

@ -65,8 +65,7 @@ class RequestIter(abc.ABC):
# asyncio will handle times <= 0 to sleep 0 seconds # asyncio will handle times <= 0 to sleep 0 seconds
if self.wait_time: if self.wait_time:
await asyncio.sleep( await asyncio.sleep(
self.wait_time - (time.time() - self.last_load), self.wait_time - (time.time() - self.last_load)
loop=self.client.loop
) )
self.last_load = time.time() self.last_load = time.time()

View File

@ -445,8 +445,7 @@ class Conversation(ChatGetter):
# cleared when their futures are set to a result. # cleared when their futures are set to a result.
return asyncio.wait_for( return asyncio.wait_for(
future, future,
timeout=None if due == float('inf') else due - time.time(), timeout=None if due == float('inf') else due - time.time()
loop=self._client.loop
) )
def _cancel_all(self, exception=None): def _cancel_all(self, exception=None):

View File

@ -341,8 +341,8 @@ class App(tkinter.Tk):
self.chat.configure(bg='yellow') self.chat.configure(bg='yellow')
async def main(loop, interval=0.05): async def main(interval=0.05):
client = TelegramClient(SESSION, API_ID, API_HASH, loop=loop) client = TelegramClient(SESSION, API_ID, API_HASH)
try: try:
await client.connect() await client.connect()
except Exception as e: except Exception as e:
@ -372,7 +372,7 @@ if __name__ == "__main__":
# Some boilerplate code to set up the main method # Some boilerplate code to set up the main method
aio_loop = asyncio.get_event_loop() aio_loop = asyncio.get_event_loop()
try: try:
aio_loop.run_until_complete(main(aio_loop)) aio_loop.run_until_complete(main())
finally: finally:
if not aio_loop.is_closed(): if not aio_loop.is_closed():
aio_loop.close() aio_loop.close()