Stop waiting for send items on disconnection

This commit is contained in:
Lonami Exo 2018-09-29 12:48:50 +02:00
parent 470fb9f5df
commit b02ebcb69b

View File

@ -3,10 +3,10 @@ import collections
import logging import logging
from .mtprotolayer import MTProtoLayer from .mtprotolayer import MTProtoLayer
from .requeststate import RequestState
from .. import utils from .. import utils
from ..errors import ( from ..errors import (
BadMessageError, TypeNotFoundError, BrokenAuthKeyError, SecurityError, BadMessageError, TypeNotFoundError, rpc_message_to_error
rpc_message_to_error
) )
from ..extensions import BinaryReader from ..extensions import BinaryReader
from ..tl.core import RpcResult, MessageContainer, GzipPacked from ..tl.core import RpcResult, MessageContainer, GzipPacked
@ -16,7 +16,6 @@ from ..tl.types import (
MsgNewDetailedInfo, NewSessionCreated, MsgDetailedInfo, MsgsStateReq, MsgNewDetailedInfo, NewSessionCreated, MsgDetailedInfo, MsgsStateReq,
MsgsStateInfo, MsgsAllInfo, MsgResendReq, upload MsgsStateInfo, MsgsAllInfo, MsgResendReq, upload
) )
from .requeststate import RequestState
__log__ = logging.getLogger(__name__) __log__ = logging.getLogger(__name__)
@ -65,7 +64,10 @@ class MTProtoSender:
# Outgoing messages are put in a queue and sent in a batch. # Outgoing messages are put in a queue and sent in a batch.
# Note that here we're also storing their ``_RequestState``. # Note that here we're also storing their ``_RequestState``.
# Note that it may also store lists (implying order must be kept). # Note that it may also store lists (implying order must be kept).
self._send_queue = asyncio.Queue() #
# TODO Abstract this queue away?
self._send_queue = []
self._send_ready = asyncio.Event(loop=self._loop)
# Sent states are remembered until a response is received. # Sent states are remembered until a response is received.
self._pending_state = {} self._pending_state = {}
@ -189,7 +191,8 @@ class MTProtoSender:
if not utils.is_list_like(request): if not utils.is_list_like(request):
state = RequestState(request, self._loop) state = RequestState(request, self._loop)
self._send_queue.put_nowait(state) self._send_queue.append(state)
self._send_ready.set()
return state.future return state.future
else: else:
states = [] states = []
@ -199,10 +202,11 @@ class MTProtoSender:
states.append(state) states.append(state)
futures.append(state.future) futures.append(state.future)
if ordered: if ordered:
self._send_queue.put(states) self._send_queue.append(states)
else: else:
for state in states: self._send_queue.extend(states)
self._send_queue.put(state)
self._send_ready.set()
return futures return futures
@property @property
@ -328,19 +332,29 @@ class MTProtoSender:
while self._user_connected and not self._reconnecting: while self._user_connected and not self._reconnecting:
if self._pending_ack: if self._pending_ack:
ack = RequestState(MsgsAck(list(self._pending_ack)), self._loop) ack = RequestState(MsgsAck(list(self._pending_ack)), self._loop)
self._send_queue.put_nowait(ack) self._send_queue.append(ack)
self._send_ready.set()
self._last_acks.append(ack) self._last_acks.append(ack)
self._pending_ack.clear() self._pending_ack.clear()
state_list = [] queue = asyncio.ensure_future(
self._send_ready.wait(), loop=self._loop)
# TODO wait for the list to have one or for a disconnect to happen disconnected = asyncio.ensure_future(
# and pop while that's the case self._connection._connection._disconnected.wait())
state = await self._send_queue.get()
state_list.append(state)
while not self._send_queue.empty(): # Basically using the disconnected as a cancellation token
state_list.append(self._send_queue.get_nowait()) done, pending = await asyncio.wait(
[queue, disconnected],
return_when=asyncio.FIRST_COMPLETED,
loop=self._loop
)
if disconnected in done:
break
state_list = self._send_queue
self._send_queue = []
self._send_ready.clear()
# TODO Debug logs to notify which messages are being sent # TODO Debug logs to notify which messages are being sent
# TODO Try sending them while no future was cancelled? # TODO Try sending them while no future was cancelled?
@ -409,8 +423,9 @@ class MTProtoSender:
if rpc_result.error: if rpc_result.error:
error = rpc_message_to_error(rpc_result.error) error = rpc_message_to_error(rpc_result.error)
self._send_queue.put_nowait( self._send_queue.append(
RequestState(MsgsAck([state.msg_id]), loop=self._loop)) RequestState(MsgsAck([state.msg_id]), loop=self._loop))
self._send_ready.set()
if not state.future.cancelled(): if not state.future.cancelled():
state.future.set_exception(error) state.future.set_exception(error)
@ -477,12 +492,14 @@ class MTProtoSender:
__log__.debug('Handling bad salt for message %d', bad_salt.bad_msg_id) __log__.debug('Handling bad salt for message %d', bad_salt.bad_msg_id)
self._connection._state.salt = bad_salt.new_server_salt self._connection._state.salt = bad_salt.new_server_salt
try: try:
self._send_queue.put_nowait( self._send_queue.append(
self._pending_state.pop(bad_salt.bad_msg_id)) self._pending_state.pop(bad_salt.bad_msg_id))
self._send_ready.set()
except KeyError: except KeyError:
for ack in self._pending_ack: for ack in self._pending_ack:
if ack.msg_id == bad_salt.bad_msg_id: if ack.msg_id == bad_salt.bad_msg_id:
self._send_queue.put_nowait(ack) self._send_queue.append(ack)
self._send_ready.set()
return return
__log__.info('Message %d not resent due to bad salt', __log__.info('Message %d not resent due to bad salt',
@ -521,7 +538,8 @@ class MTProtoSender:
# Messages are to be re-sent once we've corrected the issue # Messages are to be re-sent once we've corrected the issue
if state: if state:
self._send_queue.put_nowait(state) self._send_queue.append(state)
self._send_ready.set()
else: else:
# TODO Generic method that may return from the acks too # TODO Generic method that may return from the acks too
# May be MsgsAck, those are not saved in pending messages # May be MsgsAck, those are not saved in pending messages
@ -606,9 +624,10 @@ class MTProtoSender:
Handles both :tl:`MsgsStateReq` and :tl:`MsgResendReq` by Handles both :tl:`MsgsStateReq` and :tl:`MsgResendReq` by
enqueuing a :tl:`MsgsStateInfo` to be sent at a later point. enqueuing a :tl:`MsgsStateInfo` to be sent at a later point.
""" """
self._send_queue.put_nowait(RequestState(MsgsStateInfo( self._send_queue.append(RequestState(MsgsStateInfo(
req_msg_id=message.msg_id, info=chr(1) * len(message.obj.msg_ids)), req_msg_id=message.msg_id, info=chr(1) * len(message.obj.msg_ids)),
loop=self._loop)) loop=self._loop))
self._send_ready.set()
async def _handle_msg_all(self, message): async def _handle_msg_all(self, message):
""" """