mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-02-16 19:41:07 +03:00
Create RpcResult class and generalise core special cases
This results in a cleaner MTProtoSender, which now can always read a TLObject with a guaranteed item, if the message is OK.
This commit is contained in:
parent
1e66cea9b7
commit
f7e8907c6f
|
@ -40,49 +40,49 @@ def report_error(code, message, report_method):
|
|||
"We really don't want to crash when just reporting an error"
|
||||
|
||||
|
||||
def rpc_message_to_error(code, message, report_method=None):
|
||||
def rpc_message_to_error(rpc_error, report_method=None):
|
||||
"""
|
||||
Converts a Telegram's RPC Error to a Python error.
|
||||
|
||||
:param code: the integer code of the error (like 400).
|
||||
:param message: the message representing the error.
|
||||
:param rpc_error: the RpcError instance.
|
||||
:param report_method: if present, the ID of the method that caused it.
|
||||
:return: the RPCError as a Python exception that represents this error.
|
||||
"""
|
||||
if report_method is not None:
|
||||
Thread(
|
||||
target=report_error,
|
||||
args=(code, message, report_method)
|
||||
args=(rpc_error.error_code, rpc_error.error_message, report_method)
|
||||
).start()
|
||||
|
||||
# Try to get the error by direct look-up, otherwise regex
|
||||
# TODO Maybe regexes could live in a separate dictionary?
|
||||
cls = rpc_errors_all.get(message, None)
|
||||
cls = rpc_errors_all.get(rpc_error.error_message, None)
|
||||
if cls:
|
||||
return cls()
|
||||
|
||||
for msg_regex, cls in rpc_errors_all.items():
|
||||
m = re.match(msg_regex, message)
|
||||
m = re.match(msg_regex, rpc_error.error_message)
|
||||
if m:
|
||||
capture = int(m.group(1)) if m.groups() else None
|
||||
return cls(capture=capture)
|
||||
|
||||
if code == 400:
|
||||
return BadRequestError(message)
|
||||
if rpc_error.error_code == 400:
|
||||
return BadRequestError(rpc_error.error_message)
|
||||
|
||||
if code == 401:
|
||||
return UnauthorizedError(message)
|
||||
if rpc_error.error_code == 401:
|
||||
return UnauthorizedError(rpc_error.error_message)
|
||||
|
||||
if code == 403:
|
||||
return ForbiddenError(message)
|
||||
if rpc_error.error_code == 403:
|
||||
return ForbiddenError(rpc_error.error_message)
|
||||
|
||||
if code == 404:
|
||||
return NotFoundError(message)
|
||||
if rpc_error.error_code == 404:
|
||||
return NotFoundError(rpc_error.error_message)
|
||||
|
||||
if code == 406:
|
||||
return AuthKeyError(message)
|
||||
if rpc_error.error_code == 406:
|
||||
return AuthKeyError(rpc_error.error_message)
|
||||
|
||||
if code == 500:
|
||||
return ServerError(message)
|
||||
if rpc_error.error_code == 500:
|
||||
return ServerError(rpc_error.error_message)
|
||||
|
||||
return RPCError('{} (code {})'.format(message, code))
|
||||
return RPCError('{} (code {})'.format(
|
||||
rpc_error.error_message, rpc_error.error_code))
|
||||
|
|
|
@ -8,6 +8,7 @@ from struct import unpack
|
|||
|
||||
from ..errors import TypeNotFoundError
|
||||
from ..tl.all_tlobjects import tlobjects
|
||||
from ..tl.core import core_objects
|
||||
|
||||
|
||||
class BinaryReader:
|
||||
|
@ -136,9 +137,11 @@ class BinaryReader:
|
|||
elif value == 0x1cb5c415: # Vector
|
||||
return [self.tgread_object() for _ in range(self.read_int())]
|
||||
|
||||
# If there was still no luck, give up
|
||||
self.seek(-4) # Go back
|
||||
raise TypeNotFoundError(constructor_id)
|
||||
clazz = core_objects.get(constructor_id, None)
|
||||
if clazz is None:
|
||||
# If there was still no luck, give up
|
||||
self.seek(-4) # Go back
|
||||
raise TypeNotFoundError(constructor_id)
|
||||
|
||||
return clazz.from_reader(self)
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ from ..errors import (
|
|||
rpc_message_to_error
|
||||
)
|
||||
from ..extensions import BinaryReader
|
||||
from ..tl import MessageContainer, GzipPacked
|
||||
from ..tl.core import RpcResult, MessageContainer, GzipPacked
|
||||
from ..tl.functions.auth import LogOutRequest
|
||||
from ..tl.types import (
|
||||
MsgsAck, Pong, BadServerSalt, BadMsgNotification, FutureSalts,
|
||||
|
@ -80,7 +80,7 @@ class MTProtoSender:
|
|||
|
||||
# Jump table from response ID to method that handles it
|
||||
self._handlers = {
|
||||
0xf35c6d01: self._handle_rpc_result,
|
||||
RpcResult.CONSTRUCTOR_ID: self._handle_rpc_result,
|
||||
MessageContainer.CONSTRUCTOR_ID: self._handle_container,
|
||||
GzipPacked.CONSTRUCTOR_ID: self._handle_gzip_packed,
|
||||
Pong.CONSTRUCTOR_ID: self._handle_pong,
|
||||
|
@ -354,26 +354,26 @@ class MTProtoSender:
|
|||
else:
|
||||
try:
|
||||
with BinaryReader(message.body) as reader:
|
||||
await self._process_message(message, reader)
|
||||
obj = reader.tgread_object()
|
||||
except TypeNotFoundError as e:
|
||||
__log__.warning('Could not decode received message: {}, '
|
||||
'raw bytes: {!r}'.format(e, message))
|
||||
else:
|
||||
await self._process_message(message, obj)
|
||||
|
||||
# Response Handlers
|
||||
|
||||
async def _process_message(self, message, reader):
|
||||
async def _process_message(self, message, obj):
|
||||
"""
|
||||
Adds the given message to the list of messages that must be
|
||||
acknowledged and dispatches control to different ``_handle_*``
|
||||
method based on its type.
|
||||
"""
|
||||
self._pending_ack.add(message.msg_id)
|
||||
code = reader.read_int(signed=False)
|
||||
reader.seek(-4)
|
||||
handler = self._handlers.get(code, self._handle_update)
|
||||
await handler(message, reader)
|
||||
handler = self._handlers.get(obj.CONSTRUCTOR_ID, self._handle_update)
|
||||
await handler(message, obj)
|
||||
|
||||
async def _handle_rpc_result(self, message, reader):
|
||||
async def _handle_rpc_result(self, message, rpc_result):
|
||||
"""
|
||||
Handles the result for Remote Procedure Calls:
|
||||
|
||||
|
@ -381,20 +381,13 @@ class MTProtoSender:
|
|||
|
||||
This is where the future results for sent requests are set.
|
||||
"""
|
||||
# TODO Don't make this a special cased object
|
||||
reader.read_int(signed=False) # code
|
||||
message_id = reader.read_long()
|
||||
inner_code = reader.read_int(signed=False)
|
||||
reader.seek(-4)
|
||||
message = self._pending_messages.pop(rpc_result.req_msg_id, None)
|
||||
__log__.debug('Handling RPC result for message {}'
|
||||
.format(rpc_result.req_msg_id))
|
||||
|
||||
__log__.debug('Handling RPC result for message {}'.format(message_id))
|
||||
message = self._pending_messages.pop(message_id, None)
|
||||
if inner_code == 0x2144ca19: # RPC Error
|
||||
if rpc_result.error:
|
||||
# TODO Report errors if possible/enabled
|
||||
reader.seek(4)
|
||||
error = rpc_message_to_error(reader.read_int(),
|
||||
reader.tgread_string())
|
||||
|
||||
error = rpc_message_to_error(rpc_result.error)
|
||||
await self._send_queue.put(self.state.create_message(
|
||||
MsgsAck([message.msg_id])
|
||||
))
|
||||
|
@ -403,10 +396,7 @@ class MTProtoSender:
|
|||
message.future.set_exception(error)
|
||||
return
|
||||
elif message:
|
||||
if inner_code == GzipPacked.CONSTRUCTOR_ID:
|
||||
with BinaryReader(GzipPacked.read(reader)) as compressed_reader:
|
||||
result = message.request.read_result(compressed_reader)
|
||||
else:
|
||||
with BinaryReader(rpc_result.body) as reader:
|
||||
result = message.request.read_result(reader)
|
||||
|
||||
# TODO Process entities
|
||||
|
@ -416,37 +406,37 @@ class MTProtoSender:
|
|||
else:
|
||||
# TODO We should not get responses to things we never sent
|
||||
__log__.info('Received response without parent request: {}'
|
||||
.format(reader.tgread_object()))
|
||||
.format(rpc_result.body))
|
||||
|
||||
async def _handle_container(self, message, reader):
|
||||
async def _handle_container(self, message, container):
|
||||
"""
|
||||
Processes the inner messages of a container with many of them:
|
||||
|
||||
msg_container#73f1f8dc messages:vector<%Message> = MessageContainer;
|
||||
"""
|
||||
__log__.debug('Handling container')
|
||||
for inner_message in MessageContainer.iter_read(reader):
|
||||
with BinaryReader(inner_message.body) as inner_reader:
|
||||
await self._process_message(inner_message, inner_reader)
|
||||
for inner_message in container.messages:
|
||||
with BinaryReader(inner_message.body) as reader:
|
||||
inner_obj = reader.tgread_object()
|
||||
await self._process_message(inner_message, inner_obj)
|
||||
|
||||
async def _handle_gzip_packed(self, message, reader):
|
||||
async def _handle_gzip_packed(self, message, gzip_packed):
|
||||
"""
|
||||
Unpacks the data from a gzipped object and processes it:
|
||||
|
||||
gzip_packed#3072cfa1 packed_data:bytes = Object;
|
||||
"""
|
||||
__log__.debug('Handling gzipped data')
|
||||
with BinaryReader(GzipPacked.read(reader)) as compressed_reader:
|
||||
await self._process_message(message, compressed_reader)
|
||||
with BinaryReader(gzip_packed.data) as reader:
|
||||
await self._process_message(message, reader.tgread_object())
|
||||
|
||||
async def _handle_update(self, message, reader):
|
||||
obj = reader.tgread_object()
|
||||
__log__.debug('Handling update {}'.format(obj.__class__.__name__))
|
||||
async def _handle_update(self, message, update):
|
||||
__log__.debug('Handling update {}'.format(update.__class__.__name__))
|
||||
|
||||
# TODO Further handling of the update
|
||||
# TODO Process entities
|
||||
|
||||
async def _handle_pong(self, message, reader):
|
||||
async def _handle_pong(self, message, pong):
|
||||
"""
|
||||
Handles pong results, which don't come inside a ``rpc_result``
|
||||
but are still sent through a request:
|
||||
|
@ -454,12 +444,11 @@ class MTProtoSender:
|
|||
pong#347773c5 msg_id:long ping_id:long = Pong;
|
||||
"""
|
||||
__log__.debug('Handling pong')
|
||||
pong = reader.tgread_object()
|
||||
message = self._pending_messages.pop(pong.msg_id, None)
|
||||
if message:
|
||||
message.future.set_result(pong)
|
||||
|
||||
async def _handle_bad_server_salt(self, message, reader):
|
||||
async def _handle_bad_server_salt(self, message, bad_salt):
|
||||
"""
|
||||
Corrects the currently used server salt to use the right value
|
||||
before enqueuing the rejected message to be re-sent:
|
||||
|
@ -468,11 +457,10 @@ class MTProtoSender:
|
|||
error_code:int new_server_salt:long = BadMsgNotification;
|
||||
"""
|
||||
__log__.debug('Handling bad salt')
|
||||
bad_salt = reader.tgread_object()
|
||||
self.state.salt = bad_salt.new_server_salt
|
||||
await self._send_queue.put(self._pending_messages[bad_salt.bad_msg_id])
|
||||
|
||||
async def _handle_bad_notification(self, message, reader):
|
||||
async def _handle_bad_notification(self, message, bad_msg):
|
||||
"""
|
||||
Adjusts the current state to be correct based on the
|
||||
received bad message notification whenever possible:
|
||||
|
@ -481,7 +469,6 @@ class MTProtoSender:
|
|||
error_code:int = BadMsgNotification;
|
||||
"""
|
||||
__log__.debug('Handling bad message')
|
||||
bad_msg = reader.tgread_object()
|
||||
if bad_msg.error_code in (16, 17):
|
||||
# Sent msg_id too low or too high (respectively).
|
||||
# Use the current msg_id to determine the right time offset.
|
||||
|
@ -502,7 +489,7 @@ class MTProtoSender:
|
|||
# Messages are to be re-sent once we've corrected the issue
|
||||
await self._send_queue.put(self._pending_messages[bad_msg.bad_msg_id])
|
||||
|
||||
async def _handle_detailed_info(self, message, reader):
|
||||
async def _handle_detailed_info(self, message, detailed_info):
|
||||
"""
|
||||
Updates the current status with the received detailed information:
|
||||
|
||||
|
@ -511,9 +498,9 @@ class MTProtoSender:
|
|||
"""
|
||||
# TODO https://goo.gl/VvpCC6
|
||||
__log__.debug('Handling detailed info')
|
||||
self._pending_ack.add(reader.tgread_object().answer_msg_id)
|
||||
self._pending_ack.add(detailed_info.answer_msg_id)
|
||||
|
||||
async def _handle_new_detailed_info(self, message, reader):
|
||||
async def _handle_new_detailed_info(self, message, new_detailed_info):
|
||||
"""
|
||||
Updates the current status with the received detailed information:
|
||||
|
||||
|
@ -522,9 +509,9 @@ class MTProtoSender:
|
|||
"""
|
||||
# TODO https://goo.gl/G7DPsR
|
||||
__log__.debug('Handling new detailed info')
|
||||
self._pending_ack.add(reader.tgread_object().answer_msg_id)
|
||||
self._pending_ack.add(new_detailed_info.answer_msg_id)
|
||||
|
||||
async def _handle_new_session_created(self, message, reader):
|
||||
async def _handle_new_session_created(self, message, new_session):
|
||||
"""
|
||||
Updates the current status with the received session information:
|
||||
|
||||
|
@ -533,7 +520,7 @@ class MTProtoSender:
|
|||
"""
|
||||
# TODO https://goo.gl/LMyN7A
|
||||
__log__.debug('Handling new session created')
|
||||
self.state.salt = reader.tgread_object().server_salt
|
||||
self.state.salt = new_session.server_salt
|
||||
|
||||
def _clean_containers(self, msg_ids):
|
||||
"""
|
||||
|
@ -552,7 +539,7 @@ class MTProtoSender:
|
|||
del self._pending_messages[message.msg_id]
|
||||
break
|
||||
|
||||
async def _handle_ack(self, message, reader):
|
||||
async def _handle_ack(self, message, ack):
|
||||
"""
|
||||
Handles a server acknowledge about our messages. Normally
|
||||
these can be ignored except in the case of ``auth.logOut``:
|
||||
|
@ -568,7 +555,6 @@ class MTProtoSender:
|
|||
messages are acknowledged.
|
||||
"""
|
||||
__log__.debug('Handling acknowledge')
|
||||
ack = reader.tgread_object()
|
||||
if self._pending_containers:
|
||||
self._clean_containers(ack.msg_ids)
|
||||
|
||||
|
@ -578,7 +564,7 @@ class MTProtoSender:
|
|||
del self._pending_messages[msg_id]
|
||||
msg.future.set_result(True)
|
||||
|
||||
async def _handle_future_salts(self, message, reader):
|
||||
async def _handle_future_salts(self, message, salts):
|
||||
"""
|
||||
Handles future salt results, which don't come inside a
|
||||
``rpc_result`` but are still sent through a request:
|
||||
|
@ -589,7 +575,6 @@ class MTProtoSender:
|
|||
# TODO save these salts and automatically adjust to the
|
||||
# correct one whenever the salt in use expires.
|
||||
__log__.debug('Handling future salts')
|
||||
salts = reader.tgread_object()
|
||||
msg = self._pending_messages.pop(message.msg_id, None)
|
||||
if msg:
|
||||
msg.future.set_result(salts)
|
||||
|
|
|
@ -6,7 +6,7 @@ from hashlib import sha256
|
|||
from ..crypto import AES
|
||||
from ..errors import SecurityError, BrokenAuthKeyError
|
||||
from ..extensions import BinaryReader
|
||||
from ..tl import TLMessage
|
||||
from ..tl.core import TLMessage
|
||||
|
||||
|
||||
class MTProtoState:
|
||||
|
|
|
@ -1,4 +1 @@
|
|||
from .tlobject import TLObject
|
||||
from .gzip_packed import GzipPacked
|
||||
from .tl_message import TLMessage
|
||||
from .message_container import MessageContainer
|
||||
|
|
26
telethon/tl/core/__init__.py
Normal file
26
telethon/tl/core/__init__.py
Normal file
|
@ -0,0 +1,26 @@
|
|||
"""
|
||||
This module holds core "special" types, which are more convenient ways
|
||||
to do stuff in a `telethon.network.mtprotosender.MTProtoSender` instance.
|
||||
|
||||
Only special cases are gzip-packed data, the response message (not a
|
||||
client message), the message container which references these messages
|
||||
and would otherwise conflict with the rest, and finally the RpcResult:
|
||||
|
||||
rpc_result#f35c6d01 req_msg_id:long result:bytes = RpcResult;
|
||||
|
||||
Three things to note with this definition:
|
||||
1. The constructor ID is actually ``42d36c2c``.
|
||||
2. Those bytes are not read like the rest of bytes (length + payload).
|
||||
They are actually the raw bytes of another object, which can't be
|
||||
read directly because it depends on per-request information (since
|
||||
some can return ``Vector<int>`` and ``Vector<long>``).
|
||||
3. Those bytes may be gzipped data, which needs to be treated early.
|
||||
"""
|
||||
from .tlmessage import TLMessage
|
||||
from .gzippacked import GzipPacked
|
||||
from .messagecontainer import MessageContainer
|
||||
from .rpcresult import RpcResult
|
||||
|
||||
core_objects = {x.CONSTRUCTOR_ID: x for x in (
|
||||
GzipPacked, MessageContainer, RpcResult
|
||||
)}
|
|
@ -1,7 +1,7 @@
|
|||
import gzip
|
||||
import struct
|
||||
|
||||
from . import TLObject
|
||||
from .. import TLObject
|
||||
|
||||
|
||||
class GzipPacked(TLObject):
|
||||
|
@ -36,3 +36,7 @@ class GzipPacked(TLObject):
|
|||
def read(reader):
|
||||
assert reader.read_int(signed=False) == GzipPacked.CONSTRUCTOR_ID
|
||||
return gzip.decompress(reader.tgread_bytes())
|
||||
|
||||
@classmethod
|
||||
def from_reader(cls, reader):
|
||||
return GzipPacked(gzip.decompress(reader.tgread_bytes()))
|
|
@ -1,7 +1,7 @@
|
|||
import struct
|
||||
|
||||
from . import TLObject
|
||||
from .tl_message import TLMessage
|
||||
from ..tlobject import TLObject
|
||||
from .tlmessage import TLMessage
|
||||
|
||||
|
||||
class MessageContainer(TLObject):
|
||||
|
@ -42,3 +42,12 @@ class MessageContainer(TLObject):
|
|||
|
||||
def stringify(self):
|
||||
return TLObject.pretty_format(self, indent=0)
|
||||
|
||||
@classmethod
|
||||
def from_reader(cls, reader):
|
||||
# This assumes that .read_* calls are done in the order they appear
|
||||
return MessageContainer([TLMessage(
|
||||
msg_id=reader.read_long(),
|
||||
seq_no=reader.read_int(),
|
||||
body=reader.read(reader.read_int())
|
||||
) for _ in range(reader.read_int())])
|
23
telethon/tl/core/rpcresult.py
Normal file
23
telethon/tl/core/rpcresult.py
Normal file
|
@ -0,0 +1,23 @@
|
|||
from .gzippacked import GzipPacked
|
||||
from ..types import RpcError
|
||||
|
||||
|
||||
class RpcResult:
|
||||
CONSTRUCTOR_ID = 0xf35c6d01
|
||||
|
||||
def __init__(self, req_msg_id, body, error):
|
||||
self.req_msg_id = req_msg_id
|
||||
self.body = body
|
||||
self.error = error
|
||||
|
||||
@classmethod
|
||||
def from_reader(cls, reader):
|
||||
msg_id = reader.read_long()
|
||||
inner_code = reader.read_int(signed=False)
|
||||
if inner_code == RpcError.CONSTRUCTOR_ID:
|
||||
return RpcResult(msg_id, None, RpcError.from_reader(reader))
|
||||
if inner_code == GzipPacked.CONSTRUCTOR_ID:
|
||||
return RpcResult(msg_id, GzipPacked.from_reader(reader).data, None)
|
||||
|
||||
reader.seek(-4)
|
||||
return RpcResult(msg_id, reader.read(), None)
|
|
@ -1,8 +1,9 @@
|
|||
import asyncio
|
||||
import struct
|
||||
|
||||
from . import TLObject, GzipPacked
|
||||
from ..tl.functions import InvokeAfterMsgRequest
|
||||
from .gzippacked import GzipPacked
|
||||
from .. import TLObject
|
||||
from ..functions import InvokeAfterMsgRequest
|
||||
|
||||
|
||||
class TLMessage(TLObject):
|
|
@ -49,7 +49,7 @@ new_session_created#9ec20908 first_msg_id:long unique_id:long server_salt:long =
|
|||
//message msg_id:long seqno:int bytes:int body:bytes = Message;
|
||||
//msg_copy#e06046b2 orig_message:Message = MessageCopy;
|
||||
|
||||
gzip_packed#3072cfa1 packed_data:bytes = Object;
|
||||
//gzip_packed#3072cfa1 packed_data:bytes = Object;
|
||||
|
||||
msgs_ack#62d6b459 msg_ids:Vector<long> = MsgsAck;
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user