Allow sending several requests at once through new MessageContainer

This commit is contained in:
Lonami Exo 2017-09-25 20:52:27 +02:00
parent b40708a8c7
commit f233110732
5 changed files with 86 additions and 34 deletions

View File

@ -9,6 +9,7 @@ from ..errors import (
rpc_message_to_error rpc_message_to_error
) )
from ..extensions import BinaryReader, BinaryWriter from ..extensions import BinaryReader, BinaryWriter
from ..tl import MessageContainer
from ..tl.all_tlobjects import tlobjects from ..tl.all_tlobjects import tlobjects
from ..tl.types import MsgsAck from ..tl.types import MsgsAck
@ -56,14 +57,20 @@ class MtProtoSender:
# region Send and receive # region Send and receive
def send(self, request): def send(self, *requests):
"""Sends the specified MTProtoRequest, previously sending any message """Sends the specified MTProtoRequest, previously sending any message
which needed confirmation.""" which needed confirmation."""
# If any message needs confirmation send an AckRequest first # If any message needs confirmation send an AckRequest first
self._send_acknowledges() self._send_acknowledges()
# Finally send our packed request # Finally send our packed request(s)
self._pending_receive.extend(requests)
if len(requests) == 1:
request = requests[0]
else:
request = MessageContainer(self.session, requests)
with BinaryWriter() as writer: with BinaryWriter() as writer:
request.on_send(writer) request.on_send(writer)
self._send_packet(writer.get_bytes(), request) self._send_packet(writer.get_bytes(), request)
@ -268,22 +275,17 @@ class MtProtoSender:
def _handle_container(self, msg_id, sequence, reader, state): def _handle_container(self, msg_id, sequence, reader, state):
self._logger.debug('Handling container') self._logger.debug('Handling container')
reader.read_int(signed=False) # code for inner_msg_id, _, inner_len in MessageContainer.iter_read(reader):
size = reader.read_int()
for _ in range(size):
inner_msg_id = reader.read_long()
reader.read_int() # inner_sequence
inner_length = reader.read_int()
begin_position = reader.tell_position() begin_position = reader.tell_position()
# Note that this code is IMPORTANT for skipping RPC results of # Note that this code is IMPORTANT for skipping RPC results of
# lost requests (i.e., ones from the previous connection session) # lost requests (i.e., ones from the previous connection session)
try: try:
if not self._process_msg(inner_msg_id, sequence, reader, state): if not self._process_msg(inner_msg_id, sequence, reader, state):
reader.set_position(begin_position + inner_length) reader.set_position(begin_position + inner_len)
except: except:
# If any error is raised, something went wrong; skip the packet # If any error is raised, something went wrong; skip the packet
reader.set_position(begin_position + inner_length) reader.set_position(begin_position + inner_len)
raise raise
return True return True

View File

@ -290,7 +290,7 @@ class TelegramBareClient:
# region Invoking Telegram requests # region Invoking Telegram requests
def invoke(self, request, call_receive=True, retries=5): def invoke(self, *requests, call_receive=True, retries=5):
"""Invokes (sends) a MTProtoRequest and returns (receives) its result. """Invokes (sends) a MTProtoRequest and returns (receives) its result.
If 'updates' is not None, all read update object will be put If 'updates' is not None, all read update object will be put
@ -300,7 +300,8 @@ class TelegramBareClient:
thread calling to 'self._sender.receive()' running or this method thread calling to 'self._sender.receive()' running or this method
will lock forever. will lock forever.
""" """
if not isinstance(request, TLObject) and not request.content_related: if not all(isinstance(x, TLObject) and
x.content_related for x in requests):
raise ValueError('You can only invoke requests, not types!') raise ValueError('You can only invoke requests, not types!')
if retries <= 0: if retries <= 0:
@ -308,20 +309,22 @@ class TelegramBareClient:
try: try:
# Ensure that we start with no previous errors (i.e. resending) # Ensure that we start with no previous errors (i.e. resending)
request.confirm_received.clear() for x in requests:
request.rpc_error = None x.confirm_received.clear()
x.rpc_error = None
self._sender.send(request) self._sender.send(*requests)
if not call_receive: if not call_receive:
# TODO This will be slightly troublesome if we allow # TODO This will be slightly troublesome if we allow
# switching between constant read or not on the fly. # switching between constant read or not on the fly.
# Must also watch out for calling .read() from two places, # Must also watch out for calling .read() from two places,
# in which case a Lock would be required for .receive(). # in which case a Lock would be required for .receive().
request.confirm_received.wait( for x in requests:
x.confirm_received.wait(
self._sender.connection.get_timeout() self._sender.connection.get_timeout()
) )
else: else:
while not request.confirm_received.is_set(): while not all(x.confirm_received.is_set() for x in requests):
self._sender.receive(update_state=self.updates) self._sender.receive(update_state=self.updates)
except TimeoutError: except TimeoutError:
@ -336,14 +339,19 @@ class TelegramBareClient:
self.disconnect() self.disconnect()
raise raise
if request.rpc_error: try:
raise request.rpc_error raise next(x.rpc_error for x in requests if x.rpc_error)
if request.result is None: except StopIteration:
if any(x.result is None for x in requests):
# "A container may only be accepted or
# rejected by the other party as a whole."
return self.invoke( return self.invoke(
request, call_receive=call_receive, retries=(retries - 1) *requests, call_receive=call_receive, retries=(retries - 1)
) )
elif len(requests) == 1:
return requests[0].result
else: else:
return request.result return [x.result for x in requests]
# Let people use client(SomeRequest()) instead client.invoke(...) # Let people use client(SomeRequest()) instead client.invoke(...)
__call__ = invoke __call__ = invoke

View File

@ -239,11 +239,10 @@ class TelegramClient(TelegramBareClient):
# region Telegram requests functions # region Telegram requests functions
def invoke(self, request, *args, **kwargs): def invoke(self, *requests, **kwargs):
"""Invokes (sends) a MTProtoRequest and returns (receives) its result. """Invokes (sends) one or several MTProtoRequest and returns
An optional 'retries' parameter can be set. (receives) their result. An optional named 'retries' parameter
can be used, indicating how many times it should retry.
*args will be ignored.
""" """
# This is only valid when the read thread is reconnecting, # This is only valid when the read thread is reconnecting,
# that is, the connection lock is locked. # that is, the connection lock is locked.
@ -261,7 +260,8 @@ class TelegramClient(TelegramBareClient):
self._recv_thread is None or self._connect_lock.locked() self._recv_thread is None or self._connect_lock.locked()
return super().invoke( return super().invoke(
request, call_receive=call_receive, *requests,
call_receive=call_receive,
retries=kwargs.get('retries', 5) retries=kwargs.get('retries', 5)
) )
@ -275,7 +275,7 @@ class TelegramClient(TelegramBareClient):
# be on the very first connection (not authorized, not running), # be on the very first connection (not authorized, not running),
# but may be an issue for people who actually travel? # but may be an issue for people who actually travel?
self._reconnect(new_dc=e.new_dc) self._reconnect(new_dc=e.new_dc)
return self.invoke(request) return self.invoke(*requests)
except ConnectionResetError as e: except ConnectionResetError as e:
if self._connect_lock.locked(): if self._connect_lock.locked():

View File

@ -1,2 +1,3 @@
from .tlobject import TLObject from .tlobject import TLObject
from .session import Session from .session import Session
from .message_container import MessageContainer

View File

@ -0,0 +1,41 @@
from . import TLObject
from ..extensions import BinaryWriter
class MessageContainer(TLObject):
constructor_id = 0x8953ad37
# TODO Currently it's a bit of a hack, since the container actually holds
# messages (message id, sequence number, request body), not requests.
# Probably create a proper "Message" class
def __init__(self, session, requests):
super().__init__()
self.content_related = False
self.session = session
self.requests = requests
def on_send(self, writer):
writer.write_int(0x73f1f8dc, signed=False)
writer.write_int(len(self.requests))
for x in self.requests:
with BinaryWriter() as aux:
x.on_send(aux)
x.request_msg_id = self.session.get_new_msg_id()
writer.write_long(x.request_msg_id)
writer.write_int(
self.session.generate_sequence(x.content_related)
)
packet = aux.get_bytes()
writer.write_int(len(packet))
writer.write(packet)
@staticmethod
def iter_read(reader):
reader.read_int(signed=False) # code
size = reader.read_int()
for _ in range(size):
inner_msg_id = reader.read_long()
inner_sequence = reader.read_int()
inner_length = reader.read_int()
yield inner_msg_id, inner_sequence, inner_length