Make use of the MTProtoLayer in MTProtoSender

This commit is contained in:
Lonami Exo 2018-09-29 12:20:26 +02:00
parent 9402b4a26d
commit 470fb9f5df
3 changed files with 158 additions and 296 deletions

View File

@ -33,17 +33,16 @@ class MTProtoLayer:
""" """
self._connection.disconnect() self._connection.disconnect()
async def send(self, data_list): async def send(self, state_list):
""" """
A list of serialized RPC queries as bytes must be given to be sent. The list of `RequestState` that will be sent. They will
be updated with their new message and container IDs.
Nested lists imply an order is required for the messages in them. Nested lists imply an order is required for the messages in them.
Message containers will be used if there is more than one item. Message containers will be used if there is more than one item.
Returns ``(container_id, msg_ids)``.
""" """
data, container_id, msg_ids = self._pack_data_list(data_list) data = self._pack_state_list(state_list)
await self._connection.send(self._state.encrypt_message_data(data)) await self._connection.send(self._state.encrypt_message_data(data))
return container_id, msg_ids
async def recv(self): async def recv(self):
""" """
@ -52,12 +51,14 @@ class MTProtoLayer:
body = await self._connection.recv() body = await self._connection.recv()
return self._state.decrypt_message_data(body) return self._state.decrypt_message_data(body)
def _pack_data_list(self, data_list): def _pack_state_list(self, state_list):
""" """
A list of serialized RPC queries as bytes must be given to be packed. The list of `RequestState` that will be sent. They will
Nested lists imply an order is required for the messages in them. be updated with their new message and container IDs.
Returns ``(data, container_id, msg_ids)``. Packs all their serialized data into a message (possibly
nested inside another message and message container) and
returns the serialized message data.
""" """
# TODO write_data_as_message raises on invalid messages, handle it # TODO write_data_as_message raises on invalid messages, handle it
# TODO This method could be an iterator yielding messages while small # TODO This method could be an iterator yielding messages while small
@ -72,33 +73,39 @@ class MTProtoLayer:
# to store and serialize the data. However, to keep the context local # to store and serialize the data. However, to keep the context local
# and relevant to the only place where such feature is actually used, # and relevant to the only place where such feature is actually used,
# this is not done. # this is not done.
msg_ids = [] n = 0
buffer = io.BytesIO() buffer = io.BytesIO()
for data in data_list: for state in state_list:
if not isinstance(data, list): if not isinstance(state, list):
msg_ids.append(self._state.write_data_as_message(buffer, data)) n += 1
state.msg_id = \
self._state.write_data_as_message(buffer, state.data)
else: else:
last_id = None last_id = None
for d in data: for s in state:
last_id = self._state.write_data_as_message( n += 1
buffer, d, after_id=last_id) last_id = s.msg_id = self._state.write_data_as_message(
msg_ids.append(last_id) buffer, s.data, after_id=last_id)
if len(msg_ids) == 1: if n > 1:
container_id = None
else:
# Inlined code to pack several messages into a container # Inlined code to pack several messages into a container
# #
# TODO This part and encrypting data prepend a few bytes but # TODO This part and encrypting data prepend a few bytes but
# force a potentially large payload to be appended, which # force a potentially large payload to be appended, which
# may be expensive. Can we do better? # may be expensive. Can we do better?
data = struct.pack( data = struct.pack(
'<Ii', MessageContainer.CONSTRUCTOR_ID, len(msg_ids) '<Ii', MessageContainer.CONSTRUCTOR_ID, n
) + buffer.getvalue() ) + buffer.getvalue()
buffer = io.BytesIO() buffer = io.BytesIO()
container_id = self._state.write_data_as_message(buffer, data) container_id = self._state.write_data_as_message(buffer, data)
for state in state_list:
if not isinstance(state, list):
state.container_id = container_id
else:
for s in state:
s.container_id = container_id
return buffer.getvalue(), container_id, msg_ids return buffer.getvalue()
def __str__(self): def __str__(self):
return str(self._connection) return str(self._connection)

View File

@ -1,4 +1,5 @@
import asyncio import asyncio
import collections
import logging import logging
from .mtprotolayer import MTProtoLayer from .mtprotolayer import MTProtoLayer
@ -15,16 +16,11 @@ from ..tl.types import (
MsgNewDetailedInfo, NewSessionCreated, MsgDetailedInfo, MsgsStateReq, MsgNewDetailedInfo, NewSessionCreated, MsgDetailedInfo, MsgsStateReq,
MsgsStateInfo, MsgsAllInfo, MsgResendReq, upload MsgsStateInfo, MsgsAllInfo, MsgResendReq, upload
) )
from .requeststate import RequestState
__log__ = logging.getLogger(__name__) __log__ = logging.getLogger(__name__)
# Place this object in the send queue when a reconnection is needed
# so there is an item to read and we can early quit the loop, since
# without this it will block until there's something in the queue.
_reconnect_sentinel = object()
class MTProtoSender: class MTProtoSender:
""" """
MTProto Mobile Protocol sender MTProto Mobile Protocol sender
@ -66,27 +62,21 @@ class MTProtoSender:
self._send_loop_handle = None self._send_loop_handle = None
self._recv_loop_handle = None self._recv_loop_handle = None
# Sending something shouldn't block # Outgoing messages are put in a queue and sent in a batch.
self._send_queue = _ContainerQueue() # Note that here we're also storing their ``_RequestState``.
# Note that it may also store lists (implying order must be kept).
self._send_queue = asyncio.Queue()
# Telegram responds to messages out of order. Keep # Sent states are remembered until a response is received.
# {id: Message} to set their Future result upon arrival. self._pending_state = {}
self._pending_messages = {}
# Containers are accepted or rejected as a whole when any of # Responses must be acknowledged, and we can also batch these.
# its inner requests are acknowledged. For this purpose we
# all the sent containers here.
self._pending_containers = []
# We need to acknowledge every response from Telegram
self._pending_ack = set() self._pending_ack = set()
# Similar to pending_messages but only for the last ack. # Similar to pending_messages but only for the last acknowledges.
# Ack can't be put in the messages because Telegram never # These can't go in pending_messages because no acknowledge for them
# responds to acknowledges (they're just that, acknowledges), # is received, but we may still need to resend their state on bad salts.
# so it would grow to infinite otherwise, but on bad salt it's self._last_acks = collections.deque(maxlen=10)
# necessary to resend them just like everything else.
self._last_ack = None
# Jump table from response ID to method that handles it # Jump table from response ID to method that handles it
self._handlers = { self._handlers = {
@ -136,6 +126,7 @@ class MTProtoSender:
await self._disconnect() await self._disconnect()
# TODO Move this out of the "Public API" section
async def _disconnect(self, error=None): async def _disconnect(self, error=None):
__log__.info('Disconnecting from %s...', self._connection) __log__.info('Disconnecting from %s...', self._connection)
self._user_connected = False self._user_connected = False
@ -144,14 +135,14 @@ class MTProtoSender:
self._connection.disconnect() self._connection.disconnect()
finally: finally:
__log__.debug('Cancelling {} pending message(s)...' __log__.debug('Cancelling {} pending message(s)...'
.format(len(self._pending_messages))) .format(len(self._pending_state)))
for message in self._pending_messages.values(): for state in self._pending_state.values():
if error and not message.future.done(): if error and not state.future.done():
message.future.set_exception(error) state.future.set_exception(error)
else: else:
message.future.cancel() state.future.cancel()
self._pending_messages.clear() self._pending_state.clear()
self._pending_ack.clear() self._pending_ack.clear()
self._last_ack = None self._last_ack = None
@ -172,11 +163,9 @@ class MTProtoSender:
def send(self, request, ordered=False): def send(self, request, ordered=False):
""" """
This method enqueues the given request to be sent. This method enqueues the given request to be sent. Its send
state will be saved until a response arrives, and a ``Future``
The request will be wrapped inside a `TLMessage` until its that will be resolved when the response arrives will be returned:
response arrives, and the `Future` response of the `TLMessage`
is immediately returned so that one can further ``await`` it:
.. code-block:: python .. code-block:: python
@ -198,23 +187,23 @@ class MTProtoSender:
if not self._user_connected: if not self._user_connected:
raise ConnectionError('Cannot send requests while disconnected') raise ConnectionError('Cannot send requests while disconnected')
if utils.is_list_like(request): if not utils.is_list_like(request):
result = [] state = RequestState(request, self._loop)
after = None self._send_queue.put_nowait(state)
for r in request: return state.future
message = self.state.create_message(
r, loop=self._loop, after=after)
self._pending_messages[message.msg_id] = message
self._send_queue.put_nowait(message)
result.append(message.future)
after = ordered and message
return result
else: else:
message = self.state.create_message(request, loop=self._loop) states = []
self._pending_messages[message.msg_id] = message futures = []
self._send_queue.put_nowait(message) for req in request:
return message.future state = RequestState(req, self._loop)
states.append(state)
futures.append(state.future)
if ordered:
self._send_queue.put(states)
else:
for state in states:
self._send_queue.put(state)
return futures
@property @property
def disconnected(self): def disconnected(self):
@ -290,7 +279,6 @@ class MTProtoSender:
Cleanly disconnects and then reconnects. Cleanly disconnects and then reconnects.
""" """
self._reconnecting = True self._reconnecting = True
self._send_queue.put_nowait(_reconnect_sentinel)
__log__.debug('Awaiting for the send loop before reconnecting...') __log__.debug('Awaiting for the send loop before reconnecting...')
await self._send_loop_handle await self._send_loop_handle
@ -307,8 +295,11 @@ class MTProtoSender:
for retry in range(1, retries + 1): for retry in range(1, retries + 1):
try: try:
await self._connect() await self._connect()
# TODO Keep this?
"""
for m in self._pending_messages.values(): for m in self._pending_messages.values():
self._send_queue.put_nowait(m) self._send_queue.put_nowait(m)
"""
if self._auto_reconnect_callback: if self._auto_reconnect_callback:
self._loop.create_task(self._auto_reconnect_callback()) self._loop.create_task(self._auto_reconnect_callback())
@ -325,23 +316,6 @@ class MTProtoSender:
if self._user_connected: if self._user_connected:
self._loop.create_task(self._reconnect()) self._loop.create_task(self._reconnect())
def _clean_containers(self, msg_ids):
"""
Helper method to clean containers from the pending messages
once a wrapped msg_id of them has been acknowledged.
This is the only way we can resend TLMessage(MessageContainer)
on bad notifications and also mark them as received once any
of their inner TLMessage is acknowledged.
"""
for i in reversed(range(len(self._pending_containers))):
message = self._pending_containers[i]
for msg in message.obj.messages:
if msg.msg_id in msg_ids:
del self._pending_containers[i]
del self._pending_messages[message.msg_id]
break
# Loops # Loops
async def _send_loop(self): async def _send_loop(self):
@ -353,67 +327,31 @@ class MTProtoSender:
""" """
while self._user_connected and not self._reconnecting: while self._user_connected and not self._reconnecting:
if self._pending_ack: if self._pending_ack:
self._last_ack = self.state.create_message( ack = RequestState(MsgsAck(list(self._pending_ack)), self._loop)
MsgsAck(list(self._pending_ack)), loop=self._loop self._send_queue.put_nowait(ack)
) self._last_acks.append(ack)
self._send_queue.put_nowait(self._last_ack)
self._pending_ack.clear() self._pending_ack.clear()
messages = await self._send_queue.get() state_list = []
if messages == _reconnect_sentinel:
if self._reconnecting: # TODO wait for the list to have one or for a disconnect to happen
break # and pop while that's the case
state = await self._send_queue.get()
state_list.append(state)
while not self._send_queue.empty():
state_list.append(self._send_queue.get_nowait())
# TODO Debug logs to notify which messages are being sent
# TODO Try sending them while no future was cancelled?
# TODO Handle timeout, cancelled, arbitrary errors
await self._connection.send(state_list)
for state in state_list:
if not isinstance(state, list):
self._pending_state[state.msg_id] = state
else: else:
continue for s in state:
self._pending_state[s.msg_id] = s
if isinstance(messages, list):
message = self.state.create_message(
MessageContainer(messages), loop=self._loop)
self._pending_messages[message.msg_id] = message
self._pending_containers.append(message)
else:
message = messages
messages = [message]
__log__.debug(
'Packing %d outgoing message(s) %s...', len(messages),
', '.join(x.obj.__class__.__name__ for x in messages)
)
body = self.state.pack_message(message)
while not any(m.future.cancelled() for m in messages):
try:
__log__.debug('Sending {} bytes...'.format(len(body)))
await self._connection.send(body)
break
except asyncio.TimeoutError:
continue
except asyncio.CancelledError:
return
except Exception as e:
if isinstance(e, ConnectionError):
__log__.info('Connection reset while sending %s', e)
elif isinstance(e, OSError):
__log__.warning('OSError while sending %s', e)
else:
__log__.exception('Unhandled exception while receiving')
await asyncio.sleep(1, loop=self._loop)
self._start_reconnect()
break
else:
# Remove the cancelled messages from pending
__log__.info('Some futures were cancelled, aborted send')
self._clean_containers([m.msg_id for m in messages])
for m in messages:
if m.future.cancelled():
self._pending_messages.pop(m.msg_id, None)
else:
self._send_queue.put_nowait(m)
__log__.debug('Outgoing messages {} sent!'
.format(', '.join(str(m.msg_id) for m in messages)))
async def _recv_loop(self): async def _recv_loop(self):
""" """
@ -423,68 +361,11 @@ class MTProtoSender:
Besides `connect`, only this method ever receives data. Besides `connect`, only this method ever receives data.
""" """
while self._user_connected and not self._reconnecting: while self._user_connected and not self._reconnecting:
try: # TODO Handle timeout, cancelled, arbitrary, broken auth, buffer,
__log__.debug('Receiving items from the network...') # security and type not found.
body = await self._connection.recv() __log__.debug('Receiving items from the network...')
except asyncio.TimeoutError: message = await self._connection.recv()
continue await self._process_message(message)
except asyncio.CancelledError:
return
except Exception as e:
if isinstance(e, ConnectionError):
__log__.info('Connection reset while receiving %s', e)
elif isinstance(e, OSError):
__log__.warning('OSError while receiving %s', e)
else:
__log__.exception('Unhandled exception while receiving')
await asyncio.sleep(1, loop=self._loop)
self._start_reconnect()
break
# TODO Check salt, session_id and sequence_number
__log__.debug('Decoding packet of %d bytes...', len(body))
try:
message = self.state.unpack_message(body)
except (BrokenAuthKeyError, BufferError) as e:
# The authorization key may be broken if a message was
# sent malformed, or if the authkey truly is corrupted.
#
# There may be a buffer error if Telegram's response was too
# short and hence not understood. Reset the authorization key
# and try again in either case.
#
# TODO Is it possible to detect malformed messages vs
# an actually broken authkey?
__log__.warning('Broken authorization key?: {}'.format(e))
self.state.auth_key = None
self._start_reconnect()
break
except SecurityError as e:
# A step while decoding had the incorrect data. This message
# should not be considered safe and it should be ignored.
__log__.warning('Security error while unpacking a '
'received message: {}'.format(e))
continue
except TypeNotFoundError as e:
# The payload inside the message was not a known TLObject.
__log__.info('Server replied with an unknown type {:08x}: {!r}'
.format(e.invalid_constructor_id, e.remaining))
continue
except asyncio.CancelledError:
return
except Exception as e:
__log__.exception('Unhandled exception while unpacking %s',e)
await asyncio.sleep(1, loop=self._loop)
else:
try:
await self._process_message(message)
except asyncio.CancelledError:
return
except Exception as e:
__log__.exception('Unhandled exception while '
'processing %s', message)
await asyncio.sleep(1, loop=self._loop)
# Response Handlers # Response Handlers
@ -508,11 +389,11 @@ class MTProtoSender:
This is where the future results for sent requests are set. This is where the future results for sent requests are set.
""" """
rpc_result = message.obj rpc_result = message.obj
message = self._pending_messages.pop(rpc_result.req_msg_id, None) state = self._pending_state.pop(rpc_result.req_msg_id, None)
__log__.debug('Handling RPC result for message %d', __log__.debug('Handling RPC result for message %d',
rpc_result.req_msg_id) rpc_result.req_msg_id)
if not message: if not state:
# TODO We should not get responses to things we never sent # TODO We should not get responses to things we never sent
# However receiving a File() with empty bytes is "common". # However receiving a File() with empty bytes is "common".
# See #658, #759 and #958. They seem to happen in a container # See #658, #759 and #958. They seem to happen in a container
@ -528,22 +409,21 @@ class MTProtoSender:
if rpc_result.error: if rpc_result.error:
error = rpc_message_to_error(rpc_result.error) error = rpc_message_to_error(rpc_result.error)
self._send_queue.put_nowait(self.state.create_message( self._send_queue.put_nowait(
MsgsAck([message.msg_id]), loop=self._loop RequestState(MsgsAck([state.msg_id]), loop=self._loop))
))
if not message.future.cancelled(): if not state.future.cancelled():
message.future.set_exception(error) state.future.set_exception(error)
else: else:
# TODO Would be nice to avoid accessing a per-obj read_result # TODO Would be nice to avoid accessing a per-obj read_result
# Instead have a variable that indicated how the result should # Instead have a variable that indicated how the result should
# be read (an enum) and dispatch to read the result, mostly # be read (an enum) and dispatch to read the result, mostly
# always it's just a normal TLObject. # always it's just a normal TLObject.
with BinaryReader(rpc_result.body) as reader: with BinaryReader(rpc_result.body) as reader:
result = message.obj.read_result(reader) result = state.request.read_result(reader)
if not message.future.cancelled(): if not state.future.cancelled():
message.future.set_result(result) state.future.set_result(result)
async def _handle_container(self, message): async def _handle_container(self, message):
""" """
@ -581,9 +461,9 @@ class MTProtoSender:
""" """
pong = message.obj pong = message.obj
__log__.debug('Handling pong for message %d', pong.msg_id) __log__.debug('Handling pong for message %d', pong.msg_id)
message = self._pending_messages.pop(pong.msg_id, None) state = self._pending_state.pop(pong.msg_id, None)
if message: if state:
message.future.set_result(pong) state.future.set_result(pong)
async def _handle_bad_server_salt(self, message): async def _handle_bad_server_salt(self, message):
""" """
@ -595,16 +475,16 @@ class MTProtoSender:
""" """
bad_salt = message.obj bad_salt = message.obj
__log__.debug('Handling bad salt for message %d', bad_salt.bad_msg_id) __log__.debug('Handling bad salt for message %d', bad_salt.bad_msg_id)
self.state.salt = bad_salt.new_server_salt self._connection._state.salt = bad_salt.new_server_salt
if self._last_ack and bad_salt.bad_msg_id == self._last_ack.msg_id:
self._send_queue.put_nowait(self._last_ack)
return
try: try:
self._send_queue.put_nowait( self._send_queue.put_nowait(
self._pending_messages[bad_salt.bad_msg_id]) self._pending_state.pop(bad_salt.bad_msg_id))
except KeyError: except KeyError:
# May be MsgsAck, those are not saved in pending messages for ack in self._pending_ack:
if ack.msg_id == bad_salt.bad_msg_id:
self._send_queue.put_nowait(ack)
return
__log__.info('Message %d not resent due to bad salt', __log__.info('Message %d not resent due to bad salt',
bad_salt.bad_msg_id) bad_salt.bad_msg_id)
@ -617,41 +497,33 @@ class MTProtoSender:
error_code:int = BadMsgNotification; error_code:int = BadMsgNotification;
""" """
bad_msg = message.obj bad_msg = message.obj
msg = self._pending_messages.get(bad_msg.bad_msg_id) # TODO Pending state may need to pop by container ID
state = self._pending_state.pop(bad_msg.bad_msg_id, None)
__log__.debug('Handling bad msg %s', bad_msg) __log__.debug('Handling bad msg %s', bad_msg)
if bad_msg.error_code in (16, 17): if bad_msg.error_code in (16, 17):
# Sent msg_id too low or too high (respectively). # Sent msg_id too low or too high (respectively).
# Use the current msg_id to determine the right time offset. # Use the current msg_id to determine the right time offset.
to = self.state.update_time_offset(correct_msg_id=message.msg_id) to = self._connection._state.update_time_offset(
correct_msg_id=message.msg_id)
__log__.info('System clock is wrong, set time offset to %ds', to) __log__.info('System clock is wrong, set time offset to %ds', to)
# Correct the msg_id *of the message to resend*, not all.
#
# If we correct them all, new "bad message" would not find
# the old invalid IDs, causing all awaits to never finish.
if msg:
del self._pending_messages[msg.msg_id]
self.state.update_message_id(msg)
self._pending_messages[msg.msg_id] = msg
elif bad_msg.error_code == 32: elif bad_msg.error_code == 32:
# msg_seqno too low, so just pump it up by some "large" amount # msg_seqno too low, so just pump it up by some "large" amount
# TODO A better fix would be to start with a new fresh session ID # TODO A better fix would be to start with a new fresh session ID
self.state._sequence += 64 self._connection._state._sequence += 64
elif bad_msg.error_code == 33: elif bad_msg.error_code == 33:
# msg_seqno too high never seems to happen but just in case # msg_seqno too high never seems to happen but just in case
self.state._sequence -= 16 self._connection._state._sequence -= 16
else: else:
if msg: if state:
del self._pending_messages[msg.msg_id] state.future.set_exception(BadMessageError(bad_msg.error_code))
msg.future.set_exception(BadMessageError(bad_msg.error_code))
return return
# Messages are to be re-sent once we've corrected the issue # Messages are to be re-sent once we've corrected the issue
if msg: if state:
self._send_queue.put_nowait(msg) self._send_queue.put_nowait(state)
else: else:
# TODO Generic method that may return from the acks too
# May be MsgsAck, those are not saved in pending messages # May be MsgsAck, those are not saved in pending messages
__log__.info('Message %d not resent due to bad msg', __log__.info('Message %d not resent due to bad msg',
bad_msg.bad_msg_id) bad_msg.bad_msg_id)
@ -689,7 +561,7 @@ class MTProtoSender:
""" """
# TODO https://goo.gl/LMyN7A # TODO https://goo.gl/LMyN7A
__log__.debug('Handling new session created') __log__.debug('Handling new session created')
self.state.salt = message.obj.server_salt self._connection._state.salt = message.obj.server_salt
async def _handle_ack(self, message): async def _handle_ack(self, message):
""" """
@ -708,14 +580,11 @@ class MTProtoSender:
""" """
ack = message.obj ack = message.obj
__log__.debug('Handling acknowledge for %s', str(ack.msg_ids)) __log__.debug('Handling acknowledge for %s', str(ack.msg_ids))
if self._pending_containers:
self._clean_containers(ack.msg_ids)
for msg_id in ack.msg_ids: for msg_id in ack.msg_ids:
msg = self._pending_messages.get(msg_id, None) state = self._pending_state.get(msg_id)
if msg and isinstance(msg.obj, LogOutRequest): if state and isinstance(state.request, LogOutRequest):
del self._pending_messages[msg_id] del self._pending_state[msg_id]
msg.future.set_result(True) state.future.set_result(True)
async def _handle_future_salts(self, message): async def _handle_future_salts(self, message):
""" """
@ -728,52 +597,20 @@ class MTProtoSender:
# TODO save these salts and automatically adjust to the # TODO save these salts and automatically adjust to the
# correct one whenever the salt in use expires. # correct one whenever the salt in use expires.
__log__.debug('Handling future salts for message %d', message.msg_id) __log__.debug('Handling future salts for message %d', message.msg_id)
msg = self._pending_messages.pop(message.msg_id, None) state = self._pending_state.pop(message.msg_id, None)
if msg: if state:
msg.future.set_result(message.obj) state.future.set_result(message.obj)
async def _handle_state_forgotten(self, message): async def _handle_state_forgotten(self, message):
""" """
Handles both :tl:`MsgsStateReq` and :tl:`MsgResendReq` by Handles both :tl:`MsgsStateReq` and :tl:`MsgResendReq` by
enqueuing a :tl:`MsgsStateInfo` to be sent at a later point. enqueuing a :tl:`MsgsStateInfo` to be sent at a later point.
""" """
self.send(MsgsStateInfo(req_msg_id=message.msg_id, self._send_queue.put_nowait(RequestState(MsgsStateInfo(
info=chr(1) * len(message.obj.msg_ids))) req_msg_id=message.msg_id, info=chr(1) * len(message.obj.msg_ids)),
loop=self._loop))
async def _handle_msg_all(self, message): async def _handle_msg_all(self, message):
""" """
Handles :tl:`MsgsAllInfo` by doing nothing (yet). Handles :tl:`MsgsAllInfo` by doing nothing (yet).
""" """
class _ContainerQueue(asyncio.Queue):
"""
An asyncio queue that's aware of `MessageContainer` instances.
The `get` method returns either a single `TLMessage` or a list
of them that should be turned into a new `MessageContainer`.
Instances of this class can be replaced with the simpler
``asyncio.Queue`` when needed for testing purposes, and
a list won't be returned in said case.
"""
async def get(self):
result = await super().get()
if self.empty() or result == _reconnect_sentinel or\
isinstance(result.obj, MessageContainer):
return result
size = result.size()
result = [result]
while not self.empty():
item = self.get_nowait()
if (item == _reconnect_sentinel or
isinstance(item.obj, MessageContainer)
or size + item.size() > MessageContainer.MAXIMUM_SIZE):
self.put_nowait(item)
break
else:
size += item.size()
result.append(item)
return result

View File

@ -0,0 +1,18 @@
import asyncio
class RequestState:
"""
This request state holds several information relevant to sent messages,
in particular the message ID assigned to the request, the container ID
it belongs to, the request itself, the request as bytes, and the future
result that will eventually be resolved.
"""
__slots__ = ('container_id', 'msg_id', 'request', 'data', 'future')
def __init__(self, request, loop):
self.container_id = None
self.msg_id = None
self.request = request
self.data = bytes(request)
self.future = asyncio.Future(loop=loop)