Make use of the MTProtoLayer in MTProtoSender

This commit is contained in:
Lonami Exo 2018-09-29 12:20:26 +02:00
parent 9402b4a26d
commit 470fb9f5df
3 changed files with 158 additions and 296 deletions

View File

@ -33,17 +33,16 @@ class MTProtoLayer:
"""
self._connection.disconnect()
async def send(self, data_list):
async def send(self, state_list):
"""
A list of serialized RPC queries as bytes must be given to be sent.
The list of `RequestState` that will be sent. They will
be updated with their new message and container IDs.
Nested lists imply an order is required for the messages in them.
Message containers will be used if there is more than one item.
Returns ``(container_id, msg_ids)``.
"""
data, container_id, msg_ids = self._pack_data_list(data_list)
data = self._pack_state_list(state_list)
await self._connection.send(self._state.encrypt_message_data(data))
return container_id, msg_ids
async def recv(self):
"""
@ -52,12 +51,14 @@ class MTProtoLayer:
body = await self._connection.recv()
return self._state.decrypt_message_data(body)
def _pack_data_list(self, data_list):
def _pack_state_list(self, state_list):
"""
A list of serialized RPC queries as bytes must be given to be packed.
Nested lists imply an order is required for the messages in them.
The list of `RequestState` that will be sent. They will
be updated with their new message and container IDs.
Returns ``(data, container_id, msg_ids)``.
Packs all their serialized data into a message (possibly
nested inside another message and message container) and
returns the serialized message data.
"""
# TODO write_data_as_message raises on invalid messages, handle it
# TODO This method could be an iterator yielding messages while small
@ -72,33 +73,39 @@ class MTProtoLayer:
# to store and serialize the data. However, to keep the context local
# and relevant to the only place where such feature is actually used,
# this is not done.
msg_ids = []
n = 0
buffer = io.BytesIO()
for data in data_list:
if not isinstance(data, list):
msg_ids.append(self._state.write_data_as_message(buffer, data))
for state in state_list:
if not isinstance(state, list):
n += 1
state.msg_id = \
self._state.write_data_as_message(buffer, state.data)
else:
last_id = None
for d in data:
last_id = self._state.write_data_as_message(
buffer, d, after_id=last_id)
msg_ids.append(last_id)
for s in state:
n += 1
last_id = s.msg_id = self._state.write_data_as_message(
buffer, s.data, after_id=last_id)
if len(msg_ids) == 1:
container_id = None
else:
if n > 1:
# Inlined code to pack several messages into a container
#
# TODO This part and encrypting data prepend a few bytes but
# force a potentially large payload to be appended, which
# may be expensive. Can we do better?
data = struct.pack(
'<Ii', MessageContainer.CONSTRUCTOR_ID, len(msg_ids)
'<Ii', MessageContainer.CONSTRUCTOR_ID, n
) + buffer.getvalue()
buffer = io.BytesIO()
container_id = self._state.write_data_as_message(buffer, data)
for state in state_list:
if not isinstance(state, list):
state.container_id = container_id
else:
for s in state:
s.container_id = container_id
return buffer.getvalue(), container_id, msg_ids
return buffer.getvalue()
def __str__(self):
return str(self._connection)

View File

@ -1,4 +1,5 @@
import asyncio
import collections
import logging
from .mtprotolayer import MTProtoLayer
@ -15,16 +16,11 @@ from ..tl.types import (
MsgNewDetailedInfo, NewSessionCreated, MsgDetailedInfo, MsgsStateReq,
MsgsStateInfo, MsgsAllInfo, MsgResendReq, upload
)
from .requeststate import RequestState
__log__ = logging.getLogger(__name__)
# Place this object in the send queue when a reconnection is needed
# so there is an item to read and we can early quit the loop, since
# without this it will block until there's something in the queue.
_reconnect_sentinel = object()
class MTProtoSender:
"""
MTProto Mobile Protocol sender
@ -66,27 +62,21 @@ class MTProtoSender:
self._send_loop_handle = None
self._recv_loop_handle = None
# Sending something shouldn't block
self._send_queue = _ContainerQueue()
# Outgoing messages are put in a queue and sent in a batch.
# Note that here we're also storing their ``_RequestState``.
# Note that it may also store lists (implying order must be kept).
self._send_queue = asyncio.Queue()
# Telegram responds to messages out of order. Keep
# {id: Message} to set their Future result upon arrival.
self._pending_messages = {}
# Sent states are remembered until a response is received.
self._pending_state = {}
# Containers are accepted or rejected as a whole when any of
# its inner requests are acknowledged. For this purpose we
# all the sent containers here.
self._pending_containers = []
# We need to acknowledge every response from Telegram
# Responses must be acknowledged, and we can also batch these.
self._pending_ack = set()
# Similar to pending_messages but only for the last ack.
# Ack can't be put in the messages because Telegram never
# responds to acknowledges (they're just that, acknowledges),
# so it would grow to infinite otherwise, but on bad salt it's
# necessary to resend them just like everything else.
self._last_ack = None
# Similar to pending_messages but only for the last acknowledges.
# These can't go in pending_messages because no acknowledge for them
# is received, but we may still need to resend their state on bad salts.
self._last_acks = collections.deque(maxlen=10)
# Jump table from response ID to method that handles it
self._handlers = {
@ -136,6 +126,7 @@ class MTProtoSender:
await self._disconnect()
# TODO Move this out of the "Public API" section
async def _disconnect(self, error=None):
__log__.info('Disconnecting from %s...', self._connection)
self._user_connected = False
@ -144,14 +135,14 @@ class MTProtoSender:
self._connection.disconnect()
finally:
__log__.debug('Cancelling {} pending message(s)...'
.format(len(self._pending_messages)))
for message in self._pending_messages.values():
if error and not message.future.done():
message.future.set_exception(error)
.format(len(self._pending_state)))
for state in self._pending_state.values():
if error and not state.future.done():
state.future.set_exception(error)
else:
message.future.cancel()
state.future.cancel()
self._pending_messages.clear()
self._pending_state.clear()
self._pending_ack.clear()
self._last_ack = None
@ -172,11 +163,9 @@ class MTProtoSender:
def send(self, request, ordered=False):
"""
This method enqueues the given request to be sent.
The request will be wrapped inside a `TLMessage` until its
response arrives, and the `Future` response of the `TLMessage`
is immediately returned so that one can further ``await`` it:
This method enqueues the given request to be sent. Its send
state will be saved until a response arrives, and a ``Future``
that will be resolved when the response arrives will be returned:
.. code-block:: python
@ -198,23 +187,23 @@ class MTProtoSender:
if not self._user_connected:
raise ConnectionError('Cannot send requests while disconnected')
if utils.is_list_like(request):
result = []
after = None
for r in request:
message = self.state.create_message(
r, loop=self._loop, after=after)
self._pending_messages[message.msg_id] = message
self._send_queue.put_nowait(message)
result.append(message.future)
after = ordered and message
return result
if not utils.is_list_like(request):
state = RequestState(request, self._loop)
self._send_queue.put_nowait(state)
return state.future
else:
message = self.state.create_message(request, loop=self._loop)
self._pending_messages[message.msg_id] = message
self._send_queue.put_nowait(message)
return message.future
states = []
futures = []
for req in request:
state = RequestState(req, self._loop)
states.append(state)
futures.append(state.future)
if ordered:
self._send_queue.put(states)
else:
for state in states:
self._send_queue.put(state)
return futures
@property
def disconnected(self):
@ -290,7 +279,6 @@ class MTProtoSender:
Cleanly disconnects and then reconnects.
"""
self._reconnecting = True
self._send_queue.put_nowait(_reconnect_sentinel)
__log__.debug('Awaiting for the send loop before reconnecting...')
await self._send_loop_handle
@ -307,8 +295,11 @@ class MTProtoSender:
for retry in range(1, retries + 1):
try:
await self._connect()
# TODO Keep this?
"""
for m in self._pending_messages.values():
self._send_queue.put_nowait(m)
"""
if self._auto_reconnect_callback:
self._loop.create_task(self._auto_reconnect_callback())
@ -325,23 +316,6 @@ class MTProtoSender:
if self._user_connected:
self._loop.create_task(self._reconnect())
def _clean_containers(self, msg_ids):
"""
Helper method to clean containers from the pending messages
once a wrapped msg_id of them has been acknowledged.
This is the only way we can resend TLMessage(MessageContainer)
on bad notifications and also mark them as received once any
of their inner TLMessage is acknowledged.
"""
for i in reversed(range(len(self._pending_containers))):
message = self._pending_containers[i]
for msg in message.obj.messages:
if msg.msg_id in msg_ids:
del self._pending_containers[i]
del self._pending_messages[message.msg_id]
break
# Loops
async def _send_loop(self):
@ -353,67 +327,31 @@ class MTProtoSender:
"""
while self._user_connected and not self._reconnecting:
if self._pending_ack:
self._last_ack = self.state.create_message(
MsgsAck(list(self._pending_ack)), loop=self._loop
)
self._send_queue.put_nowait(self._last_ack)
ack = RequestState(MsgsAck(list(self._pending_ack)), self._loop)
self._send_queue.put_nowait(ack)
self._last_acks.append(ack)
self._pending_ack.clear()
messages = await self._send_queue.get()
if messages == _reconnect_sentinel:
if self._reconnecting:
break
state_list = []
# TODO wait for the list to have one or for a disconnect to happen
# and pop while that's the case
state = await self._send_queue.get()
state_list.append(state)
while not self._send_queue.empty():
state_list.append(self._send_queue.get_nowait())
# TODO Debug logs to notify which messages are being sent
# TODO Try sending them while no future was cancelled?
# TODO Handle timeout, cancelled, arbitrary errors
await self._connection.send(state_list)
for state in state_list:
if not isinstance(state, list):
self._pending_state[state.msg_id] = state
else:
continue
if isinstance(messages, list):
message = self.state.create_message(
MessageContainer(messages), loop=self._loop)
self._pending_messages[message.msg_id] = message
self._pending_containers.append(message)
else:
message = messages
messages = [message]
__log__.debug(
'Packing %d outgoing message(s) %s...', len(messages),
', '.join(x.obj.__class__.__name__ for x in messages)
)
body = self.state.pack_message(message)
while not any(m.future.cancelled() for m in messages):
try:
__log__.debug('Sending {} bytes...'.format(len(body)))
await self._connection.send(body)
break
except asyncio.TimeoutError:
continue
except asyncio.CancelledError:
return
except Exception as e:
if isinstance(e, ConnectionError):
__log__.info('Connection reset while sending %s', e)
elif isinstance(e, OSError):
__log__.warning('OSError while sending %s', e)
else:
__log__.exception('Unhandled exception while receiving')
await asyncio.sleep(1, loop=self._loop)
self._start_reconnect()
break
else:
# Remove the cancelled messages from pending
__log__.info('Some futures were cancelled, aborted send')
self._clean_containers([m.msg_id for m in messages])
for m in messages:
if m.future.cancelled():
self._pending_messages.pop(m.msg_id, None)
else:
self._send_queue.put_nowait(m)
__log__.debug('Outgoing messages {} sent!'
.format(', '.join(str(m.msg_id) for m in messages)))
for s in state:
self._pending_state[s.msg_id] = s
async def _recv_loop(self):
"""
@ -423,68 +361,11 @@ class MTProtoSender:
Besides `connect`, only this method ever receives data.
"""
while self._user_connected and not self._reconnecting:
try:
__log__.debug('Receiving items from the network...')
body = await self._connection.recv()
except asyncio.TimeoutError:
continue
except asyncio.CancelledError:
return
except Exception as e:
if isinstance(e, ConnectionError):
__log__.info('Connection reset while receiving %s', e)
elif isinstance(e, OSError):
__log__.warning('OSError while receiving %s', e)
else:
__log__.exception('Unhandled exception while receiving')
await asyncio.sleep(1, loop=self._loop)
self._start_reconnect()
break
# TODO Check salt, session_id and sequence_number
__log__.debug('Decoding packet of %d bytes...', len(body))
try:
message = self.state.unpack_message(body)
except (BrokenAuthKeyError, BufferError) as e:
# The authorization key may be broken if a message was
# sent malformed, or if the authkey truly is corrupted.
#
# There may be a buffer error if Telegram's response was too
# short and hence not understood. Reset the authorization key
# and try again in either case.
#
# TODO Is it possible to detect malformed messages vs
# an actually broken authkey?
__log__.warning('Broken authorization key?: {}'.format(e))
self.state.auth_key = None
self._start_reconnect()
break
except SecurityError as e:
# A step while decoding had the incorrect data. This message
# should not be considered safe and it should be ignored.
__log__.warning('Security error while unpacking a '
'received message: {}'.format(e))
continue
except TypeNotFoundError as e:
# The payload inside the message was not a known TLObject.
__log__.info('Server replied with an unknown type {:08x}: {!r}'
.format(e.invalid_constructor_id, e.remaining))
continue
except asyncio.CancelledError:
return
except Exception as e:
__log__.exception('Unhandled exception while unpacking %s',e)
await asyncio.sleep(1, loop=self._loop)
else:
try:
await self._process_message(message)
except asyncio.CancelledError:
return
except Exception as e:
__log__.exception('Unhandled exception while '
'processing %s', message)
await asyncio.sleep(1, loop=self._loop)
# TODO Handle timeout, cancelled, arbitrary, broken auth, buffer,
# security and type not found.
__log__.debug('Receiving items from the network...')
message = await self._connection.recv()
await self._process_message(message)
# Response Handlers
@ -508,11 +389,11 @@ class MTProtoSender:
This is where the future results for sent requests are set.
"""
rpc_result = message.obj
message = self._pending_messages.pop(rpc_result.req_msg_id, None)
state = self._pending_state.pop(rpc_result.req_msg_id, None)
__log__.debug('Handling RPC result for message %d',
rpc_result.req_msg_id)
if not message:
if not state:
# TODO We should not get responses to things we never sent
# However receiving a File() with empty bytes is "common".
# See #658, #759 and #958. They seem to happen in a container
@ -528,22 +409,21 @@ class MTProtoSender:
if rpc_result.error:
error = rpc_message_to_error(rpc_result.error)
self._send_queue.put_nowait(self.state.create_message(
MsgsAck([message.msg_id]), loop=self._loop
))
self._send_queue.put_nowait(
RequestState(MsgsAck([state.msg_id]), loop=self._loop))
if not message.future.cancelled():
message.future.set_exception(error)
if not state.future.cancelled():
state.future.set_exception(error)
else:
# TODO Would be nice to avoid accessing a per-obj read_result
# Instead have a variable that indicated how the result should
# be read (an enum) and dispatch to read the result, mostly
# always it's just a normal TLObject.
with BinaryReader(rpc_result.body) as reader:
result = message.obj.read_result(reader)
result = state.request.read_result(reader)
if not message.future.cancelled():
message.future.set_result(result)
if not state.future.cancelled():
state.future.set_result(result)
async def _handle_container(self, message):
"""
@ -581,9 +461,9 @@ class MTProtoSender:
"""
pong = message.obj
__log__.debug('Handling pong for message %d', pong.msg_id)
message = self._pending_messages.pop(pong.msg_id, None)
if message:
message.future.set_result(pong)
state = self._pending_state.pop(pong.msg_id, None)
if state:
state.future.set_result(pong)
async def _handle_bad_server_salt(self, message):
"""
@ -595,16 +475,16 @@ class MTProtoSender:
"""
bad_salt = message.obj
__log__.debug('Handling bad salt for message %d', bad_salt.bad_msg_id)
self.state.salt = bad_salt.new_server_salt
if self._last_ack and bad_salt.bad_msg_id == self._last_ack.msg_id:
self._send_queue.put_nowait(self._last_ack)
return
self._connection._state.salt = bad_salt.new_server_salt
try:
self._send_queue.put_nowait(
self._pending_messages[bad_salt.bad_msg_id])
self._pending_state.pop(bad_salt.bad_msg_id))
except KeyError:
# May be MsgsAck, those are not saved in pending messages
for ack in self._pending_ack:
if ack.msg_id == bad_salt.bad_msg_id:
self._send_queue.put_nowait(ack)
return
__log__.info('Message %d not resent due to bad salt',
bad_salt.bad_msg_id)
@ -617,41 +497,33 @@ class MTProtoSender:
error_code:int = BadMsgNotification;
"""
bad_msg = message.obj
msg = self._pending_messages.get(bad_msg.bad_msg_id)
# TODO Pending state may need to pop by container ID
state = self._pending_state.pop(bad_msg.bad_msg_id, None)
__log__.debug('Handling bad msg %s', bad_msg)
if bad_msg.error_code in (16, 17):
# Sent msg_id too low or too high (respectively).
# Use the current msg_id to determine the right time offset.
to = self.state.update_time_offset(correct_msg_id=message.msg_id)
to = self._connection._state.update_time_offset(
correct_msg_id=message.msg_id)
__log__.info('System clock is wrong, set time offset to %ds', to)
# Correct the msg_id *of the message to resend*, not all.
#
# If we correct them all, new "bad message" would not find
# the old invalid IDs, causing all awaits to never finish.
if msg:
del self._pending_messages[msg.msg_id]
self.state.update_message_id(msg)
self._pending_messages[msg.msg_id] = msg
elif bad_msg.error_code == 32:
# msg_seqno too low, so just pump it up by some "large" amount
# TODO A better fix would be to start with a new fresh session ID
self.state._sequence += 64
self._connection._state._sequence += 64
elif bad_msg.error_code == 33:
# msg_seqno too high never seems to happen but just in case
self.state._sequence -= 16
self._connection._state._sequence -= 16
else:
if msg:
del self._pending_messages[msg.msg_id]
msg.future.set_exception(BadMessageError(bad_msg.error_code))
if state:
state.future.set_exception(BadMessageError(bad_msg.error_code))
return
# Messages are to be re-sent once we've corrected the issue
if msg:
self._send_queue.put_nowait(msg)
if state:
self._send_queue.put_nowait(state)
else:
# TODO Generic method that may return from the acks too
# May be MsgsAck, those are not saved in pending messages
__log__.info('Message %d not resent due to bad msg',
bad_msg.bad_msg_id)
@ -689,7 +561,7 @@ class MTProtoSender:
"""
# TODO https://goo.gl/LMyN7A
__log__.debug('Handling new session created')
self.state.salt = message.obj.server_salt
self._connection._state.salt = message.obj.server_salt
async def _handle_ack(self, message):
"""
@ -708,14 +580,11 @@ class MTProtoSender:
"""
ack = message.obj
__log__.debug('Handling acknowledge for %s', str(ack.msg_ids))
if self._pending_containers:
self._clean_containers(ack.msg_ids)
for msg_id in ack.msg_ids:
msg = self._pending_messages.get(msg_id, None)
if msg and isinstance(msg.obj, LogOutRequest):
del self._pending_messages[msg_id]
msg.future.set_result(True)
state = self._pending_state.get(msg_id)
if state and isinstance(state.request, LogOutRequest):
del self._pending_state[msg_id]
state.future.set_result(True)
async def _handle_future_salts(self, message):
"""
@ -728,52 +597,20 @@ class MTProtoSender:
# TODO save these salts and automatically adjust to the
# correct one whenever the salt in use expires.
__log__.debug('Handling future salts for message %d', message.msg_id)
msg = self._pending_messages.pop(message.msg_id, None)
if msg:
msg.future.set_result(message.obj)
state = self._pending_state.pop(message.msg_id, None)
if state:
state.future.set_result(message.obj)
async def _handle_state_forgotten(self, message):
"""
Handles both :tl:`MsgsStateReq` and :tl:`MsgResendReq` by
enqueuing a :tl:`MsgsStateInfo` to be sent at a later point.
"""
self.send(MsgsStateInfo(req_msg_id=message.msg_id,
info=chr(1) * len(message.obj.msg_ids)))
self._send_queue.put_nowait(RequestState(MsgsStateInfo(
req_msg_id=message.msg_id, info=chr(1) * len(message.obj.msg_ids)),
loop=self._loop))
async def _handle_msg_all(self, message):
"""
Handles :tl:`MsgsAllInfo` by doing nothing (yet).
"""
class _ContainerQueue(asyncio.Queue):
"""
An asyncio queue that's aware of `MessageContainer` instances.
The `get` method returns either a single `TLMessage` or a list
of them that should be turned into a new `MessageContainer`.
Instances of this class can be replaced with the simpler
``asyncio.Queue`` when needed for testing purposes, and
a list won't be returned in said case.
"""
async def get(self):
result = await super().get()
if self.empty() or result == _reconnect_sentinel or\
isinstance(result.obj, MessageContainer):
return result
size = result.size()
result = [result]
while not self.empty():
item = self.get_nowait()
if (item == _reconnect_sentinel or
isinstance(item.obj, MessageContainer)
or size + item.size() > MessageContainer.MAXIMUM_SIZE):
self.put_nowait(item)
break
else:
size += item.size()
result.append(item)
return result

View File

@ -0,0 +1,18 @@
import asyncio
class RequestState:
"""
This request state holds several information relevant to sent messages,
in particular the message ID assigned to the request, the container ID
it belongs to, the request itself, the request as bytes, and the future
result that will eventually be resolved.
"""
__slots__ = ('container_id', 'msg_id', 'request', 'data', 'future')
def __init__(self, request, loop):
self.container_id = None
self.msg_id = None
self.request = request
self.data = bytes(request)
self.future = asyncio.Future(loop=loop)