Fix event loop not being passed into many asyncio calls

This commit is contained in:
Lonami Exo 2018-08-21 11:31:14 +02:00
parent d474458136
commit 47190d7d55
7 changed files with 37 additions and 27 deletions

View File

@ -244,7 +244,7 @@ class TelegramBaseClient(abc.ABC):
# being ``n`` the amount of borrows a given sender has; once ``n`` # being ``n`` the amount of borrows a given sender has; once ``n``
# reaches ``0`` it should be disconnected and removed. # reaches ``0`` it should be disconnected and removed.
self._borrowed_senders = {} self._borrowed_senders = {}
self._borrow_sender_lock = asyncio.Lock() self._borrow_sender_lock = asyncio.Lock(loop=self._loop)
# Save whether the user is authorized here (a.k.a. logged in) # Save whether the user is authorized here (a.k.a. logged in)
self._authorized = None # None = We don't know yet self._authorized = None # None = We don't know yet
@ -258,8 +258,8 @@ class TelegramBaseClient(abc.ABC):
self._channel_pts = {} self._channel_pts = {}
if sequential_updates: if sequential_updates:
self._updates_queue = asyncio.Queue() self._updates_queue = asyncio.Queue(loop=self._loop)
self._dispatching_updates_queue = asyncio.Event() self._dispatching_updates_queue = asyncio.Event(loop=self._loop)
else: else:
self._updates_queue = None self._updates_queue = None
self._dispatching_updates_queue = None self._dispatching_updates_queue = None

View File

@ -162,8 +162,10 @@ class InlineQuery(EventBuilder):
if self._answered: if self._answered:
return return
results = [self._as_awaitable(x) for x in results] results = [self._as_awaitable(x, self._client.loop)
done, _ = await asyncio.wait(results) for x in results]
done, _ = await asyncio.wait(results, loop=self._client.loop)
results = [x.result() for x in done] results = [x.result() for x in done]
if switch_pm: if switch_pm:
@ -181,10 +183,10 @@ class InlineQuery(EventBuilder):
) )
@staticmethod @staticmethod
def _as_awaitable(obj): def _as_awaitable(obj, loop):
if inspect.isawaitable(obj): if inspect.isawaitable(obj):
return obj return obj
f = asyncio.Future() f = asyncio.Future(loop=loop)
f.set_result(obj) f.set_result(obj)
return f return f

View File

@ -205,14 +205,16 @@ class MTProtoSender:
result = [] result = []
after = None after = None
for r in request: for r in request:
message = self.state.create_message(r, after=after) message = self.state.create_message(
r, loop=self._loop, after=after)
self._pending_messages[message.msg_id] = message self._pending_messages[message.msg_id] = message
self._send_queue.put_nowait(message) self._send_queue.put_nowait(message)
result.append(message.future) result.append(message.future)
after = ordered and message after = ordered and message
return result return result
else: else:
message = self.state.create_message(request) message = self.state.create_message(request, loop=self._loop)
self._pending_messages[message.msg_id] = message self._pending_messages[message.msg_id] = message
self._send_queue.put_nowait(message) self._send_queue.put_nowait(message)
return message.future return message.future
@ -280,7 +282,7 @@ class MTProtoSender:
# First connection or manual reconnection after a failure # First connection or manual reconnection after a failure
if self._disconnected is None or self._disconnected.done(): if self._disconnected is None or self._disconnected.done():
self._disconnected = asyncio.Future() self._disconnected = asyncio.Future(loop=self._loop)
__log__.info('Connection to {} complete!'.format(self._ip)) __log__.info('Connection to {} complete!'.format(self._ip))
async def _reconnect(self): async def _reconnect(self):
@ -352,7 +354,7 @@ class MTProtoSender:
while self._user_connected and not self._reconnecting: while self._user_connected and not self._reconnecting:
if self._pending_ack: if self._pending_ack:
self._last_ack = self.state.create_message( self._last_ack = self.state.create_message(
MsgsAck(list(self._pending_ack)) MsgsAck(list(self._pending_ack)), loop=self._loop
) )
self._send_queue.put_nowait(self._last_ack) self._send_queue.put_nowait(self._last_ack)
self._pending_ack.clear() self._pending_ack.clear()
@ -365,7 +367,9 @@ class MTProtoSender:
continue continue
if isinstance(messages, list): if isinstance(messages, list):
message = self.state.create_message(MessageContainer(messages)) message = self.state.create_message(
MessageContainer(messages), loop=self._loop)
self._pending_messages[message.msg_id] = message self._pending_messages[message.msg_id] = message
self._pending_containers.append(message) self._pending_containers.append(message)
else: else:
@ -394,7 +398,7 @@ class MTProtoSender:
__log__.warning('OSError while sending %s', e) __log__.warning('OSError while sending %s', e)
else: else:
__log__.exception('Unhandled exception while receiving') __log__.exception('Unhandled exception while receiving')
await asyncio.sleep(1) await asyncio.sleep(1, loop=self._loop)
self._start_reconnect() self._start_reconnect()
break break
@ -433,7 +437,7 @@ class MTProtoSender:
__log__.warning('OSError while receiving %s', e) __log__.warning('OSError while receiving %s', e)
else: else:
__log__.exception('Unhandled exception while receiving') __log__.exception('Unhandled exception while receiving')
await asyncio.sleep(1) await asyncio.sleep(1, loop=self._loop)
self._start_reconnect() self._start_reconnect()
break break
@ -471,7 +475,7 @@ class MTProtoSender:
return return
except Exception as e: except Exception as e:
__log__.exception('Unhandled exception while unpacking %s',e) __log__.exception('Unhandled exception while unpacking %s',e)
await asyncio.sleep(1) await asyncio.sleep(1, loop=self._loop)
else: else:
try: try:
await self._process_message(message) await self._process_message(message)
@ -480,7 +484,7 @@ class MTProtoSender:
except Exception as e: except Exception as e:
__log__.exception('Unhandled exception while ' __log__.exception('Unhandled exception while '
'processing %s', message) 'processing %s', message)
await asyncio.sleep(1) await asyncio.sleep(1, loop=self._loop)
# Response Handlers # Response Handlers
@ -525,7 +529,7 @@ class MTProtoSender:
if rpc_result.error: if rpc_result.error:
error = rpc_message_to_error(rpc_result.error) error = rpc_message_to_error(rpc_result.error)
self._send_queue.put_nowait(self.state.create_message( self._send_queue.put_nowait(self.state.create_message(
MsgsAck([message.msg_id]) MsgsAck([message.msg_id]), loop=self._loop
)) ))
if not message.future.cancelled(): if not message.future.cancelled():

View File

@ -37,7 +37,7 @@ class MTProtoState:
self._sequence = 0 self._sequence = 0
self._last_msg_id = 0 self._last_msg_id = 0
def create_message(self, obj, after=None): def create_message(self, obj, *, loop, after=None):
""" """
Creates a new `telethon.tl.tl_message.TLMessage` from Creates a new `telethon.tl.tl_message.TLMessage` from
the given `telethon.tl.tlobject.TLObject` instance. the given `telethon.tl.tlobject.TLObject` instance.
@ -47,7 +47,8 @@ class MTProtoState:
seq_no=self._get_seq_no(isinstance(obj, TLRequest)), seq_no=self._get_seq_no(isinstance(obj, TLRequest)),
obj=obj, obj=obj,
after_id=after.msg_id if after else None, after_id=after.msg_id if after else None,
out=True # Pre-convert the request into bytes out=True, # Pre-convert the request into bytes
loop=loop
) )
def update_message_id(self, message): def update_message_id(self, message):
@ -135,7 +136,7 @@ class MTProtoState:
# reader isn't used for anything else after this, it's unnecessary. # reader isn't used for anything else after this, it's unnecessary.
obj = reader.tgread_object() obj = reader.tgread_object()
return TLMessage(remote_msg_id, remote_sequence, obj) return TLMessage(remote_msg_id, remote_sequence, obj, loop=None)
def _get_new_msg_id(self): def _get_new_msg_id(self):
""" """

View File

@ -43,5 +43,5 @@ class MessageContainer(TLObject):
before = reader.tell_position() before = reader.tell_position()
obj = reader.tgread_object() # May over-read e.g. RpcResult obj = reader.tgread_object() # May over-read e.g. RpcResult
reader.set_position(before + length) reader.set_position(before + length)
messages.append(TLMessage(msg_id, seq_no, obj)) messages.append(TLMessage(msg_id, seq_no, obj, loop=None))
return MessageContainer(messages) return MessageContainer(messages)

View File

@ -24,10 +24,13 @@ class TLMessage(TLObject):
sent `TLMessage`, and this result can be represented as a `Future` sent `TLMessage`, and this result can be represented as a `Future`
that will eventually be set with either a result, error or cancelled. that will eventually be set with either a result, error or cancelled.
""" """
def __init__(self, msg_id, seq_no, obj, out=False, after_id=0): def __init__(self, msg_id, seq_no, obj, *, loop, out=False, after_id=0):
self.obj = obj self.obj = obj
self.container_msg_id = None self.container_msg_id = None
self.future = asyncio.Future()
# If no loop is given then it is an incoming message.
# Only outgoing messages need the future to await them.
self.future = asyncio.Future(loop=loop) if loop else None
# After which message ID this one should run. We do this so # After which message ID this one should run. We do this so
# InvokeAfterMsgRequest is transparent to the user and we can # InvokeAfterMsgRequest is transparent to the user and we can

View File

@ -181,7 +181,7 @@ class Conversation(ChatGetter):
return incoming return incoming
# Otherwise the next incoming response will be the one to use # Otherwise the next incoming response will be the one to use
future = asyncio.Future() future = asyncio.Future(loop=self._client.loop)
pending[target_id] = future pending[target_id] = future
return self._get_result(future, start_time, timeout) return self._get_result(future, start_time, timeout)
@ -209,7 +209,7 @@ class Conversation(ChatGetter):
return earliest_edit return earliest_edit
# Otherwise the next incoming response will be the one to use # Otherwise the next incoming response will be the one to use
future = asyncio.Future() future = asyncio.Future(loop=self._client.loop)
self._pending_edits[target_id] = future self._pending_edits[target_id] = future
return await self._get_result(future, start_time, timeout) return await self._get_result(future, start_time, timeout)
@ -220,7 +220,7 @@ class Conversation(ChatGetter):
will also trigger even without a response. will also trigger even without a response.
""" """
start_time = time.time() start_time = time.time()
future = asyncio.Future() future = asyncio.Future(loop=self._client.loop)
target_id = self._get_message_id(message) target_id = self._get_message_id(message)
if self._last_read is None: if self._last_read is None:
@ -265,7 +265,7 @@ class Conversation(ChatGetter):
counter = Conversation._custom_counter counter = Conversation._custom_counter
Conversation._custom_counter += 1 Conversation._custom_counter += 1
future = asyncio.Future() future = asyncio.Future(loop=self._client.loop)
async def result(): async def result():
try: try:
return await self._get_result(future, start_time, timeout) return await self._get_result(future, start_time, timeout)