Create RpcResult class and generalise core special cases

This results in a cleaner MTProtoSender, which now can always
read a TLObject with a guaranteed item, if the message is OK.
This commit is contained in:
Lonami Exo 2018-06-09 13:11:49 +02:00
parent 1e66cea9b7
commit f7e8907c6f
11 changed files with 132 additions and 84 deletions

View File

@ -40,49 +40,49 @@ def report_error(code, message, report_method):
"We really don't want to crash when just reporting an error"
def rpc_message_to_error(code, message, report_method=None):
def rpc_message_to_error(rpc_error, report_method=None):
"""
Converts a Telegram's RPC Error to a Python error.
:param code: the integer code of the error (like 400).
:param message: the message representing the error.
:param rpc_error: the RpcError instance.
:param report_method: if present, the ID of the method that caused it.
:return: the RPCError as a Python exception that represents this error.
"""
if report_method is not None:
Thread(
target=report_error,
args=(code, message, report_method)
args=(rpc_error.error_code, rpc_error.error_message, report_method)
).start()
# Try to get the error by direct look-up, otherwise regex
# TODO Maybe regexes could live in a separate dictionary?
cls = rpc_errors_all.get(message, None)
cls = rpc_errors_all.get(rpc_error.error_message, None)
if cls:
return cls()
for msg_regex, cls in rpc_errors_all.items():
m = re.match(msg_regex, message)
m = re.match(msg_regex, rpc_error.error_message)
if m:
capture = int(m.group(1)) if m.groups() else None
return cls(capture=capture)
if code == 400:
return BadRequestError(message)
if rpc_error.error_code == 400:
return BadRequestError(rpc_error.error_message)
if code == 401:
return UnauthorizedError(message)
if rpc_error.error_code == 401:
return UnauthorizedError(rpc_error.error_message)
if code == 403:
return ForbiddenError(message)
if rpc_error.error_code == 403:
return ForbiddenError(rpc_error.error_message)
if code == 404:
return NotFoundError(message)
if rpc_error.error_code == 404:
return NotFoundError(rpc_error.error_message)
if code == 406:
return AuthKeyError(message)
if rpc_error.error_code == 406:
return AuthKeyError(rpc_error.error_message)
if code == 500:
return ServerError(message)
if rpc_error.error_code == 500:
return ServerError(rpc_error.error_message)
return RPCError('{} (code {})'.format(message, code))
return RPCError('{} (code {})'.format(
rpc_error.error_message, rpc_error.error_code))

View File

@ -8,6 +8,7 @@ from struct import unpack
from ..errors import TypeNotFoundError
from ..tl.all_tlobjects import tlobjects
from ..tl.core import core_objects
class BinaryReader:
@ -136,9 +137,11 @@ class BinaryReader:
elif value == 0x1cb5c415: # Vector
return [self.tgread_object() for _ in range(self.read_int())]
# If there was still no luck, give up
self.seek(-4) # Go back
raise TypeNotFoundError(constructor_id)
clazz = core_objects.get(constructor_id, None)
if clazz is None:
# If there was still no luck, give up
self.seek(-4) # Go back
raise TypeNotFoundError(constructor_id)
return clazz.from_reader(self)

View File

@ -9,7 +9,7 @@ from ..errors import (
rpc_message_to_error
)
from ..extensions import BinaryReader
from ..tl import MessageContainer, GzipPacked
from ..tl.core import RpcResult, MessageContainer, GzipPacked
from ..tl.functions.auth import LogOutRequest
from ..tl.types import (
MsgsAck, Pong, BadServerSalt, BadMsgNotification, FutureSalts,
@ -80,7 +80,7 @@ class MTProtoSender:
# Jump table from response ID to method that handles it
self._handlers = {
0xf35c6d01: self._handle_rpc_result,
RpcResult.CONSTRUCTOR_ID: self._handle_rpc_result,
MessageContainer.CONSTRUCTOR_ID: self._handle_container,
GzipPacked.CONSTRUCTOR_ID: self._handle_gzip_packed,
Pong.CONSTRUCTOR_ID: self._handle_pong,
@ -354,26 +354,26 @@ class MTProtoSender:
else:
try:
with BinaryReader(message.body) as reader:
await self._process_message(message, reader)
obj = reader.tgread_object()
except TypeNotFoundError as e:
__log__.warning('Could not decode received message: {}, '
'raw bytes: {!r}'.format(e, message))
else:
await self._process_message(message, obj)
# Response Handlers
async def _process_message(self, message, reader):
async def _process_message(self, message, obj):
"""
Adds the given message to the list of messages that must be
acknowledged and dispatches control to different ``_handle_*``
method based on its type.
"""
self._pending_ack.add(message.msg_id)
code = reader.read_int(signed=False)
reader.seek(-4)
handler = self._handlers.get(code, self._handle_update)
await handler(message, reader)
handler = self._handlers.get(obj.CONSTRUCTOR_ID, self._handle_update)
await handler(message, obj)
async def _handle_rpc_result(self, message, reader):
async def _handle_rpc_result(self, message, rpc_result):
"""
Handles the result for Remote Procedure Calls:
@ -381,20 +381,13 @@ class MTProtoSender:
This is where the future results for sent requests are set.
"""
# TODO Don't make this a special cased object
reader.read_int(signed=False) # code
message_id = reader.read_long()
inner_code = reader.read_int(signed=False)
reader.seek(-4)
message = self._pending_messages.pop(rpc_result.req_msg_id, None)
__log__.debug('Handling RPC result for message {}'
.format(rpc_result.req_msg_id))
__log__.debug('Handling RPC result for message {}'.format(message_id))
message = self._pending_messages.pop(message_id, None)
if inner_code == 0x2144ca19: # RPC Error
if rpc_result.error:
# TODO Report errors if possible/enabled
reader.seek(4)
error = rpc_message_to_error(reader.read_int(),
reader.tgread_string())
error = rpc_message_to_error(rpc_result.error)
await self._send_queue.put(self.state.create_message(
MsgsAck([message.msg_id])
))
@ -403,10 +396,7 @@ class MTProtoSender:
message.future.set_exception(error)
return
elif message:
if inner_code == GzipPacked.CONSTRUCTOR_ID:
with BinaryReader(GzipPacked.read(reader)) as compressed_reader:
result = message.request.read_result(compressed_reader)
else:
with BinaryReader(rpc_result.body) as reader:
result = message.request.read_result(reader)
# TODO Process entities
@ -416,37 +406,37 @@ class MTProtoSender:
else:
# TODO We should not get responses to things we never sent
__log__.info('Received response without parent request: {}'
.format(reader.tgread_object()))
.format(rpc_result.body))
async def _handle_container(self, message, reader):
async def _handle_container(self, message, container):
"""
Processes the inner messages of a container with many of them:
msg_container#73f1f8dc messages:vector<%Message> = MessageContainer;
"""
__log__.debug('Handling container')
for inner_message in MessageContainer.iter_read(reader):
with BinaryReader(inner_message.body) as inner_reader:
await self._process_message(inner_message, inner_reader)
for inner_message in container.messages:
with BinaryReader(inner_message.body) as reader:
inner_obj = reader.tgread_object()
await self._process_message(inner_message, inner_obj)
async def _handle_gzip_packed(self, message, reader):
async def _handle_gzip_packed(self, message, gzip_packed):
"""
Unpacks the data from a gzipped object and processes it:
gzip_packed#3072cfa1 packed_data:bytes = Object;
"""
__log__.debug('Handling gzipped data')
with BinaryReader(GzipPacked.read(reader)) as compressed_reader:
await self._process_message(message, compressed_reader)
with BinaryReader(gzip_packed.data) as reader:
await self._process_message(message, reader.tgread_object())
async def _handle_update(self, message, reader):
obj = reader.tgread_object()
__log__.debug('Handling update {}'.format(obj.__class__.__name__))
async def _handle_update(self, message, update):
__log__.debug('Handling update {}'.format(update.__class__.__name__))
# TODO Further handling of the update
# TODO Process entities
async def _handle_pong(self, message, reader):
async def _handle_pong(self, message, pong):
"""
Handles pong results, which don't come inside a ``rpc_result``
but are still sent through a request:
@ -454,12 +444,11 @@ class MTProtoSender:
pong#347773c5 msg_id:long ping_id:long = Pong;
"""
__log__.debug('Handling pong')
pong = reader.tgread_object()
message = self._pending_messages.pop(pong.msg_id, None)
if message:
message.future.set_result(pong)
async def _handle_bad_server_salt(self, message, reader):
async def _handle_bad_server_salt(self, message, bad_salt):
"""
Corrects the currently used server salt to use the right value
before enqueuing the rejected message to be re-sent:
@ -468,11 +457,10 @@ class MTProtoSender:
error_code:int new_server_salt:long = BadMsgNotification;
"""
__log__.debug('Handling bad salt')
bad_salt = reader.tgread_object()
self.state.salt = bad_salt.new_server_salt
await self._send_queue.put(self._pending_messages[bad_salt.bad_msg_id])
async def _handle_bad_notification(self, message, reader):
async def _handle_bad_notification(self, message, bad_msg):
"""
Adjusts the current state to be correct based on the
received bad message notification whenever possible:
@ -481,7 +469,6 @@ class MTProtoSender:
error_code:int = BadMsgNotification;
"""
__log__.debug('Handling bad message')
bad_msg = reader.tgread_object()
if bad_msg.error_code in (16, 17):
# Sent msg_id too low or too high (respectively).
# Use the current msg_id to determine the right time offset.
@ -502,7 +489,7 @@ class MTProtoSender:
# Messages are to be re-sent once we've corrected the issue
await self._send_queue.put(self._pending_messages[bad_msg.bad_msg_id])
async def _handle_detailed_info(self, message, reader):
async def _handle_detailed_info(self, message, detailed_info):
"""
Updates the current status with the received detailed information:
@ -511,9 +498,9 @@ class MTProtoSender:
"""
# TODO https://goo.gl/VvpCC6
__log__.debug('Handling detailed info')
self._pending_ack.add(reader.tgread_object().answer_msg_id)
self._pending_ack.add(detailed_info.answer_msg_id)
async def _handle_new_detailed_info(self, message, reader):
async def _handle_new_detailed_info(self, message, new_detailed_info):
"""
Updates the current status with the received detailed information:
@ -522,9 +509,9 @@ class MTProtoSender:
"""
# TODO https://goo.gl/G7DPsR
__log__.debug('Handling new detailed info')
self._pending_ack.add(reader.tgread_object().answer_msg_id)
self._pending_ack.add(new_detailed_info.answer_msg_id)
async def _handle_new_session_created(self, message, reader):
async def _handle_new_session_created(self, message, new_session):
"""
Updates the current status with the received session information:
@ -533,7 +520,7 @@ class MTProtoSender:
"""
# TODO https://goo.gl/LMyN7A
__log__.debug('Handling new session created')
self.state.salt = reader.tgread_object().server_salt
self.state.salt = new_session.server_salt
def _clean_containers(self, msg_ids):
"""
@ -552,7 +539,7 @@ class MTProtoSender:
del self._pending_messages[message.msg_id]
break
async def _handle_ack(self, message, reader):
async def _handle_ack(self, message, ack):
"""
Handles a server acknowledge about our messages. Normally
these can be ignored except in the case of ``auth.logOut``:
@ -568,7 +555,6 @@ class MTProtoSender:
messages are acknowledged.
"""
__log__.debug('Handling acknowledge')
ack = reader.tgread_object()
if self._pending_containers:
self._clean_containers(ack.msg_ids)
@ -578,7 +564,7 @@ class MTProtoSender:
del self._pending_messages[msg_id]
msg.future.set_result(True)
async def _handle_future_salts(self, message, reader):
async def _handle_future_salts(self, message, salts):
"""
Handles future salt results, which don't come inside a
``rpc_result`` but are still sent through a request:
@ -589,7 +575,6 @@ class MTProtoSender:
# TODO save these salts and automatically adjust to the
# correct one whenever the salt in use expires.
__log__.debug('Handling future salts')
salts = reader.tgread_object()
msg = self._pending_messages.pop(message.msg_id, None)
if msg:
msg.future.set_result(salts)

View File

@ -6,7 +6,7 @@ from hashlib import sha256
from ..crypto import AES
from ..errors import SecurityError, BrokenAuthKeyError
from ..extensions import BinaryReader
from ..tl import TLMessage
from ..tl.core import TLMessage
class MTProtoState:

View File

@ -1,4 +1 @@
from .tlobject import TLObject
from .gzip_packed import GzipPacked
from .tl_message import TLMessage
from .message_container import MessageContainer

View File

@ -0,0 +1,26 @@
"""
This module holds core "special" types, which are more convenient ways
to do stuff in a `telethon.network.mtprotosender.MTProtoSender` instance.
Only special cases are gzip-packed data, the response message (not a
client message), the message container which references these messages
and would otherwise conflict with the rest, and finally the RpcResult:
rpc_result#f35c6d01 req_msg_id:long result:bytes = RpcResult;
Three things to note with this definition:
1. The constructor ID is actually ``42d36c2c``.
2. Those bytes are not read like the rest of bytes (length + payload).
They are actually the raw bytes of another object, which can't be
read directly because it depends on per-request information (since
some can return ``Vector<int>`` and ``Vector<long>``).
3. Those bytes may be gzipped data, which needs to be treated early.
"""
from .tlmessage import TLMessage
from .gzippacked import GzipPacked
from .messagecontainer import MessageContainer
from .rpcresult import RpcResult
core_objects = {x.CONSTRUCTOR_ID: x for x in (
GzipPacked, MessageContainer, RpcResult
)}

View File

@ -1,7 +1,7 @@
import gzip
import struct
from . import TLObject
from .. import TLObject
class GzipPacked(TLObject):
@ -36,3 +36,7 @@ class GzipPacked(TLObject):
def read(reader):
assert reader.read_int(signed=False) == GzipPacked.CONSTRUCTOR_ID
return gzip.decompress(reader.tgread_bytes())
@classmethod
def from_reader(cls, reader):
return GzipPacked(gzip.decompress(reader.tgread_bytes()))

View File

@ -1,7 +1,7 @@
import struct
from . import TLObject
from .tl_message import TLMessage
from ..tlobject import TLObject
from .tlmessage import TLMessage
class MessageContainer(TLObject):
@ -42,3 +42,12 @@ class MessageContainer(TLObject):
def stringify(self):
return TLObject.pretty_format(self, indent=0)
@classmethod
def from_reader(cls, reader):
# This assumes that .read_* calls are done in the order they appear
return MessageContainer([TLMessage(
msg_id=reader.read_long(),
seq_no=reader.read_int(),
body=reader.read(reader.read_int())
) for _ in range(reader.read_int())])

View File

@ -0,0 +1,23 @@
from .gzippacked import GzipPacked
from ..types import RpcError
class RpcResult:
CONSTRUCTOR_ID = 0xf35c6d01
def __init__(self, req_msg_id, body, error):
self.req_msg_id = req_msg_id
self.body = body
self.error = error
@classmethod
def from_reader(cls, reader):
msg_id = reader.read_long()
inner_code = reader.read_int(signed=False)
if inner_code == RpcError.CONSTRUCTOR_ID:
return RpcResult(msg_id, None, RpcError.from_reader(reader))
if inner_code == GzipPacked.CONSTRUCTOR_ID:
return RpcResult(msg_id, GzipPacked.from_reader(reader).data, None)
reader.seek(-4)
return RpcResult(msg_id, reader.read(), None)

View File

@ -1,8 +1,9 @@
import asyncio
import struct
from . import TLObject, GzipPacked
from ..tl.functions import InvokeAfterMsgRequest
from .gzippacked import GzipPacked
from .. import TLObject
from ..functions import InvokeAfterMsgRequest
class TLMessage(TLObject):

View File

@ -49,7 +49,7 @@ new_session_created#9ec20908 first_msg_id:long unique_id:long server_salt:long =
//message msg_id:long seqno:int bytes:int body:bytes = Message;
//msg_copy#e06046b2 orig_message:Message = MessageCopy;
gzip_packed#3072cfa1 packed_data:bytes = Object;
//gzip_packed#3072cfa1 packed_data:bytes = Object;
msgs_ack#62d6b459 msg_ids:Vector<long> = MsgsAck;