diff --git a/telethon/network/mtprotosender.py b/telethon/network/mtprotosender.py index b51fa122..604cd0fb 100644 --- a/telethon/network/mtprotosender.py +++ b/telethon/network/mtprotosender.py @@ -3,10 +3,10 @@ import collections import logging from .mtprotolayer import MTProtoLayer +from .requeststate import RequestState from .. import utils from ..errors import ( - BadMessageError, TypeNotFoundError, BrokenAuthKeyError, SecurityError, - rpc_message_to_error + BadMessageError, TypeNotFoundError, rpc_message_to_error ) from ..extensions import BinaryReader from ..tl.core import RpcResult, MessageContainer, GzipPacked @@ -16,7 +16,6 @@ from ..tl.types import ( MsgNewDetailedInfo, NewSessionCreated, MsgDetailedInfo, MsgsStateReq, MsgsStateInfo, MsgsAllInfo, MsgResendReq, upload ) -from .requeststate import RequestState __log__ = logging.getLogger(__name__) @@ -65,7 +64,10 @@ class MTProtoSender: # Outgoing messages are put in a queue and sent in a batch. # Note that here we're also storing their ``_RequestState``. # Note that it may also store lists (implying order must be kept). - self._send_queue = asyncio.Queue() + # + # TODO Abstract this queue away? + self._send_queue = [] + self._send_ready = asyncio.Event(loop=self._loop) # Sent states are remembered until a response is received. self._pending_state = {} @@ -189,7 +191,8 @@ class MTProtoSender: if not utils.is_list_like(request): state = RequestState(request, self._loop) - self._send_queue.put_nowait(state) + self._send_queue.append(state) + self._send_ready.set() return state.future else: states = [] @@ -199,10 +202,11 @@ class MTProtoSender: states.append(state) futures.append(state.future) if ordered: - self._send_queue.put(states) + self._send_queue.append(states) else: - for state in states: - self._send_queue.put(state) + self._send_queue.extend(states) + + self._send_ready.set() return futures @property @@ -328,19 +332,29 @@ class MTProtoSender: while self._user_connected and not self._reconnecting: if self._pending_ack: ack = RequestState(MsgsAck(list(self._pending_ack)), self._loop) - self._send_queue.put_nowait(ack) + self._send_queue.append(ack) + self._send_ready.set() self._last_acks.append(ack) self._pending_ack.clear() - state_list = [] + queue = asyncio.ensure_future( + self._send_ready.wait(), loop=self._loop) - # TODO wait for the list to have one or for a disconnect to happen - # and pop while that's the case - state = await self._send_queue.get() - state_list.append(state) + disconnected = asyncio.ensure_future( + self._connection._connection._disconnected.wait()) - while not self._send_queue.empty(): - state_list.append(self._send_queue.get_nowait()) + # Basically using the disconnected as a cancellation token + done, pending = await asyncio.wait( + [queue, disconnected], + return_when=asyncio.FIRST_COMPLETED, + loop=self._loop + ) + if disconnected in done: + break + + state_list = self._send_queue + self._send_queue = [] + self._send_ready.clear() # TODO Debug logs to notify which messages are being sent # TODO Try sending them while no future was cancelled? @@ -409,8 +423,9 @@ class MTProtoSender: if rpc_result.error: error = rpc_message_to_error(rpc_result.error) - self._send_queue.put_nowait( + self._send_queue.append( RequestState(MsgsAck([state.msg_id]), loop=self._loop)) + self._send_ready.set() if not state.future.cancelled(): state.future.set_exception(error) @@ -477,12 +492,14 @@ class MTProtoSender: __log__.debug('Handling bad salt for message %d', bad_salt.bad_msg_id) self._connection._state.salt = bad_salt.new_server_salt try: - self._send_queue.put_nowait( + self._send_queue.append( self._pending_state.pop(bad_salt.bad_msg_id)) + self._send_ready.set() except KeyError: for ack in self._pending_ack: if ack.msg_id == bad_salt.bad_msg_id: - self._send_queue.put_nowait(ack) + self._send_queue.append(ack) + self._send_ready.set() return __log__.info('Message %d not resent due to bad salt', @@ -521,7 +538,8 @@ class MTProtoSender: # Messages are to be re-sent once we've corrected the issue if state: - self._send_queue.put_nowait(state) + self._send_queue.append(state) + self._send_ready.set() else: # TODO Generic method that may return from the acks too # May be MsgsAck, those are not saved in pending messages @@ -606,9 +624,10 @@ class MTProtoSender: Handles both :tl:`MsgsStateReq` and :tl:`MsgResendReq` by enqueuing a :tl:`MsgsStateInfo` to be sent at a later point. """ - self._send_queue.put_nowait(RequestState(MsgsStateInfo( + self._send_queue.append(RequestState(MsgsStateInfo( req_msg_id=message.msg_id, info=chr(1) * len(message.obj.msg_ids)), loop=self._loop)) + self._send_ready.set() async def _handle_msg_all(self, message): """