mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-02-04 21:50:57 +03:00
Make use of the MTProtoLayer in MTProtoSender
This commit is contained in:
parent
9402b4a26d
commit
470fb9f5df
|
@ -33,17 +33,16 @@ class MTProtoLayer:
|
||||||
"""
|
"""
|
||||||
self._connection.disconnect()
|
self._connection.disconnect()
|
||||||
|
|
||||||
async def send(self, data_list):
|
async def send(self, state_list):
|
||||||
"""
|
"""
|
||||||
A list of serialized RPC queries as bytes must be given to be sent.
|
The list of `RequestState` that will be sent. They will
|
||||||
|
be updated with their new message and container IDs.
|
||||||
|
|
||||||
Nested lists imply an order is required for the messages in them.
|
Nested lists imply an order is required for the messages in them.
|
||||||
Message containers will be used if there is more than one item.
|
Message containers will be used if there is more than one item.
|
||||||
|
|
||||||
Returns ``(container_id, msg_ids)``.
|
|
||||||
"""
|
"""
|
||||||
data, container_id, msg_ids = self._pack_data_list(data_list)
|
data = self._pack_state_list(state_list)
|
||||||
await self._connection.send(self._state.encrypt_message_data(data))
|
await self._connection.send(self._state.encrypt_message_data(data))
|
||||||
return container_id, msg_ids
|
|
||||||
|
|
||||||
async def recv(self):
|
async def recv(self):
|
||||||
"""
|
"""
|
||||||
|
@ -52,12 +51,14 @@ class MTProtoLayer:
|
||||||
body = await self._connection.recv()
|
body = await self._connection.recv()
|
||||||
return self._state.decrypt_message_data(body)
|
return self._state.decrypt_message_data(body)
|
||||||
|
|
||||||
def _pack_data_list(self, data_list):
|
def _pack_state_list(self, state_list):
|
||||||
"""
|
"""
|
||||||
A list of serialized RPC queries as bytes must be given to be packed.
|
The list of `RequestState` that will be sent. They will
|
||||||
Nested lists imply an order is required for the messages in them.
|
be updated with their new message and container IDs.
|
||||||
|
|
||||||
Returns ``(data, container_id, msg_ids)``.
|
Packs all their serialized data into a message (possibly
|
||||||
|
nested inside another message and message container) and
|
||||||
|
returns the serialized message data.
|
||||||
"""
|
"""
|
||||||
# TODO write_data_as_message raises on invalid messages, handle it
|
# TODO write_data_as_message raises on invalid messages, handle it
|
||||||
# TODO This method could be an iterator yielding messages while small
|
# TODO This method could be an iterator yielding messages while small
|
||||||
|
@ -72,33 +73,39 @@ class MTProtoLayer:
|
||||||
# to store and serialize the data. However, to keep the context local
|
# to store and serialize the data. However, to keep the context local
|
||||||
# and relevant to the only place where such feature is actually used,
|
# and relevant to the only place where such feature is actually used,
|
||||||
# this is not done.
|
# this is not done.
|
||||||
msg_ids = []
|
n = 0
|
||||||
buffer = io.BytesIO()
|
buffer = io.BytesIO()
|
||||||
for data in data_list:
|
for state in state_list:
|
||||||
if not isinstance(data, list):
|
if not isinstance(state, list):
|
||||||
msg_ids.append(self._state.write_data_as_message(buffer, data))
|
n += 1
|
||||||
|
state.msg_id = \
|
||||||
|
self._state.write_data_as_message(buffer, state.data)
|
||||||
else:
|
else:
|
||||||
last_id = None
|
last_id = None
|
||||||
for d in data:
|
for s in state:
|
||||||
last_id = self._state.write_data_as_message(
|
n += 1
|
||||||
buffer, d, after_id=last_id)
|
last_id = s.msg_id = self._state.write_data_as_message(
|
||||||
msg_ids.append(last_id)
|
buffer, s.data, after_id=last_id)
|
||||||
|
|
||||||
if len(msg_ids) == 1:
|
if n > 1:
|
||||||
container_id = None
|
|
||||||
else:
|
|
||||||
# Inlined code to pack several messages into a container
|
# Inlined code to pack several messages into a container
|
||||||
#
|
#
|
||||||
# TODO This part and encrypting data prepend a few bytes but
|
# TODO This part and encrypting data prepend a few bytes but
|
||||||
# force a potentially large payload to be appended, which
|
# force a potentially large payload to be appended, which
|
||||||
# may be expensive. Can we do better?
|
# may be expensive. Can we do better?
|
||||||
data = struct.pack(
|
data = struct.pack(
|
||||||
'<Ii', MessageContainer.CONSTRUCTOR_ID, len(msg_ids)
|
'<Ii', MessageContainer.CONSTRUCTOR_ID, n
|
||||||
) + buffer.getvalue()
|
) + buffer.getvalue()
|
||||||
buffer = io.BytesIO()
|
buffer = io.BytesIO()
|
||||||
container_id = self._state.write_data_as_message(buffer, data)
|
container_id = self._state.write_data_as_message(buffer, data)
|
||||||
|
for state in state_list:
|
||||||
|
if not isinstance(state, list):
|
||||||
|
state.container_id = container_id
|
||||||
|
else:
|
||||||
|
for s in state:
|
||||||
|
s.container_id = container_id
|
||||||
|
|
||||||
return buffer.getvalue(), container_id, msg_ids
|
return buffer.getvalue()
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return str(self._connection)
|
return str(self._connection)
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import collections
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from .mtprotolayer import MTProtoLayer
|
from .mtprotolayer import MTProtoLayer
|
||||||
|
@ -15,16 +16,11 @@ from ..tl.types import (
|
||||||
MsgNewDetailedInfo, NewSessionCreated, MsgDetailedInfo, MsgsStateReq,
|
MsgNewDetailedInfo, NewSessionCreated, MsgDetailedInfo, MsgsStateReq,
|
||||||
MsgsStateInfo, MsgsAllInfo, MsgResendReq, upload
|
MsgsStateInfo, MsgsAllInfo, MsgResendReq, upload
|
||||||
)
|
)
|
||||||
|
from .requeststate import RequestState
|
||||||
|
|
||||||
__log__ = logging.getLogger(__name__)
|
__log__ = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# Place this object in the send queue when a reconnection is needed
|
|
||||||
# so there is an item to read and we can early quit the loop, since
|
|
||||||
# without this it will block until there's something in the queue.
|
|
||||||
_reconnect_sentinel = object()
|
|
||||||
|
|
||||||
|
|
||||||
class MTProtoSender:
|
class MTProtoSender:
|
||||||
"""
|
"""
|
||||||
MTProto Mobile Protocol sender
|
MTProto Mobile Protocol sender
|
||||||
|
@ -66,27 +62,21 @@ class MTProtoSender:
|
||||||
self._send_loop_handle = None
|
self._send_loop_handle = None
|
||||||
self._recv_loop_handle = None
|
self._recv_loop_handle = None
|
||||||
|
|
||||||
# Sending something shouldn't block
|
# Outgoing messages are put in a queue and sent in a batch.
|
||||||
self._send_queue = _ContainerQueue()
|
# Note that here we're also storing their ``_RequestState``.
|
||||||
|
# Note that it may also store lists (implying order must be kept).
|
||||||
|
self._send_queue = asyncio.Queue()
|
||||||
|
|
||||||
# Telegram responds to messages out of order. Keep
|
# Sent states are remembered until a response is received.
|
||||||
# {id: Message} to set their Future result upon arrival.
|
self._pending_state = {}
|
||||||
self._pending_messages = {}
|
|
||||||
|
|
||||||
# Containers are accepted or rejected as a whole when any of
|
# Responses must be acknowledged, and we can also batch these.
|
||||||
# its inner requests are acknowledged. For this purpose we
|
|
||||||
# all the sent containers here.
|
|
||||||
self._pending_containers = []
|
|
||||||
|
|
||||||
# We need to acknowledge every response from Telegram
|
|
||||||
self._pending_ack = set()
|
self._pending_ack = set()
|
||||||
|
|
||||||
# Similar to pending_messages but only for the last ack.
|
# Similar to pending_messages but only for the last acknowledges.
|
||||||
# Ack can't be put in the messages because Telegram never
|
# These can't go in pending_messages because no acknowledge for them
|
||||||
# responds to acknowledges (they're just that, acknowledges),
|
# is received, but we may still need to resend their state on bad salts.
|
||||||
# so it would grow to infinite otherwise, but on bad salt it's
|
self._last_acks = collections.deque(maxlen=10)
|
||||||
# necessary to resend them just like everything else.
|
|
||||||
self._last_ack = None
|
|
||||||
|
|
||||||
# Jump table from response ID to method that handles it
|
# Jump table from response ID to method that handles it
|
||||||
self._handlers = {
|
self._handlers = {
|
||||||
|
@ -136,6 +126,7 @@ class MTProtoSender:
|
||||||
|
|
||||||
await self._disconnect()
|
await self._disconnect()
|
||||||
|
|
||||||
|
# TODO Move this out of the "Public API" section
|
||||||
async def _disconnect(self, error=None):
|
async def _disconnect(self, error=None):
|
||||||
__log__.info('Disconnecting from %s...', self._connection)
|
__log__.info('Disconnecting from %s...', self._connection)
|
||||||
self._user_connected = False
|
self._user_connected = False
|
||||||
|
@ -144,14 +135,14 @@ class MTProtoSender:
|
||||||
self._connection.disconnect()
|
self._connection.disconnect()
|
||||||
finally:
|
finally:
|
||||||
__log__.debug('Cancelling {} pending message(s)...'
|
__log__.debug('Cancelling {} pending message(s)...'
|
||||||
.format(len(self._pending_messages)))
|
.format(len(self._pending_state)))
|
||||||
for message in self._pending_messages.values():
|
for state in self._pending_state.values():
|
||||||
if error and not message.future.done():
|
if error and not state.future.done():
|
||||||
message.future.set_exception(error)
|
state.future.set_exception(error)
|
||||||
else:
|
else:
|
||||||
message.future.cancel()
|
state.future.cancel()
|
||||||
|
|
||||||
self._pending_messages.clear()
|
self._pending_state.clear()
|
||||||
self._pending_ack.clear()
|
self._pending_ack.clear()
|
||||||
self._last_ack = None
|
self._last_ack = None
|
||||||
|
|
||||||
|
@ -172,11 +163,9 @@ class MTProtoSender:
|
||||||
|
|
||||||
def send(self, request, ordered=False):
|
def send(self, request, ordered=False):
|
||||||
"""
|
"""
|
||||||
This method enqueues the given request to be sent.
|
This method enqueues the given request to be sent. Its send
|
||||||
|
state will be saved until a response arrives, and a ``Future``
|
||||||
The request will be wrapped inside a `TLMessage` until its
|
that will be resolved when the response arrives will be returned:
|
||||||
response arrives, and the `Future` response of the `TLMessage`
|
|
||||||
is immediately returned so that one can further ``await`` it:
|
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
|
@ -198,23 +187,23 @@ class MTProtoSender:
|
||||||
if not self._user_connected:
|
if not self._user_connected:
|
||||||
raise ConnectionError('Cannot send requests while disconnected')
|
raise ConnectionError('Cannot send requests while disconnected')
|
||||||
|
|
||||||
if utils.is_list_like(request):
|
if not utils.is_list_like(request):
|
||||||
result = []
|
state = RequestState(request, self._loop)
|
||||||
after = None
|
self._send_queue.put_nowait(state)
|
||||||
for r in request:
|
return state.future
|
||||||
message = self.state.create_message(
|
|
||||||
r, loop=self._loop, after=after)
|
|
||||||
|
|
||||||
self._pending_messages[message.msg_id] = message
|
|
||||||
self._send_queue.put_nowait(message)
|
|
||||||
result.append(message.future)
|
|
||||||
after = ordered and message
|
|
||||||
return result
|
|
||||||
else:
|
else:
|
||||||
message = self.state.create_message(request, loop=self._loop)
|
states = []
|
||||||
self._pending_messages[message.msg_id] = message
|
futures = []
|
||||||
self._send_queue.put_nowait(message)
|
for req in request:
|
||||||
return message.future
|
state = RequestState(req, self._loop)
|
||||||
|
states.append(state)
|
||||||
|
futures.append(state.future)
|
||||||
|
if ordered:
|
||||||
|
self._send_queue.put(states)
|
||||||
|
else:
|
||||||
|
for state in states:
|
||||||
|
self._send_queue.put(state)
|
||||||
|
return futures
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def disconnected(self):
|
def disconnected(self):
|
||||||
|
@ -290,7 +279,6 @@ class MTProtoSender:
|
||||||
Cleanly disconnects and then reconnects.
|
Cleanly disconnects and then reconnects.
|
||||||
"""
|
"""
|
||||||
self._reconnecting = True
|
self._reconnecting = True
|
||||||
self._send_queue.put_nowait(_reconnect_sentinel)
|
|
||||||
|
|
||||||
__log__.debug('Awaiting for the send loop before reconnecting...')
|
__log__.debug('Awaiting for the send loop before reconnecting...')
|
||||||
await self._send_loop_handle
|
await self._send_loop_handle
|
||||||
|
@ -307,8 +295,11 @@ class MTProtoSender:
|
||||||
for retry in range(1, retries + 1):
|
for retry in range(1, retries + 1):
|
||||||
try:
|
try:
|
||||||
await self._connect()
|
await self._connect()
|
||||||
|
# TODO Keep this?
|
||||||
|
"""
|
||||||
for m in self._pending_messages.values():
|
for m in self._pending_messages.values():
|
||||||
self._send_queue.put_nowait(m)
|
self._send_queue.put_nowait(m)
|
||||||
|
"""
|
||||||
|
|
||||||
if self._auto_reconnect_callback:
|
if self._auto_reconnect_callback:
|
||||||
self._loop.create_task(self._auto_reconnect_callback())
|
self._loop.create_task(self._auto_reconnect_callback())
|
||||||
|
@ -325,23 +316,6 @@ class MTProtoSender:
|
||||||
if self._user_connected:
|
if self._user_connected:
|
||||||
self._loop.create_task(self._reconnect())
|
self._loop.create_task(self._reconnect())
|
||||||
|
|
||||||
def _clean_containers(self, msg_ids):
|
|
||||||
"""
|
|
||||||
Helper method to clean containers from the pending messages
|
|
||||||
once a wrapped msg_id of them has been acknowledged.
|
|
||||||
|
|
||||||
This is the only way we can resend TLMessage(MessageContainer)
|
|
||||||
on bad notifications and also mark them as received once any
|
|
||||||
of their inner TLMessage is acknowledged.
|
|
||||||
"""
|
|
||||||
for i in reversed(range(len(self._pending_containers))):
|
|
||||||
message = self._pending_containers[i]
|
|
||||||
for msg in message.obj.messages:
|
|
||||||
if msg.msg_id in msg_ids:
|
|
||||||
del self._pending_containers[i]
|
|
||||||
del self._pending_messages[message.msg_id]
|
|
||||||
break
|
|
||||||
|
|
||||||
# Loops
|
# Loops
|
||||||
|
|
||||||
async def _send_loop(self):
|
async def _send_loop(self):
|
||||||
|
@ -353,67 +327,31 @@ class MTProtoSender:
|
||||||
"""
|
"""
|
||||||
while self._user_connected and not self._reconnecting:
|
while self._user_connected and not self._reconnecting:
|
||||||
if self._pending_ack:
|
if self._pending_ack:
|
||||||
self._last_ack = self.state.create_message(
|
ack = RequestState(MsgsAck(list(self._pending_ack)), self._loop)
|
||||||
MsgsAck(list(self._pending_ack)), loop=self._loop
|
self._send_queue.put_nowait(ack)
|
||||||
)
|
self._last_acks.append(ack)
|
||||||
self._send_queue.put_nowait(self._last_ack)
|
|
||||||
self._pending_ack.clear()
|
self._pending_ack.clear()
|
||||||
|
|
||||||
messages = await self._send_queue.get()
|
state_list = []
|
||||||
if messages == _reconnect_sentinel:
|
|
||||||
if self._reconnecting:
|
# TODO wait for the list to have one or for a disconnect to happen
|
||||||
break
|
# and pop while that's the case
|
||||||
|
state = await self._send_queue.get()
|
||||||
|
state_list.append(state)
|
||||||
|
|
||||||
|
while not self._send_queue.empty():
|
||||||
|
state_list.append(self._send_queue.get_nowait())
|
||||||
|
|
||||||
|
# TODO Debug logs to notify which messages are being sent
|
||||||
|
# TODO Try sending them while no future was cancelled?
|
||||||
|
# TODO Handle timeout, cancelled, arbitrary errors
|
||||||
|
await self._connection.send(state_list)
|
||||||
|
for state in state_list:
|
||||||
|
if not isinstance(state, list):
|
||||||
|
self._pending_state[state.msg_id] = state
|
||||||
else:
|
else:
|
||||||
continue
|
for s in state:
|
||||||
|
self._pending_state[s.msg_id] = s
|
||||||
if isinstance(messages, list):
|
|
||||||
message = self.state.create_message(
|
|
||||||
MessageContainer(messages), loop=self._loop)
|
|
||||||
|
|
||||||
self._pending_messages[message.msg_id] = message
|
|
||||||
self._pending_containers.append(message)
|
|
||||||
else:
|
|
||||||
message = messages
|
|
||||||
messages = [message]
|
|
||||||
|
|
||||||
__log__.debug(
|
|
||||||
'Packing %d outgoing message(s) %s...', len(messages),
|
|
||||||
', '.join(x.obj.__class__.__name__ for x in messages)
|
|
||||||
)
|
|
||||||
body = self.state.pack_message(message)
|
|
||||||
|
|
||||||
while not any(m.future.cancelled() for m in messages):
|
|
||||||
try:
|
|
||||||
__log__.debug('Sending {} bytes...'.format(len(body)))
|
|
||||||
await self._connection.send(body)
|
|
||||||
break
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
continue
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
return
|
|
||||||
except Exception as e:
|
|
||||||
if isinstance(e, ConnectionError):
|
|
||||||
__log__.info('Connection reset while sending %s', e)
|
|
||||||
elif isinstance(e, OSError):
|
|
||||||
__log__.warning('OSError while sending %s', e)
|
|
||||||
else:
|
|
||||||
__log__.exception('Unhandled exception while receiving')
|
|
||||||
await asyncio.sleep(1, loop=self._loop)
|
|
||||||
|
|
||||||
self._start_reconnect()
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
# Remove the cancelled messages from pending
|
|
||||||
__log__.info('Some futures were cancelled, aborted send')
|
|
||||||
self._clean_containers([m.msg_id for m in messages])
|
|
||||||
for m in messages:
|
|
||||||
if m.future.cancelled():
|
|
||||||
self._pending_messages.pop(m.msg_id, None)
|
|
||||||
else:
|
|
||||||
self._send_queue.put_nowait(m)
|
|
||||||
|
|
||||||
__log__.debug('Outgoing messages {} sent!'
|
|
||||||
.format(', '.join(str(m.msg_id) for m in messages)))
|
|
||||||
|
|
||||||
async def _recv_loop(self):
|
async def _recv_loop(self):
|
||||||
"""
|
"""
|
||||||
|
@ -423,68 +361,11 @@ class MTProtoSender:
|
||||||
Besides `connect`, only this method ever receives data.
|
Besides `connect`, only this method ever receives data.
|
||||||
"""
|
"""
|
||||||
while self._user_connected and not self._reconnecting:
|
while self._user_connected and not self._reconnecting:
|
||||||
try:
|
# TODO Handle timeout, cancelled, arbitrary, broken auth, buffer,
|
||||||
__log__.debug('Receiving items from the network...')
|
# security and type not found.
|
||||||
body = await self._connection.recv()
|
__log__.debug('Receiving items from the network...')
|
||||||
except asyncio.TimeoutError:
|
message = await self._connection.recv()
|
||||||
continue
|
await self._process_message(message)
|
||||||
except asyncio.CancelledError:
|
|
||||||
return
|
|
||||||
except Exception as e:
|
|
||||||
if isinstance(e, ConnectionError):
|
|
||||||
__log__.info('Connection reset while receiving %s', e)
|
|
||||||
elif isinstance(e, OSError):
|
|
||||||
__log__.warning('OSError while receiving %s', e)
|
|
||||||
else:
|
|
||||||
__log__.exception('Unhandled exception while receiving')
|
|
||||||
await asyncio.sleep(1, loop=self._loop)
|
|
||||||
|
|
||||||
self._start_reconnect()
|
|
||||||
break
|
|
||||||
|
|
||||||
# TODO Check salt, session_id and sequence_number
|
|
||||||
__log__.debug('Decoding packet of %d bytes...', len(body))
|
|
||||||
try:
|
|
||||||
message = self.state.unpack_message(body)
|
|
||||||
except (BrokenAuthKeyError, BufferError) as e:
|
|
||||||
# The authorization key may be broken if a message was
|
|
||||||
# sent malformed, or if the authkey truly is corrupted.
|
|
||||||
#
|
|
||||||
# There may be a buffer error if Telegram's response was too
|
|
||||||
# short and hence not understood. Reset the authorization key
|
|
||||||
# and try again in either case.
|
|
||||||
#
|
|
||||||
# TODO Is it possible to detect malformed messages vs
|
|
||||||
# an actually broken authkey?
|
|
||||||
__log__.warning('Broken authorization key?: {}'.format(e))
|
|
||||||
self.state.auth_key = None
|
|
||||||
self._start_reconnect()
|
|
||||||
break
|
|
||||||
except SecurityError as e:
|
|
||||||
# A step while decoding had the incorrect data. This message
|
|
||||||
# should not be considered safe and it should be ignored.
|
|
||||||
__log__.warning('Security error while unpacking a '
|
|
||||||
'received message: {}'.format(e))
|
|
||||||
continue
|
|
||||||
except TypeNotFoundError as e:
|
|
||||||
# The payload inside the message was not a known TLObject.
|
|
||||||
__log__.info('Server replied with an unknown type {:08x}: {!r}'
|
|
||||||
.format(e.invalid_constructor_id, e.remaining))
|
|
||||||
continue
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
return
|
|
||||||
except Exception as e:
|
|
||||||
__log__.exception('Unhandled exception while unpacking %s',e)
|
|
||||||
await asyncio.sleep(1, loop=self._loop)
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
await self._process_message(message)
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
return
|
|
||||||
except Exception as e:
|
|
||||||
__log__.exception('Unhandled exception while '
|
|
||||||
'processing %s', message)
|
|
||||||
await asyncio.sleep(1, loop=self._loop)
|
|
||||||
|
|
||||||
# Response Handlers
|
# Response Handlers
|
||||||
|
|
||||||
|
@ -508,11 +389,11 @@ class MTProtoSender:
|
||||||
This is where the future results for sent requests are set.
|
This is where the future results for sent requests are set.
|
||||||
"""
|
"""
|
||||||
rpc_result = message.obj
|
rpc_result = message.obj
|
||||||
message = self._pending_messages.pop(rpc_result.req_msg_id, None)
|
state = self._pending_state.pop(rpc_result.req_msg_id, None)
|
||||||
__log__.debug('Handling RPC result for message %d',
|
__log__.debug('Handling RPC result for message %d',
|
||||||
rpc_result.req_msg_id)
|
rpc_result.req_msg_id)
|
||||||
|
|
||||||
if not message:
|
if not state:
|
||||||
# TODO We should not get responses to things we never sent
|
# TODO We should not get responses to things we never sent
|
||||||
# However receiving a File() with empty bytes is "common".
|
# However receiving a File() with empty bytes is "common".
|
||||||
# See #658, #759 and #958. They seem to happen in a container
|
# See #658, #759 and #958. They seem to happen in a container
|
||||||
|
@ -528,22 +409,21 @@ class MTProtoSender:
|
||||||
|
|
||||||
if rpc_result.error:
|
if rpc_result.error:
|
||||||
error = rpc_message_to_error(rpc_result.error)
|
error = rpc_message_to_error(rpc_result.error)
|
||||||
self._send_queue.put_nowait(self.state.create_message(
|
self._send_queue.put_nowait(
|
||||||
MsgsAck([message.msg_id]), loop=self._loop
|
RequestState(MsgsAck([state.msg_id]), loop=self._loop))
|
||||||
))
|
|
||||||
|
|
||||||
if not message.future.cancelled():
|
if not state.future.cancelled():
|
||||||
message.future.set_exception(error)
|
state.future.set_exception(error)
|
||||||
else:
|
else:
|
||||||
# TODO Would be nice to avoid accessing a per-obj read_result
|
# TODO Would be nice to avoid accessing a per-obj read_result
|
||||||
# Instead have a variable that indicated how the result should
|
# Instead have a variable that indicated how the result should
|
||||||
# be read (an enum) and dispatch to read the result, mostly
|
# be read (an enum) and dispatch to read the result, mostly
|
||||||
# always it's just a normal TLObject.
|
# always it's just a normal TLObject.
|
||||||
with BinaryReader(rpc_result.body) as reader:
|
with BinaryReader(rpc_result.body) as reader:
|
||||||
result = message.obj.read_result(reader)
|
result = state.request.read_result(reader)
|
||||||
|
|
||||||
if not message.future.cancelled():
|
if not state.future.cancelled():
|
||||||
message.future.set_result(result)
|
state.future.set_result(result)
|
||||||
|
|
||||||
async def _handle_container(self, message):
|
async def _handle_container(self, message):
|
||||||
"""
|
"""
|
||||||
|
@ -581,9 +461,9 @@ class MTProtoSender:
|
||||||
"""
|
"""
|
||||||
pong = message.obj
|
pong = message.obj
|
||||||
__log__.debug('Handling pong for message %d', pong.msg_id)
|
__log__.debug('Handling pong for message %d', pong.msg_id)
|
||||||
message = self._pending_messages.pop(pong.msg_id, None)
|
state = self._pending_state.pop(pong.msg_id, None)
|
||||||
if message:
|
if state:
|
||||||
message.future.set_result(pong)
|
state.future.set_result(pong)
|
||||||
|
|
||||||
async def _handle_bad_server_salt(self, message):
|
async def _handle_bad_server_salt(self, message):
|
||||||
"""
|
"""
|
||||||
|
@ -595,16 +475,16 @@ class MTProtoSender:
|
||||||
"""
|
"""
|
||||||
bad_salt = message.obj
|
bad_salt = message.obj
|
||||||
__log__.debug('Handling bad salt for message %d', bad_salt.bad_msg_id)
|
__log__.debug('Handling bad salt for message %d', bad_salt.bad_msg_id)
|
||||||
self.state.salt = bad_salt.new_server_salt
|
self._connection._state.salt = bad_salt.new_server_salt
|
||||||
if self._last_ack and bad_salt.bad_msg_id == self._last_ack.msg_id:
|
|
||||||
self._send_queue.put_nowait(self._last_ack)
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._send_queue.put_nowait(
|
self._send_queue.put_nowait(
|
||||||
self._pending_messages[bad_salt.bad_msg_id])
|
self._pending_state.pop(bad_salt.bad_msg_id))
|
||||||
except KeyError:
|
except KeyError:
|
||||||
# May be MsgsAck, those are not saved in pending messages
|
for ack in self._pending_ack:
|
||||||
|
if ack.msg_id == bad_salt.bad_msg_id:
|
||||||
|
self._send_queue.put_nowait(ack)
|
||||||
|
return
|
||||||
|
|
||||||
__log__.info('Message %d not resent due to bad salt',
|
__log__.info('Message %d not resent due to bad salt',
|
||||||
bad_salt.bad_msg_id)
|
bad_salt.bad_msg_id)
|
||||||
|
|
||||||
|
@ -617,41 +497,33 @@ class MTProtoSender:
|
||||||
error_code:int = BadMsgNotification;
|
error_code:int = BadMsgNotification;
|
||||||
"""
|
"""
|
||||||
bad_msg = message.obj
|
bad_msg = message.obj
|
||||||
msg = self._pending_messages.get(bad_msg.bad_msg_id)
|
# TODO Pending state may need to pop by container ID
|
||||||
|
state = self._pending_state.pop(bad_msg.bad_msg_id, None)
|
||||||
|
|
||||||
__log__.debug('Handling bad msg %s', bad_msg)
|
__log__.debug('Handling bad msg %s', bad_msg)
|
||||||
if bad_msg.error_code in (16, 17):
|
if bad_msg.error_code in (16, 17):
|
||||||
# Sent msg_id too low or too high (respectively).
|
# Sent msg_id too low or too high (respectively).
|
||||||
# Use the current msg_id to determine the right time offset.
|
# Use the current msg_id to determine the right time offset.
|
||||||
to = self.state.update_time_offset(correct_msg_id=message.msg_id)
|
to = self._connection._state.update_time_offset(
|
||||||
|
correct_msg_id=message.msg_id)
|
||||||
__log__.info('System clock is wrong, set time offset to %ds', to)
|
__log__.info('System clock is wrong, set time offset to %ds', to)
|
||||||
|
|
||||||
# Correct the msg_id *of the message to resend*, not all.
|
|
||||||
#
|
|
||||||
# If we correct them all, new "bad message" would not find
|
|
||||||
# the old invalid IDs, causing all awaits to never finish.
|
|
||||||
if msg:
|
|
||||||
del self._pending_messages[msg.msg_id]
|
|
||||||
self.state.update_message_id(msg)
|
|
||||||
self._pending_messages[msg.msg_id] = msg
|
|
||||||
|
|
||||||
elif bad_msg.error_code == 32:
|
elif bad_msg.error_code == 32:
|
||||||
# msg_seqno too low, so just pump it up by some "large" amount
|
# msg_seqno too low, so just pump it up by some "large" amount
|
||||||
# TODO A better fix would be to start with a new fresh session ID
|
# TODO A better fix would be to start with a new fresh session ID
|
||||||
self.state._sequence += 64
|
self._connection._state._sequence += 64
|
||||||
elif bad_msg.error_code == 33:
|
elif bad_msg.error_code == 33:
|
||||||
# msg_seqno too high never seems to happen but just in case
|
# msg_seqno too high never seems to happen but just in case
|
||||||
self.state._sequence -= 16
|
self._connection._state._sequence -= 16
|
||||||
else:
|
else:
|
||||||
if msg:
|
if state:
|
||||||
del self._pending_messages[msg.msg_id]
|
state.future.set_exception(BadMessageError(bad_msg.error_code))
|
||||||
msg.future.set_exception(BadMessageError(bad_msg.error_code))
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# Messages are to be re-sent once we've corrected the issue
|
# Messages are to be re-sent once we've corrected the issue
|
||||||
if msg:
|
if state:
|
||||||
self._send_queue.put_nowait(msg)
|
self._send_queue.put_nowait(state)
|
||||||
else:
|
else:
|
||||||
|
# TODO Generic method that may return from the acks too
|
||||||
# May be MsgsAck, those are not saved in pending messages
|
# May be MsgsAck, those are not saved in pending messages
|
||||||
__log__.info('Message %d not resent due to bad msg',
|
__log__.info('Message %d not resent due to bad msg',
|
||||||
bad_msg.bad_msg_id)
|
bad_msg.bad_msg_id)
|
||||||
|
@ -689,7 +561,7 @@ class MTProtoSender:
|
||||||
"""
|
"""
|
||||||
# TODO https://goo.gl/LMyN7A
|
# TODO https://goo.gl/LMyN7A
|
||||||
__log__.debug('Handling new session created')
|
__log__.debug('Handling new session created')
|
||||||
self.state.salt = message.obj.server_salt
|
self._connection._state.salt = message.obj.server_salt
|
||||||
|
|
||||||
async def _handle_ack(self, message):
|
async def _handle_ack(self, message):
|
||||||
"""
|
"""
|
||||||
|
@ -708,14 +580,11 @@ class MTProtoSender:
|
||||||
"""
|
"""
|
||||||
ack = message.obj
|
ack = message.obj
|
||||||
__log__.debug('Handling acknowledge for %s', str(ack.msg_ids))
|
__log__.debug('Handling acknowledge for %s', str(ack.msg_ids))
|
||||||
if self._pending_containers:
|
|
||||||
self._clean_containers(ack.msg_ids)
|
|
||||||
|
|
||||||
for msg_id in ack.msg_ids:
|
for msg_id in ack.msg_ids:
|
||||||
msg = self._pending_messages.get(msg_id, None)
|
state = self._pending_state.get(msg_id)
|
||||||
if msg and isinstance(msg.obj, LogOutRequest):
|
if state and isinstance(state.request, LogOutRequest):
|
||||||
del self._pending_messages[msg_id]
|
del self._pending_state[msg_id]
|
||||||
msg.future.set_result(True)
|
state.future.set_result(True)
|
||||||
|
|
||||||
async def _handle_future_salts(self, message):
|
async def _handle_future_salts(self, message):
|
||||||
"""
|
"""
|
||||||
|
@ -728,52 +597,20 @@ class MTProtoSender:
|
||||||
# TODO save these salts and automatically adjust to the
|
# TODO save these salts and automatically adjust to the
|
||||||
# correct one whenever the salt in use expires.
|
# correct one whenever the salt in use expires.
|
||||||
__log__.debug('Handling future salts for message %d', message.msg_id)
|
__log__.debug('Handling future salts for message %d', message.msg_id)
|
||||||
msg = self._pending_messages.pop(message.msg_id, None)
|
state = self._pending_state.pop(message.msg_id, None)
|
||||||
if msg:
|
if state:
|
||||||
msg.future.set_result(message.obj)
|
state.future.set_result(message.obj)
|
||||||
|
|
||||||
async def _handle_state_forgotten(self, message):
|
async def _handle_state_forgotten(self, message):
|
||||||
"""
|
"""
|
||||||
Handles both :tl:`MsgsStateReq` and :tl:`MsgResendReq` by
|
Handles both :tl:`MsgsStateReq` and :tl:`MsgResendReq` by
|
||||||
enqueuing a :tl:`MsgsStateInfo` to be sent at a later point.
|
enqueuing a :tl:`MsgsStateInfo` to be sent at a later point.
|
||||||
"""
|
"""
|
||||||
self.send(MsgsStateInfo(req_msg_id=message.msg_id,
|
self._send_queue.put_nowait(RequestState(MsgsStateInfo(
|
||||||
info=chr(1) * len(message.obj.msg_ids)))
|
req_msg_id=message.msg_id, info=chr(1) * len(message.obj.msg_ids)),
|
||||||
|
loop=self._loop))
|
||||||
|
|
||||||
async def _handle_msg_all(self, message):
|
async def _handle_msg_all(self, message):
|
||||||
"""
|
"""
|
||||||
Handles :tl:`MsgsAllInfo` by doing nothing (yet).
|
Handles :tl:`MsgsAllInfo` by doing nothing (yet).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class _ContainerQueue(asyncio.Queue):
|
|
||||||
"""
|
|
||||||
An asyncio queue that's aware of `MessageContainer` instances.
|
|
||||||
|
|
||||||
The `get` method returns either a single `TLMessage` or a list
|
|
||||||
of them that should be turned into a new `MessageContainer`.
|
|
||||||
|
|
||||||
Instances of this class can be replaced with the simpler
|
|
||||||
``asyncio.Queue`` when needed for testing purposes, and
|
|
||||||
a list won't be returned in said case.
|
|
||||||
"""
|
|
||||||
async def get(self):
|
|
||||||
result = await super().get()
|
|
||||||
if self.empty() or result == _reconnect_sentinel or\
|
|
||||||
isinstance(result.obj, MessageContainer):
|
|
||||||
return result
|
|
||||||
|
|
||||||
size = result.size()
|
|
||||||
result = [result]
|
|
||||||
while not self.empty():
|
|
||||||
item = self.get_nowait()
|
|
||||||
if (item == _reconnect_sentinel or
|
|
||||||
isinstance(item.obj, MessageContainer)
|
|
||||||
or size + item.size() > MessageContainer.MAXIMUM_SIZE):
|
|
||||||
self.put_nowait(item)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
size += item.size()
|
|
||||||
result.append(item)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
18
telethon/network/requeststate.py
Normal file
18
telethon/network/requeststate.py
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
|
class RequestState:
|
||||||
|
"""
|
||||||
|
This request state holds several information relevant to sent messages,
|
||||||
|
in particular the message ID assigned to the request, the container ID
|
||||||
|
it belongs to, the request itself, the request as bytes, and the future
|
||||||
|
result that will eventually be resolved.
|
||||||
|
"""
|
||||||
|
__slots__ = ('container_id', 'msg_id', 'request', 'data', 'future')
|
||||||
|
|
||||||
|
def __init__(self, request, loop):
|
||||||
|
self.container_id = None
|
||||||
|
self.msg_id = None
|
||||||
|
self.request = request
|
||||||
|
self.data = bytes(request)
|
||||||
|
self.future = asyncio.Future(loop=loop)
|
Loading…
Reference in New Issue
Block a user