Properly set future results

This commit is contained in:
Lonami Exo 2018-06-07 10:30:20 +02:00
parent 9477c75fce
commit 56b09c0c9d
4 changed files with 63 additions and 43 deletions

View File

@ -79,11 +79,34 @@ class MTProtoSender:
self._recv_loop_handle.cancel() self._recv_loop_handle.cancel()
async def send(self, request): async def send(self, request):
# TODO Should the asyncio.Future creation belong here? """
request.result = asyncio.Future() This method enqueues the given request to be sent.
The request will be wrapped inside a `TLMessage` until its
response arrives, and the `Future` response of the `TLMessage`
is immediately returned so that one can further ``await`` it:
.. code-block:: python
async def method():
# Sending (enqueued for the send loop)
future = await sender.send(request)
# Receiving (waits for the receive loop to read the result)
result = await future
Designed like this because Telegram may send the response at
any point, and it can send other items while one waits for it.
Once the response for this future arrives, it is set with the
received result, quite similar to how a ``receive()`` call
would otherwise work.
Since the receiving part is "built in" the future, it's
impossible to await receive a result that was never sent.
"""
message = TLMessage(self.session, request) message = TLMessage(self.session, request)
self._pending_messages[message.msg_id] = message self._pending_messages[message.msg_id] = message
await self._send_queue.put(message) await self._send_queue.put(message)
return message.future
# Loops # Loops
@ -129,7 +152,7 @@ class MTProtoSender:
inner_code = reader.read_int(signed=False) inner_code = reader.read_int(signed=False)
reader.seek(-4) reader.seek(-4)
message = self._pending_messages.pop(message_id) message = self._pending_messages.pop(message_id, None)
if inner_code == 0x2144ca19: # RPC Error if inner_code == 0x2144ca19: # RPC Error
reader.seek(4) reader.seek(4)
if self.session.report_errors and message: if self.session.report_errors and message:
@ -142,17 +165,23 @@ class MTProtoSender:
reader.read_int(), reader.tgread_string() reader.read_int(), reader.tgread_string()
) )
# TODO Acknowledge that we received the error request_id await self._send_queue.put(
# TODO Set message.request exception TLMessage(self.session, MsgsAck([msg_id])))
if not message.future.cancelled():
message.future.set_exception(error)
return
elif message: elif message:
# TODO Make on_response result.set_result() instead replacing it
if inner_code == GzipPacked.CONSTRUCTOR_ID: if inner_code == GzipPacked.CONSTRUCTOR_ID:
with BinaryReader(GzipPacked.read(reader)) as compressed_reader: with BinaryReader(GzipPacked.read(reader)) as compressed_reader:
message.on_response(compressed_reader) result = message.request.read_result(compressed_reader)
else: else:
message.on_response(reader) result = message.request.read_result(reader)
# TODO Process possible entities # TODO Process possible entities
if not message.future.cancelled():
message.future.set_result(result)
return
# TODO Try reading an object # TODO Try reading an object

View File

@ -1,3 +1,4 @@
import asyncio
import struct import struct
from . import TLObject, GzipPacked from . import TLObject, GzipPacked
@ -5,7 +6,20 @@ from ..tl.functions import InvokeAfterMsgRequest
class TLMessage(TLObject): class TLMessage(TLObject):
"""https://core.telegram.org/mtproto/service_messages#simple-container""" """
https://core.telegram.org/mtproto/service_messages#simple-container.
Messages are what's ultimately sent to Telegram:
message msg_id:long seqno:int bytes:int body:bytes = Message;
Each message has its own unique identifier, and the body is simply
the serialized request that should be executed on the server. Then
Telegram will, at some point, respond with the result for this msg.
Thus it makes sense that requests and their result are bound to a
sent `TLMessage`, and this result can be represented as a `Future`
that will eventually be set with either a result, error or cancelled.
"""
def __init__(self, session, request, after_id=None): def __init__(self, session, request, after_id=None):
super().__init__() super().__init__()
del self.content_related del self.content_related
@ -13,6 +27,7 @@ class TLMessage(TLObject):
self.seq_no = session.generate_sequence(request.content_related) self.seq_no = session.generate_sequence(request.content_related)
self.request = request self.request = request
self.container_msg_id = None self.container_msg_id = None
self.future = asyncio.Future()
# After which message ID this one should run. We do this so # After which message ID this one should run. We do this so
# InvokeAfterMsgRequest is transparent to the user and we can # InvokeAfterMsgRequest is transparent to the user and we can

View File

@ -5,36 +5,10 @@ from threading import Event
class TLObject: class TLObject:
def __init__(self): def __init__(self):
self.rpc_error = None # TODO Perhaps content_related makes more sense as another type?
self.result = None # An asyncio.Future set later # Something like class TLRequest(TLObject), request inherit this
# These should be overrode
self.content_related = False # Only requests/functions/queries are self.content_related = False # Only requests/functions/queries are
# Internal parameter to tell pickler in which state Event object was
self._event_is_set = False
self._set_event()
def _set_event(self):
self.confirm_received = Event()
# Set Event state to 'set' if needed
if self._event_is_set:
self.confirm_received.set()
def __getstate__(self):
# Save state of the Event object
self._event_is_set = self.confirm_received.is_set()
# Exclude Event object from dict and return new state
new_dct = dict(self.__dict__)
del new_dct["confirm_received"]
return new_dct
def __setstate__(self, state):
self.__dict__ = state
self._set_event()
# These should not be overrode # These should not be overrode
@staticmethod @staticmethod
def pretty_format(obj, indent=None): def pretty_format(obj, indent=None):
@ -164,8 +138,9 @@ class TLObject:
raise TypeError('Cannot interpret "{}" as a date.'.format(dt)) raise TypeError('Cannot interpret "{}" as a date.'.format(dt))
# These are nearly always the same for all subclasses # These are nearly always the same for all subclasses
def on_response(self, reader): @staticmethod
self.result = reader.tgread_object() def read_result(reader):
return reader.tgread_object()
def __eq__(self, o): def __eq__(self, o):
return isinstance(o, type(self)) and self.to_dict() == o.to_dict() return isinstance(o, type(self)) and self.to_dict() == o.to_dict()

View File

@ -142,7 +142,7 @@ def _write_source_code(tlobject, builder, type_constructors):
_write_to_dict(tlobject, builder) _write_to_dict(tlobject, builder)
_write_to_bytes(tlobject, builder) _write_to_bytes(tlobject, builder)
_write_from_reader(tlobject, builder) _write_from_reader(tlobject, builder)
_write_on_response(tlobject, builder) _write_read_result(tlobject, builder)
def _write_class_init(tlobject, type_constructors, builder): def _write_class_init(tlobject, type_constructors, builder):
@ -333,7 +333,7 @@ def _write_from_reader(tlobject, builder):
'{0}=_{0}'.format(a.name) for a in tlobject.real_args)) '{0}=_{0}'.format(a.name) for a in tlobject.real_args))
def _write_on_response(tlobject, builder): def _write_read_result(tlobject, builder):
# Only requests can have a different response that's not their # Only requests can have a different response that's not their
# serialized body, that is, we'll be setting their .result. # serialized body, that is, we'll be setting their .result.
# #
@ -354,9 +354,10 @@ def _write_on_response(tlobject, builder):
return return
builder.end_block() builder.end_block()
builder.writeln('def on_response(self, reader):') builder.writeln('@staticmethod')
builder.writeln('def read_result(reader):')
builder.writeln('reader.read_int() # Vector ID') builder.writeln('reader.read_int() # Vector ID')
builder.writeln('self.result = [reader.read_{}() ' builder.writeln('return [reader.read_{}() '
'for _ in range(reader.read_int())]', m.group(1)) 'for _ in range(reader.read_int())]', m.group(1))