Abstract the send queue off MTProtoSender

This commit is contained in:
Lonami Exo 2018-09-29 13:29:44 +02:00
parent b02ebcb69b
commit 105bd52eee
3 changed files with 53 additions and 28 deletions

View File

@ -1,4 +1,5 @@
"""Various helpers not related to the Telegram API itself""" """Various helpers not related to the Telegram API itself"""
import asyncio
import collections import collections
import os import os
import struct import struct
@ -87,4 +88,42 @@ class TotalList(list):
return '[{}, total={}]'.format( return '[{}, total={}]'.format(
', '.join(repr(x) for x in self), self.total) ', '.join(repr(x) for x in self), self.total)
class _ReadyQueue:
"""
A queue list that supports an arbitrary cancellation token for `get`.
"""
def __init__(self, loop):
self._list = []
self._loop = loop
self._ready = asyncio.Event(loop=loop)
def append(self, item):
self._list.append(item)
self._ready.set()
def extend(self, items):
self._list.extend(items)
self._ready.set()
async def get(self, cancellation):
"""
Returns a list of all the items added to the queue until now and
clears the list from the queue itself. Returns ``None`` if cancelled.
"""
ready = asyncio.ensure_future(self._ready.wait(), loop=self._loop)
done, pending = await asyncio.wait(
[ready, cancellation],
return_when=asyncio.FIRST_COMPLETED,
loop=self._loop
)
if cancellation in done:
ready.cancel()
return None
result = self._list
self._list = []
self._ready.clear()
return result
# endregion # endregion

View File

@ -21,6 +21,7 @@ class Connection(abc.ABC):
self._writer = None self._writer = None
self._disconnected = asyncio.Event(loop=loop) self._disconnected = asyncio.Event(loop=loop)
self._disconnected.set() self._disconnected.set()
self._disconnected_future = None
self._send_task = None self._send_task = None
self._recv_task = None self._recv_task = None
self._send_queue = asyncio.Queue(1) self._send_queue = asyncio.Queue(1)
@ -34,6 +35,7 @@ class Connection(abc.ABC):
self._ip, self._port, loop=self._loop) self._ip, self._port, loop=self._loop)
self._disconnected.clear() self._disconnected.clear()
self._disconnected_future = None
self._send_task = self._loop.create_task(self._send_loop()) self._send_task = self._loop.create_task(self._send_loop())
self._recv_task = self._loop.create_task(self._recv_loop()) self._recv_task = self._loop.create_task(self._recv_loop())
@ -46,6 +48,13 @@ class Connection(abc.ABC):
self._recv_task.cancel() self._recv_task.cancel()
self._writer.close() self._writer.close()
@property
def disconnected(self):
if not self._disconnected_future:
self._disconnected_future = asyncio.ensure_future(
self._disconnected.wait(), loop=self._loop)
return self._disconnected_future
def clone(self): def clone(self):
""" """
Creates a clone of the connection. Creates a clone of the connection.

View File

@ -9,6 +9,7 @@ from ..errors import (
BadMessageError, TypeNotFoundError, rpc_message_to_error BadMessageError, TypeNotFoundError, rpc_message_to_error
) )
from ..extensions import BinaryReader from ..extensions import BinaryReader
from ..helpers import _ReadyQueue
from ..tl.core import RpcResult, MessageContainer, GzipPacked from ..tl.core import RpcResult, MessageContainer, GzipPacked
from ..tl.functions.auth import LogOutRequest from ..tl.functions.auth import LogOutRequest
from ..tl.types import ( from ..tl.types import (
@ -64,10 +65,7 @@ class MTProtoSender:
# Outgoing messages are put in a queue and sent in a batch. # Outgoing messages are put in a queue and sent in a batch.
# Note that here we're also storing their ``_RequestState``. # Note that here we're also storing their ``_RequestState``.
# Note that it may also store lists (implying order must be kept). # Note that it may also store lists (implying order must be kept).
# self._send_queue = _ReadyQueue(self._loop)
# TODO Abstract this queue away?
self._send_queue = []
self._send_ready = asyncio.Event(loop=self._loop)
# Sent states are remembered until a response is received. # Sent states are remembered until a response is received.
self._pending_state = {} self._pending_state = {}
@ -192,7 +190,6 @@ class MTProtoSender:
if not utils.is_list_like(request): if not utils.is_list_like(request):
state = RequestState(request, self._loop) state = RequestState(request, self._loop)
self._send_queue.append(state) self._send_queue.append(state)
self._send_ready.set()
return state.future return state.future
else: else:
states = [] states = []
@ -206,7 +203,6 @@ class MTProtoSender:
else: else:
self._send_queue.extend(states) self._send_queue.extend(states)
self._send_ready.set()
return futures return futures
@property @property
@ -333,29 +329,15 @@ class MTProtoSender:
if self._pending_ack: if self._pending_ack:
ack = RequestState(MsgsAck(list(self._pending_ack)), self._loop) ack = RequestState(MsgsAck(list(self._pending_ack)), self._loop)
self._send_queue.append(ack) self._send_queue.append(ack)
self._send_ready.set()
self._last_acks.append(ack) self._last_acks.append(ack)
self._pending_ack.clear() self._pending_ack.clear()
queue = asyncio.ensure_future( state_list = await self._send_queue.get(
self._send_ready.wait(), loop=self._loop) self._connection._connection.disconnected)
disconnected = asyncio.ensure_future( if state_list is None:
self._connection._connection._disconnected.wait())
# Basically using the disconnected as a cancellation token
done, pending = await asyncio.wait(
[queue, disconnected],
return_when=asyncio.FIRST_COMPLETED,
loop=self._loop
)
if disconnected in done:
break break
state_list = self._send_queue
self._send_queue = []
self._send_ready.clear()
# TODO Debug logs to notify which messages are being sent # TODO Debug logs to notify which messages are being sent
# TODO Try sending them while no future was cancelled? # TODO Try sending them while no future was cancelled?
# TODO Handle timeout, cancelled, arbitrary errors # TODO Handle timeout, cancelled, arbitrary errors
@ -425,7 +407,6 @@ class MTProtoSender:
error = rpc_message_to_error(rpc_result.error) error = rpc_message_to_error(rpc_result.error)
self._send_queue.append( self._send_queue.append(
RequestState(MsgsAck([state.msg_id]), loop=self._loop)) RequestState(MsgsAck([state.msg_id]), loop=self._loop))
self._send_ready.set()
if not state.future.cancelled(): if not state.future.cancelled():
state.future.set_exception(error) state.future.set_exception(error)
@ -494,12 +475,10 @@ class MTProtoSender:
try: try:
self._send_queue.append( self._send_queue.append(
self._pending_state.pop(bad_salt.bad_msg_id)) self._pending_state.pop(bad_salt.bad_msg_id))
self._send_ready.set()
except KeyError: except KeyError:
for ack in self._pending_ack: for ack in self._pending_ack:
if ack.msg_id == bad_salt.bad_msg_id: if ack.msg_id == bad_salt.bad_msg_id:
self._send_queue.append(ack) self._send_queue.append(ack)
self._send_ready.set()
return return
__log__.info('Message %d not resent due to bad salt', __log__.info('Message %d not resent due to bad salt',
@ -539,7 +518,6 @@ class MTProtoSender:
# Messages are to be re-sent once we've corrected the issue # Messages are to be re-sent once we've corrected the issue
if state: if state:
self._send_queue.append(state) self._send_queue.append(state)
self._send_ready.set()
else: else:
# TODO Generic method that may return from the acks too # TODO Generic method that may return from the acks too
# May be MsgsAck, those are not saved in pending messages # May be MsgsAck, those are not saved in pending messages
@ -627,7 +605,6 @@ class MTProtoSender:
self._send_queue.append(RequestState(MsgsStateInfo( self._send_queue.append(RequestState(MsgsStateInfo(
req_msg_id=message.msg_id, info=chr(1) * len(message.obj.msg_ids)), req_msg_id=message.msg_id, info=chr(1) * len(message.obj.msg_ids)),
loop=self._loop)) loop=self._loop))
self._send_ready.set()
async def _handle_msg_all(self, message): async def _handle_msg_all(self, message):
""" """