Merge branch 'master' into asyncio

This commit is contained in:
Lonami Exo 2018-05-10 09:55:05 +02:00
commit 95eac6c151
4 changed files with 81 additions and 34 deletions

View File

@ -3,11 +3,10 @@ This module contains the class used to communicate with Telegram's servers
encrypting every packet, and relies on a valid AuthKey in the used Session. encrypting every packet, and relies on a valid AuthKey in the used Session.
""" """
import asyncio import asyncio
import gzip
import logging import logging
from asyncio import Event from asyncio import Event
from .. import helpers as utils from .. import helpers, utils
from ..errors import ( from ..errors import (
BadMessageError, InvalidChecksumError, BrokenAuthKeyError, BadMessageError, InvalidChecksumError, BrokenAuthKeyError,
rpc_message_to_error rpc_message_to_error
@ -15,6 +14,7 @@ from ..errors import (
from ..extensions import BinaryReader from ..extensions import BinaryReader
from ..tl import TLMessage, MessageContainer, GzipPacked from ..tl import TLMessage, MessageContainer, GzipPacked
from ..tl.all_tlobjects import tlobjects from ..tl.all_tlobjects import tlobjects
from ..tl.functions import InvokeAfterMsgRequest
from ..tl.functions.auth import LogOutRequest from ..tl.functions.auth import LogOutRequest
from ..tl.types import ( from ..tl.types import (
MsgsAck, Pong, BadServerSalt, BadMsgNotification, FutureSalts, MsgsAck, Pong, BadServerSalt, BadMsgNotification, FutureSalts,
@ -81,13 +81,18 @@ class MtProtoSender:
# region Send and receive # region Send and receive
async def send(self, *requests): async def send(self, requests, ordered=False):
""" """
Sends the specified TLObject(s) (which must be requests), Sends the specified TLObject(s) (which must be requests),
and acknowledging any message which needed confirmation. and acknowledging any message which needed confirmation.
:param requests: the requests to be sent. :param requests: the requests to be sent.
:param ordered: whether the requests should be invoked in the
order in which they appear or they can be executed
in arbitrary order in the server.
""" """
if not utils.is_list_like(requests):
requests = (requests,)
# Prepare the event of every request # Prepare the event of every request
for r in requests: for r in requests:
@ -96,8 +101,15 @@ class MtProtoSender:
else: else:
r.confirm_received.clear() r.confirm_received.clear()
# Finally send our packed request(s) if ordered:
messages = [TLMessage(self.session, r) for r in requests] requests = iter(requests)
messages = [TLMessage(self.session, next(requests))]
for r in requests:
messages.append(TLMessage(self.session, r,
after_id=messages[-1].msg_id))
else:
messages = [TLMessage(self.session, r) for r in requests]
self._pending_receive.update({m.msg_id: m for m in messages}) self._pending_receive.update({m.msg_id: m for m in messages})
__log__.debug('Sending requests with IDs: %s', ', '.join( __log__.debug('Sending requests with IDs: %s', ', '.join(
@ -137,7 +149,12 @@ class MtProtoSender:
Update and Updates objects. Update and Updates objects.
""" """
if self._recv_lock.locked(): if self._recv_lock.locked():
return with await self._recv_lock:
# Don't busy wait, acquire it but return because there's
# already a receive running and we don't want another one.
# It would lock until Telegram sent another update even if
# the current receive already received the expected response.
return
try: try:
with await self._recv_lock: with await self._recv_lock:
@ -187,7 +204,7 @@ class MtProtoSender:
raise BufferError("Can't decode packet ({})".format(body)) raise BufferError("Can't decode packet ({})".format(body))
with BinaryReader(body) as reader: with BinaryReader(body) as reader:
return utils.unpack_message(self.session, reader) return helpers.unpack_message(self.session, reader)
async def _process_msg(self, msg_id, sequence, reader, state): async def _process_msg(self, msg_id, sequence, reader, state):
""" """

View File

@ -418,38 +418,62 @@ class TelegramBareClient:
# region Invoking Telegram requests # region Invoking Telegram requests
async def __call__(self, *requests, retries=5): async def __call__(self, request, retries=5, ordered=False):
"""Invokes (sends) a MTProtoRequest and returns (receives) its result.
The invoke will be retried up to 'retries' times before raising
RuntimeError().
""" """
Invokes (sends) one or more MTProtoRequests and returns (receives)
their result.
Args:
request (`TLObject` | `list`):
The request or requests to be invoked.
retries (`bool`, optional):
How many times the request should be retried automatically
in case it fails with a non-RPC error.
The invoke will be retried up to 'retries' times before raising
``RuntimeError``.
ordered (`bool`, optional):
Whether the requests (if more than one was given) should be
executed sequentially on the server. They run in arbitrary
order by default.
Returns:
The result of the request (often a `TLObject`) or a list of
results if more than one request was given.
"""
single = not utils.is_list_like(request)
if single:
request = (request,)
if not all(isinstance(x, TLObject) and if not all(isinstance(x, TLObject) and
x.content_related for x in requests): x.content_related for x in request):
raise TypeError('You can only invoke requests, not types!') raise TypeError('You can only invoke requests, not types!')
for request in requests: for r in request:
await request.resolve(self, utils) await r.resolve(self, utils)
# For logging purposes # For logging purposes
if len(requests) == 1: if single:
which = type(requests[0]).__name__ which = type(request[0]).__name__
else: else:
which = '{} requests ({})'.format( which = '{} requests ({})'.format(
len(requests), [type(x).__name__ for x in requests]) len(request), [type(x).__name__ for x in request])
__log__.debug('Invoking %s', which) __log__.debug('Invoking %s', which)
call_receive = \ call_receive = \
not self._idling.is_set() or self._reconnect_lock.locked() not self._idling.is_set() or self._reconnect_lock.locked()
for retry in range(retries): for retry in range(retries):
result = await self._invoke(call_receive, retry, *requests) result = await self._invoke(call_receive, retry, request,
ordered=ordered)
if result is not None: if result is not None:
return result return result[0] if single else result
log = __log__.info if retry == 0 else __log__.warning log = __log__.info if retry == 0 else __log__.warning
log('Invoking %s failed %d times, connecting again and retrying', log('Invoking %s failed %d times, connecting again and retrying',
[str(x) for x in requests], retry + 1) which, retry + 1)
await asyncio.sleep(1) await asyncio.sleep(1)
if not self._reconnect_lock.locked(): if not self._reconnect_lock.locked():
@ -457,13 +481,13 @@ class TelegramBareClient:
await self._reconnect() await self._reconnect()
raise RuntimeError('Number of retries reached 0 for {}.'.format( raise RuntimeError('Number of retries reached 0 for {}.'.format(
[type(x).__name__ for x in requests] which
)) ))
# Let people use client.invoke(SomeRequest()) instead client(...) # Let people use client.invoke(SomeRequest()) instead client(...)
invoke = __call__ invoke = __call__
async def _invoke(self, call_receive, retry, *requests): async def _invoke(self, call_receive, retry, requests, ordered=False):
try: try:
# Ensure that we start with no previous errors (i.e. resending) # Ensure that we start with no previous errors (i.e. resending)
for x in requests: for x in requests:
@ -487,7 +511,7 @@ class TelegramBareClient:
self._wrap_init_connection(GetConfigRequest()) self._wrap_init_connection(GetConfigRequest())
) )
await self._sender.send(*requests) await self._sender.send(requests, ordered=ordered)
if not call_receive: if not call_receive:
await asyncio.wait( await asyncio.wait(
@ -540,10 +564,7 @@ class TelegramBareClient:
# rejected by the other party as a whole." # rejected by the other party as a whole."
return None return None
if len(requests) == 1: return [x.result for x in requests]
return requests[0].result
else:
return [x.result for x in requests]
except (PhoneMigrateError, NetworkMigrateError, except (PhoneMigrateError, NetworkMigrateError,
UserMigrateError) as e: UserMigrateError) as e:

View File

@ -1383,10 +1383,7 @@ class TelegramClient(TelegramBareClient):
if requests[0].offset > limit: if requests[0].offset > limit:
break break
if len(requests) == 1: results = await self(requests)
results = (await self(requests[0]),)
else:
results = await self(*requests)
for i in reversed(range(len(requests))): for i in reversed(range(len(requests))):
participants = results[i] participants = results[i]
if not participants.users: if not participants.users:

View File

@ -1,11 +1,12 @@
import struct import struct
from . import TLObject, GzipPacked from . import TLObject, GzipPacked
from ..tl.functions import InvokeAfterMsgRequest
class TLMessage(TLObject): class TLMessage(TLObject):
"""https://core.telegram.org/mtproto/service_messages#simple-container""" """https://core.telegram.org/mtproto/service_messages#simple-container"""
def __init__(self, session, request): def __init__(self, session, request, after_id=None):
super().__init__() super().__init__()
del self.content_related del self.content_related
self.msg_id = session.get_new_msg_id() self.msg_id = session.get_new_msg_id()
@ -13,16 +14,27 @@ class TLMessage(TLObject):
self.request = request self.request = request
self.container_msg_id = None self.container_msg_id = None
# After which message ID this one should run. We do this so
# InvokeAfterMsgRequest is transparent to the user and we can
# easily invoke after while confirming the original request.
self.after_id = after_id
def to_dict(self, recursive=True): def to_dict(self, recursive=True):
return { return {
'msg_id': self.msg_id, 'msg_id': self.msg_id,
'seq_no': self.seq_no, 'seq_no': self.seq_no,
'request': self.request, 'request': self.request,
'container_msg_id': self.container_msg_id, 'container_msg_id': self.container_msg_id,
'after_id': self.after_id
} }
def __bytes__(self): def __bytes__(self):
body = GzipPacked.gzip_if_smaller(self.request) if self.after_id is None:
body = GzipPacked.gzip_if_smaller(self.request)
else:
body = GzipPacked.gzip_if_smaller(
InvokeAfterMsgRequest(self.after_id, self.request))
return struct.pack('<qii', self.msg_id, self.seq_no, len(body)) + body return struct.pack('<qii', self.msg_id, self.seq_no, len(body)) + body
def __str__(self): def __str__(self):