mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2024-11-25 19:03:46 +03:00
Abstract the send queue off MTProtoSender
This commit is contained in:
parent
b02ebcb69b
commit
105bd52eee
|
@ -1,4 +1,5 @@
|
||||||
"""Various helpers not related to the Telegram API itself"""
|
"""Various helpers not related to the Telegram API itself"""
|
||||||
|
import asyncio
|
||||||
import collections
|
import collections
|
||||||
import os
|
import os
|
||||||
import struct
|
import struct
|
||||||
|
@ -87,4 +88,42 @@ class TotalList(list):
|
||||||
return '[{}, total={}]'.format(
|
return '[{}, total={}]'.format(
|
||||||
', '.join(repr(x) for x in self), self.total)
|
', '.join(repr(x) for x in self), self.total)
|
||||||
|
|
||||||
|
|
||||||
|
class _ReadyQueue:
|
||||||
|
"""
|
||||||
|
A queue list that supports an arbitrary cancellation token for `get`.
|
||||||
|
"""
|
||||||
|
def __init__(self, loop):
|
||||||
|
self._list = []
|
||||||
|
self._loop = loop
|
||||||
|
self._ready = asyncio.Event(loop=loop)
|
||||||
|
|
||||||
|
def append(self, item):
|
||||||
|
self._list.append(item)
|
||||||
|
self._ready.set()
|
||||||
|
|
||||||
|
def extend(self, items):
|
||||||
|
self._list.extend(items)
|
||||||
|
self._ready.set()
|
||||||
|
|
||||||
|
async def get(self, cancellation):
|
||||||
|
"""
|
||||||
|
Returns a list of all the items added to the queue until now and
|
||||||
|
clears the list from the queue itself. Returns ``None`` if cancelled.
|
||||||
|
"""
|
||||||
|
ready = asyncio.ensure_future(self._ready.wait(), loop=self._loop)
|
||||||
|
done, pending = await asyncio.wait(
|
||||||
|
[ready, cancellation],
|
||||||
|
return_when=asyncio.FIRST_COMPLETED,
|
||||||
|
loop=self._loop
|
||||||
|
)
|
||||||
|
if cancellation in done:
|
||||||
|
ready.cancel()
|
||||||
|
return None
|
||||||
|
|
||||||
|
result = self._list
|
||||||
|
self._list = []
|
||||||
|
self._ready.clear()
|
||||||
|
return result
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
|
@ -21,6 +21,7 @@ class Connection(abc.ABC):
|
||||||
self._writer = None
|
self._writer = None
|
||||||
self._disconnected = asyncio.Event(loop=loop)
|
self._disconnected = asyncio.Event(loop=loop)
|
||||||
self._disconnected.set()
|
self._disconnected.set()
|
||||||
|
self._disconnected_future = None
|
||||||
self._send_task = None
|
self._send_task = None
|
||||||
self._recv_task = None
|
self._recv_task = None
|
||||||
self._send_queue = asyncio.Queue(1)
|
self._send_queue = asyncio.Queue(1)
|
||||||
|
@ -34,6 +35,7 @@ class Connection(abc.ABC):
|
||||||
self._ip, self._port, loop=self._loop)
|
self._ip, self._port, loop=self._loop)
|
||||||
|
|
||||||
self._disconnected.clear()
|
self._disconnected.clear()
|
||||||
|
self._disconnected_future = None
|
||||||
self._send_task = self._loop.create_task(self._send_loop())
|
self._send_task = self._loop.create_task(self._send_loop())
|
||||||
self._recv_task = self._loop.create_task(self._recv_loop())
|
self._recv_task = self._loop.create_task(self._recv_loop())
|
||||||
|
|
||||||
|
@ -46,6 +48,13 @@ class Connection(abc.ABC):
|
||||||
self._recv_task.cancel()
|
self._recv_task.cancel()
|
||||||
self._writer.close()
|
self._writer.close()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def disconnected(self):
|
||||||
|
if not self._disconnected_future:
|
||||||
|
self._disconnected_future = asyncio.ensure_future(
|
||||||
|
self._disconnected.wait(), loop=self._loop)
|
||||||
|
return self._disconnected_future
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
"""
|
"""
|
||||||
Creates a clone of the connection.
|
Creates a clone of the connection.
|
||||||
|
|
|
@ -9,6 +9,7 @@ from ..errors import (
|
||||||
BadMessageError, TypeNotFoundError, rpc_message_to_error
|
BadMessageError, TypeNotFoundError, rpc_message_to_error
|
||||||
)
|
)
|
||||||
from ..extensions import BinaryReader
|
from ..extensions import BinaryReader
|
||||||
|
from ..helpers import _ReadyQueue
|
||||||
from ..tl.core import RpcResult, MessageContainer, GzipPacked
|
from ..tl.core import RpcResult, MessageContainer, GzipPacked
|
||||||
from ..tl.functions.auth import LogOutRequest
|
from ..tl.functions.auth import LogOutRequest
|
||||||
from ..tl.types import (
|
from ..tl.types import (
|
||||||
|
@ -64,10 +65,7 @@ class MTProtoSender:
|
||||||
# Outgoing messages are put in a queue and sent in a batch.
|
# Outgoing messages are put in a queue and sent in a batch.
|
||||||
# Note that here we're also storing their ``_RequestState``.
|
# Note that here we're also storing their ``_RequestState``.
|
||||||
# Note that it may also store lists (implying order must be kept).
|
# Note that it may also store lists (implying order must be kept).
|
||||||
#
|
self._send_queue = _ReadyQueue(self._loop)
|
||||||
# TODO Abstract this queue away?
|
|
||||||
self._send_queue = []
|
|
||||||
self._send_ready = asyncio.Event(loop=self._loop)
|
|
||||||
|
|
||||||
# Sent states are remembered until a response is received.
|
# Sent states are remembered until a response is received.
|
||||||
self._pending_state = {}
|
self._pending_state = {}
|
||||||
|
@ -192,7 +190,6 @@ class MTProtoSender:
|
||||||
if not utils.is_list_like(request):
|
if not utils.is_list_like(request):
|
||||||
state = RequestState(request, self._loop)
|
state = RequestState(request, self._loop)
|
||||||
self._send_queue.append(state)
|
self._send_queue.append(state)
|
||||||
self._send_ready.set()
|
|
||||||
return state.future
|
return state.future
|
||||||
else:
|
else:
|
||||||
states = []
|
states = []
|
||||||
|
@ -206,7 +203,6 @@ class MTProtoSender:
|
||||||
else:
|
else:
|
||||||
self._send_queue.extend(states)
|
self._send_queue.extend(states)
|
||||||
|
|
||||||
self._send_ready.set()
|
|
||||||
return futures
|
return futures
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -333,29 +329,15 @@ class MTProtoSender:
|
||||||
if self._pending_ack:
|
if self._pending_ack:
|
||||||
ack = RequestState(MsgsAck(list(self._pending_ack)), self._loop)
|
ack = RequestState(MsgsAck(list(self._pending_ack)), self._loop)
|
||||||
self._send_queue.append(ack)
|
self._send_queue.append(ack)
|
||||||
self._send_ready.set()
|
|
||||||
self._last_acks.append(ack)
|
self._last_acks.append(ack)
|
||||||
self._pending_ack.clear()
|
self._pending_ack.clear()
|
||||||
|
|
||||||
queue = asyncio.ensure_future(
|
state_list = await self._send_queue.get(
|
||||||
self._send_ready.wait(), loop=self._loop)
|
self._connection._connection.disconnected)
|
||||||
|
|
||||||
disconnected = asyncio.ensure_future(
|
if state_list is None:
|
||||||
self._connection._connection._disconnected.wait())
|
|
||||||
|
|
||||||
# Basically using the disconnected as a cancellation token
|
|
||||||
done, pending = await asyncio.wait(
|
|
||||||
[queue, disconnected],
|
|
||||||
return_when=asyncio.FIRST_COMPLETED,
|
|
||||||
loop=self._loop
|
|
||||||
)
|
|
||||||
if disconnected in done:
|
|
||||||
break
|
break
|
||||||
|
|
||||||
state_list = self._send_queue
|
|
||||||
self._send_queue = []
|
|
||||||
self._send_ready.clear()
|
|
||||||
|
|
||||||
# TODO Debug logs to notify which messages are being sent
|
# TODO Debug logs to notify which messages are being sent
|
||||||
# TODO Try sending them while no future was cancelled?
|
# TODO Try sending them while no future was cancelled?
|
||||||
# TODO Handle timeout, cancelled, arbitrary errors
|
# TODO Handle timeout, cancelled, arbitrary errors
|
||||||
|
@ -425,7 +407,6 @@ class MTProtoSender:
|
||||||
error = rpc_message_to_error(rpc_result.error)
|
error = rpc_message_to_error(rpc_result.error)
|
||||||
self._send_queue.append(
|
self._send_queue.append(
|
||||||
RequestState(MsgsAck([state.msg_id]), loop=self._loop))
|
RequestState(MsgsAck([state.msg_id]), loop=self._loop))
|
||||||
self._send_ready.set()
|
|
||||||
|
|
||||||
if not state.future.cancelled():
|
if not state.future.cancelled():
|
||||||
state.future.set_exception(error)
|
state.future.set_exception(error)
|
||||||
|
@ -494,12 +475,10 @@ class MTProtoSender:
|
||||||
try:
|
try:
|
||||||
self._send_queue.append(
|
self._send_queue.append(
|
||||||
self._pending_state.pop(bad_salt.bad_msg_id))
|
self._pending_state.pop(bad_salt.bad_msg_id))
|
||||||
self._send_ready.set()
|
|
||||||
except KeyError:
|
except KeyError:
|
||||||
for ack in self._pending_ack:
|
for ack in self._pending_ack:
|
||||||
if ack.msg_id == bad_salt.bad_msg_id:
|
if ack.msg_id == bad_salt.bad_msg_id:
|
||||||
self._send_queue.append(ack)
|
self._send_queue.append(ack)
|
||||||
self._send_ready.set()
|
|
||||||
return
|
return
|
||||||
|
|
||||||
__log__.info('Message %d not resent due to bad salt',
|
__log__.info('Message %d not resent due to bad salt',
|
||||||
|
@ -539,7 +518,6 @@ class MTProtoSender:
|
||||||
# Messages are to be re-sent once we've corrected the issue
|
# Messages are to be re-sent once we've corrected the issue
|
||||||
if state:
|
if state:
|
||||||
self._send_queue.append(state)
|
self._send_queue.append(state)
|
||||||
self._send_ready.set()
|
|
||||||
else:
|
else:
|
||||||
# TODO Generic method that may return from the acks too
|
# TODO Generic method that may return from the acks too
|
||||||
# May be MsgsAck, those are not saved in pending messages
|
# May be MsgsAck, those are not saved in pending messages
|
||||||
|
@ -627,7 +605,6 @@ class MTProtoSender:
|
||||||
self._send_queue.append(RequestState(MsgsStateInfo(
|
self._send_queue.append(RequestState(MsgsStateInfo(
|
||||||
req_msg_id=message.msg_id, info=chr(1) * len(message.obj.msg_ids)),
|
req_msg_id=message.msg_id, info=chr(1) * len(message.obj.msg_ids)),
|
||||||
loop=self._loop))
|
loop=self._loop))
|
||||||
self._send_ready.set()
|
|
||||||
|
|
||||||
async def _handle_msg_all(self, message):
|
async def _handle_msg_all(self, message):
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue
Block a user