diff --git a/telethon/network/mtproto_sender.py b/telethon/network/mtproto_sender.py index 54e4e626..24c9a4ab 100644 --- a/telethon/network/mtproto_sender.py +++ b/telethon/network/mtproto_sender.py @@ -3,11 +3,10 @@ This module contains the class used to communicate with Telegram's servers encrypting every packet, and relies on a valid AuthKey in the used Session. """ import asyncio -import gzip import logging from asyncio import Event -from .. import helpers as utils +from .. import helpers, utils from ..errors import ( BadMessageError, InvalidChecksumError, BrokenAuthKeyError, rpc_message_to_error @@ -15,6 +14,7 @@ from ..errors import ( from ..extensions import BinaryReader from ..tl import TLMessage, MessageContainer, GzipPacked from ..tl.all_tlobjects import tlobjects +from ..tl.functions import InvokeAfterMsgRequest from ..tl.functions.auth import LogOutRequest from ..tl.types import ( MsgsAck, Pong, BadServerSalt, BadMsgNotification, FutureSalts, @@ -81,13 +81,18 @@ class MtProtoSender: # region Send and receive - async def send(self, *requests): + async def send(self, requests, ordered=False): """ Sends the specified TLObject(s) (which must be requests), and acknowledging any message which needed confirmation. :param requests: the requests to be sent. + :param ordered: whether the requests should be invoked in the + order in which they appear or they can be executed + in arbitrary order in the server. """ + if not utils.is_list_like(requests): + requests = (requests,) # Prepare the event of every request for r in requests: @@ -96,8 +101,15 @@ class MtProtoSender: else: r.confirm_received.clear() - # Finally send our packed request(s) - messages = [TLMessage(self.session, r) for r in requests] + if ordered: + requests = iter(requests) + messages = [TLMessage(self.session, next(requests))] + for r in requests: + messages.append(TLMessage(self.session, r, + after_id=messages[-1].msg_id)) + else: + messages = [TLMessage(self.session, r) for r in requests] + self._pending_receive.update({m.msg_id: m for m in messages}) __log__.debug('Sending requests with IDs: %s', ', '.join( @@ -137,7 +149,12 @@ class MtProtoSender: Update and Updates objects. """ if self._recv_lock.locked(): - return + with await self._recv_lock: + # Don't busy wait, acquire it but return because there's + # already a receive running and we don't want another one. + # It would lock until Telegram sent another update even if + # the current receive already received the expected response. + return try: with await self._recv_lock: @@ -187,7 +204,7 @@ class MtProtoSender: raise BufferError("Can't decode packet ({})".format(body)) with BinaryReader(body) as reader: - return utils.unpack_message(self.session, reader) + return helpers.unpack_message(self.session, reader) async def _process_msg(self, msg_id, sequence, reader, state): """ diff --git a/telethon/telegram_bare_client.py b/telethon/telegram_bare_client.py index cf19c1d1..9a142b43 100644 --- a/telethon/telegram_bare_client.py +++ b/telethon/telegram_bare_client.py @@ -418,38 +418,62 @@ class TelegramBareClient: # region Invoking Telegram requests - async def __call__(self, *requests, retries=5): - """Invokes (sends) a MTProtoRequest and returns (receives) its result. - - The invoke will be retried up to 'retries' times before raising - RuntimeError(). + async def __call__(self, request, retries=5, ordered=False): """ + Invokes (sends) one or more MTProtoRequests and returns (receives) + their result. + + Args: + request (`TLObject` | `list`): + The request or requests to be invoked. + + retries (`bool`, optional): + How many times the request should be retried automatically + in case it fails with a non-RPC error. + + The invoke will be retried up to 'retries' times before raising + ``RuntimeError``. + + ordered (`bool`, optional): + Whether the requests (if more than one was given) should be + executed sequentially on the server. They run in arbitrary + order by default. + + Returns: + The result of the request (often a `TLObject`) or a list of + results if more than one request was given. + """ + single = not utils.is_list_like(request) + if single: + request = (request,) + if not all(isinstance(x, TLObject) and - x.content_related for x in requests): + x.content_related for x in request): raise TypeError('You can only invoke requests, not types!') - for request in requests: - await request.resolve(self, utils) + for r in request: + await r.resolve(self, utils) # For logging purposes - if len(requests) == 1: - which = type(requests[0]).__name__ + if single: + which = type(request[0]).__name__ else: which = '{} requests ({})'.format( - len(requests), [type(x).__name__ for x in requests]) + len(request), [type(x).__name__ for x in request]) __log__.debug('Invoking %s', which) call_receive = \ not self._idling.is_set() or self._reconnect_lock.locked() for retry in range(retries): - result = await self._invoke(call_receive, retry, *requests) + result = await self._invoke(call_receive, retry, request, + ordered=ordered) if result is not None: - return result + return result[0] if single else result log = __log__.info if retry == 0 else __log__.warning log('Invoking %s failed %d times, connecting again and retrying', - [str(x) for x in requests], retry + 1) + which, retry + 1) await asyncio.sleep(1) if not self._reconnect_lock.locked(): @@ -457,13 +481,13 @@ class TelegramBareClient: await self._reconnect() raise RuntimeError('Number of retries reached 0 for {}.'.format( - [type(x).__name__ for x in requests] + which )) # Let people use client.invoke(SomeRequest()) instead client(...) invoke = __call__ - async def _invoke(self, call_receive, retry, *requests): + async def _invoke(self, call_receive, retry, requests, ordered=False): try: # Ensure that we start with no previous errors (i.e. resending) for x in requests: @@ -487,7 +511,7 @@ class TelegramBareClient: self._wrap_init_connection(GetConfigRequest()) ) - await self._sender.send(*requests) + await self._sender.send(requests, ordered=ordered) if not call_receive: await asyncio.wait( @@ -540,10 +564,7 @@ class TelegramBareClient: # rejected by the other party as a whole." return None - if len(requests) == 1: - return requests[0].result - else: - return [x.result for x in requests] + return [x.result for x in requests] except (PhoneMigrateError, NetworkMigrateError, UserMigrateError) as e: diff --git a/telethon/telegram_client.py b/telethon/telegram_client.py index e6de56ea..f22d1ee7 100644 --- a/telethon/telegram_client.py +++ b/telethon/telegram_client.py @@ -1383,10 +1383,7 @@ class TelegramClient(TelegramBareClient): if requests[0].offset > limit: break - if len(requests) == 1: - results = (await self(requests[0]),) - else: - results = await self(*requests) + results = await self(requests) for i in reversed(range(len(requests))): participants = results[i] if not participants.users: diff --git a/telethon/tl/tl_message.py b/telethon/tl/tl_message.py index bcb48279..f6246de2 100644 --- a/telethon/tl/tl_message.py +++ b/telethon/tl/tl_message.py @@ -1,11 +1,12 @@ import struct from . import TLObject, GzipPacked +from ..tl.functions import InvokeAfterMsgRequest class TLMessage(TLObject): """https://core.telegram.org/mtproto/service_messages#simple-container""" - def __init__(self, session, request): + def __init__(self, session, request, after_id=None): super().__init__() del self.content_related self.msg_id = session.get_new_msg_id() @@ -13,16 +14,27 @@ class TLMessage(TLObject): self.request = request self.container_msg_id = None + # After which message ID this one should run. We do this so + # InvokeAfterMsgRequest is transparent to the user and we can + # easily invoke after while confirming the original request. + self.after_id = after_id + def to_dict(self, recursive=True): return { 'msg_id': self.msg_id, 'seq_no': self.seq_no, 'request': self.request, 'container_msg_id': self.container_msg_id, + 'after_id': self.after_id } def __bytes__(self): - body = GzipPacked.gzip_if_smaller(self.request) + if self.after_id is None: + body = GzipPacked.gzip_if_smaller(self.request) + else: + body = GzipPacked.gzip_if_smaller( + InvokeAfterMsgRequest(self.after_id, self.request)) + return struct.pack('