mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-02-03 21:24:35 +03:00
Create a new layer to lift encryption off the MTProtoSender
This commit is contained in:
parent
5daad2aaab
commit
9402b4a26d
|
@ -224,10 +224,9 @@ class TelegramBaseClient(abc.ABC):
|
|||
)
|
||||
)
|
||||
|
||||
state = MTProtoState(self.session.auth_key)
|
||||
self._connection = connection
|
||||
self._sender = MTProtoSender(
|
||||
state, self._loop,
|
||||
self.session.auth_key, self._loop,
|
||||
retries=self._connection_retries,
|
||||
auto_reconnect=self._auto_reconnect,
|
||||
update_callback=self._handle_update,
|
||||
|
@ -413,12 +412,11 @@ class TelegramBaseClient(abc.ABC):
|
|||
# Thanks badoualy/kotlogram on /telegram/api/DefaultTelegramClient.kt
|
||||
# for clearly showing how to export the authorization
|
||||
dc = await self._get_dc(dc_id)
|
||||
state = MTProtoState(None)
|
||||
# Can't reuse self._sender._connection as it has its own seqno.
|
||||
#
|
||||
# If one were to do that, Telegram would reset the connection
|
||||
# with no further clues.
|
||||
sender = MTProtoSender(state, self._loop)
|
||||
sender = MTProtoSender(None, self._loop)
|
||||
await sender.connect(self._connection(
|
||||
dc.ip_address, dc.port, loop=self._loop))
|
||||
__log__.info('Exporting authorization for data center %s', dc)
|
||||
|
|
|
@ -108,3 +108,9 @@ class Connection(abc.ABC):
|
|||
the way it should be read from `self._reader`.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def __str__(self):
|
||||
return '{}:{}/{}'.format(
|
||||
self._ip, self._port,
|
||||
self.__class__.__name__.replace('Connection', '')
|
||||
)
|
||||
|
|
104
telethon/network/mtprotolayer.py
Normal file
104
telethon/network/mtprotolayer.py
Normal file
|
@ -0,0 +1,104 @@
|
|||
import io
|
||||
import struct
|
||||
|
||||
from .mtprotostate import MTProtoState
|
||||
from ..tl.core.messagecontainer import MessageContainer
|
||||
|
||||
|
||||
class MTProtoLayer:
|
||||
"""
|
||||
This class is the message encryption layer between the methods defined
|
||||
in the schema and the response objects. It also holds the necessary state
|
||||
necessary for this encryption to happen.
|
||||
|
||||
The `connection` parameter is through which these messages will be sent
|
||||
and received.
|
||||
|
||||
The `auth_key` must be a valid authorization key which will be used to
|
||||
encrypt these messages. This class is not responsible for generating them.
|
||||
"""
|
||||
def __init__(self, connection, auth_key):
|
||||
self._connection = connection
|
||||
self._state = MTProtoState(auth_key)
|
||||
|
||||
def connect(self):
|
||||
"""
|
||||
Wrapper for ``self._connection.connect()``.
|
||||
"""
|
||||
return self._connection.connect()
|
||||
|
||||
def disconnect(self):
|
||||
"""
|
||||
Wrapper for ``self._connection.disconnect()``.
|
||||
"""
|
||||
self._connection.disconnect()
|
||||
|
||||
async def send(self, data_list):
|
||||
"""
|
||||
A list of serialized RPC queries as bytes must be given to be sent.
|
||||
Nested lists imply an order is required for the messages in them.
|
||||
Message containers will be used if there is more than one item.
|
||||
|
||||
Returns ``(container_id, msg_ids)``.
|
||||
"""
|
||||
data, container_id, msg_ids = self._pack_data_list(data_list)
|
||||
await self._connection.send(self._state.encrypt_message_data(data))
|
||||
return container_id, msg_ids
|
||||
|
||||
async def recv(self):
|
||||
"""
|
||||
Reads a single message from the network, decrypts it and returns it.
|
||||
"""
|
||||
body = await self._connection.recv()
|
||||
return self._state.decrypt_message_data(body)
|
||||
|
||||
def _pack_data_list(self, data_list):
|
||||
"""
|
||||
A list of serialized RPC queries as bytes must be given to be packed.
|
||||
Nested lists imply an order is required for the messages in them.
|
||||
|
||||
Returns ``(data, container_id, msg_ids)``.
|
||||
"""
|
||||
# TODO write_data_as_message raises on invalid messages, handle it
|
||||
# TODO This method could be an iterator yielding messages while small
|
||||
# respecting the ``MessageContainer.MAXIMUM_SIZE`` limit.
|
||||
#
|
||||
# Note that the simplest case is writing a single query data into
|
||||
# a message, and returning the message data and ID. For efficiency
|
||||
# purposes this method supports more than one message and automatically
|
||||
# uses containers if deemed necessary.
|
||||
#
|
||||
# Technically the message and message container classes could be used
|
||||
# to store and serialize the data. However, to keep the context local
|
||||
# and relevant to the only place where such feature is actually used,
|
||||
# this is not done.
|
||||
msg_ids = []
|
||||
buffer = io.BytesIO()
|
||||
for data in data_list:
|
||||
if not isinstance(data, list):
|
||||
msg_ids.append(self._state.write_data_as_message(buffer, data))
|
||||
else:
|
||||
last_id = None
|
||||
for d in data:
|
||||
last_id = self._state.write_data_as_message(
|
||||
buffer, d, after_id=last_id)
|
||||
msg_ids.append(last_id)
|
||||
|
||||
if len(msg_ids) == 1:
|
||||
container_id = None
|
||||
else:
|
||||
# Inlined code to pack several messages into a container
|
||||
#
|
||||
# TODO This part and encrypting data prepend a few bytes but
|
||||
# force a potentially large payload to be appended, which
|
||||
# may be expensive. Can we do better?
|
||||
data = struct.pack(
|
||||
'<Ii', MessageContainer.CONSTRUCTOR_ID, len(msg_ids)
|
||||
) + buffer.getvalue()
|
||||
buffer = io.BytesIO()
|
||||
container_id = self._state.write_data_as_message(buffer, data)
|
||||
|
||||
return buffer.getvalue(), container_id, msg_ids
|
||||
|
||||
def __str__(self):
|
||||
return str(self._connection)
|
|
@ -1,7 +1,7 @@
|
|||
import asyncio
|
||||
import logging
|
||||
|
||||
from . import MTProtoPlainSender, authenticator
|
||||
from .mtprotolayer import MTProtoLayer
|
||||
from .. import utils
|
||||
from ..errors import (
|
||||
BadMessageError, TypeNotFoundError, BrokenAuthKeyError, SecurityError,
|
||||
|
@ -40,11 +40,11 @@ class MTProtoSender:
|
|||
A new authorization key will be generated on connection if no other
|
||||
key exists yet.
|
||||
"""
|
||||
def __init__(self, state, loop, *,
|
||||
def __init__(self, auth_key, loop, *,
|
||||
retries=5, auto_reconnect=True, update_callback=None,
|
||||
auth_key_callback=None, auto_reconnect_callback=None):
|
||||
self.state = state
|
||||
self._connection = None
|
||||
self._auth_key = auth_key
|
||||
self._connection = None # MTProtoLayer, a.k.a. encrypted connection
|
||||
self._loop = loop
|
||||
self._retries = retries
|
||||
self._auto_reconnect = auto_reconnect
|
||||
|
@ -118,7 +118,7 @@ class MTProtoSender:
|
|||
__log__.info('User is already connected!')
|
||||
return
|
||||
|
||||
self._connection = connection
|
||||
self._connection = MTProtoLayer(connection, self._auth_key)
|
||||
self._user_connected = True
|
||||
await self._connect()
|
||||
|
||||
|
@ -137,7 +137,7 @@ class MTProtoSender:
|
|||
await self._disconnect()
|
||||
|
||||
async def _disconnect(self, error=None):
|
||||
__log__.info('Disconnecting from %s...', self._connection._ip)
|
||||
__log__.info('Disconnecting from %s...', self._connection)
|
||||
self._user_connected = False
|
||||
try:
|
||||
__log__.debug('Closing current connection...')
|
||||
|
@ -163,7 +163,7 @@ class MTProtoSender:
|
|||
__log__.debug('Cancelling the receive loop...')
|
||||
self._recv_loop_handle.cancel()
|
||||
|
||||
__log__.info('Disconnection from %s complete!', self._connection._ip)
|
||||
__log__.info('Disconnection from %s complete!', self._connection)
|
||||
if self._disconnected and not self._disconnected.done():
|
||||
if error:
|
||||
self._disconnected.set_exception(error)
|
||||
|
@ -235,8 +235,7 @@ class MTProtoSender:
|
|||
authorization key if necessary, and starting the send and
|
||||
receive loops.
|
||||
"""
|
||||
__log__.info('Connecting to %s:%d...',
|
||||
self._connection._ip, self._connection._port)
|
||||
__log__.info('Connecting to %s...', self._connection)
|
||||
for retry in range(1, self._retries + 1):
|
||||
try:
|
||||
__log__.debug('Connection attempt {}...'.format(retry))
|
||||
|
@ -251,6 +250,8 @@ class MTProtoSender:
|
|||
.format(self._retries))
|
||||
|
||||
__log__.debug('Connection success!')
|
||||
# TODO Handle this, maybe an empty MTProtoState that does no encryption
|
||||
"""
|
||||
if self.state.auth_key is None:
|
||||
plain = MTProtoPlainSender(self._connection)
|
||||
for retry in range(1, self._retries + 1):
|
||||
|
@ -271,6 +272,7 @@ class MTProtoSender:
|
|||
.format(self._retries))
|
||||
await self._disconnect(error=e)
|
||||
raise e
|
||||
"""
|
||||
|
||||
__log__.debug('Starting send loop')
|
||||
self._send_loop_handle = self._loop.create_task(self._send_loop())
|
||||
|
@ -281,7 +283,7 @@ class MTProtoSender:
|
|||
# First connection or manual reconnection after a failure
|
||||
if self._disconnected is None or self._disconnected.done():
|
||||
self._disconnected = self._loop.create_future()
|
||||
__log__.info('Connection to %s complete!', self._connection._ip)
|
||||
__log__.info('Connection to %s complete!', self._connection)
|
||||
|
||||
async def _reconnect(self):
|
||||
"""
|
||||
|
|
|
@ -8,7 +8,8 @@ from ..crypto import AES
|
|||
from ..errors import SecurityError, BrokenAuthKeyError
|
||||
from ..extensions import BinaryReader
|
||||
from ..tl.core import TLMessage
|
||||
from ..tl.tlobject import TLRequest
|
||||
from ..tl.functions import InvokeAfterMsgRequest
|
||||
from ..tl.core.gzippacked import GzipPacked
|
||||
|
||||
__log__ = logging.getLogger(__name__)
|
||||
|
||||
|
@ -37,20 +38,6 @@ class MTProtoState:
|
|||
self._sequence = 0
|
||||
self._last_msg_id = 0
|
||||
|
||||
def create_message(self, obj, *, loop, after=None):
|
||||
"""
|
||||
Creates a new `telethon.tl.tl_message.TLMessage` from
|
||||
the given `telethon.tl.tlobject.TLObject` instance.
|
||||
"""
|
||||
return TLMessage(
|
||||
msg_id=self._get_new_msg_id(),
|
||||
seq_no=self._get_seq_no(isinstance(obj, TLRequest)),
|
||||
obj=obj,
|
||||
after_id=after.msg_id if after else None,
|
||||
out=True, # Pre-convert the request into bytes
|
||||
loop=loop
|
||||
)
|
||||
|
||||
def update_message_id(self, message):
|
||||
"""
|
||||
Updates the message ID to a new one,
|
||||
|
@ -74,14 +61,30 @@ class MTProtoState:
|
|||
|
||||
return aes_key, aes_iv
|
||||
|
||||
def pack_message(self, message):
|
||||
def write_data_as_message(self, buffer, data, after_id=None):
|
||||
"""
|
||||
Packs the given `telethon.tl.tl_message.TLMessage` using the
|
||||
current authorization key following MTProto 2.0 guidelines.
|
||||
Writes a message containing the given data into buffer.
|
||||
|
||||
See https://core.telegram.org/mtproto/description.
|
||||
Returns the message id.
|
||||
"""
|
||||
data = struct.pack('<qq', self.salt, self.id) + bytes(message)
|
||||
msg_id = self._get_new_msg_id()
|
||||
seq_no = self._get_seq_no(True) # TODO ack/ping are not content-related
|
||||
if after_id is None:
|
||||
body = GzipPacked.gzip_if_smaller(data)
|
||||
else:
|
||||
body = GzipPacked.gzip_if_smaller(
|
||||
bytes(InvokeAfterMsgRequest(after_id, data)))
|
||||
|
||||
buffer.write(struct.pack('<qii', msg_id, seq_no, len(body)))
|
||||
buffer.write(body)
|
||||
return msg_id
|
||||
|
||||
def encrypt_message_data(self, data):
|
||||
"""
|
||||
Encrypts the given message data using the current authorization key
|
||||
following MTProto 2.0 guidelines core.telegram.org/mtproto/description.
|
||||
"""
|
||||
data = struct.pack('<qq', self.salt, self.id) + data
|
||||
padding = os.urandom(-(len(data) + 12) % 16 + 12)
|
||||
|
||||
# Being substr(what, offset, length); x = 0 for client
|
||||
|
@ -97,11 +100,12 @@ class MTProtoState:
|
|||
return (key_id + msg_key +
|
||||
AES.encrypt_ige(data + padding, aes_key, aes_iv))
|
||||
|
||||
def unpack_message(self, body):
|
||||
def decrypt_message_data(self, body):
|
||||
"""
|
||||
Inverse of `pack_message` for incoming server messages.
|
||||
Inverse of `encrypt_message_data` for incoming server messages.
|
||||
"""
|
||||
if len(body) < 8:
|
||||
# TODO If len == 4, raise HTTPErrorCode(-little endian int)
|
||||
if body == b'l\xfe\xff\xff':
|
||||
raise BrokenAuthKeyError()
|
||||
else:
|
||||
|
@ -136,7 +140,7 @@ class MTProtoState:
|
|||
# reader isn't used for anything else after this, it's unnecessary.
|
||||
obj = reader.tgread_object()
|
||||
|
||||
return TLMessage(remote_msg_id, remote_sequence, obj, loop=None)
|
||||
return TLMessage(remote_msg_id, remote_sequence, obj)
|
||||
|
||||
def _get_new_msg_id(self):
|
||||
"""
|
||||
|
|
|
@ -11,15 +11,15 @@ class GzipPacked(TLObject):
|
|||
self.data = data
|
||||
|
||||
@staticmethod
|
||||
def gzip_if_smaller(request):
|
||||
def gzip_if_smaller(data):
|
||||
"""Calls bytes(request), and based on a certain threshold,
|
||||
optionally gzips the resulting data. If the gzipped data is
|
||||
smaller than the original byte array, this is returned instead.
|
||||
|
||||
Note that this only applies to content related requests.
|
||||
"""
|
||||
data = bytes(request)
|
||||
if isinstance(request, TLRequest) and len(data) > 512:
|
||||
# TODO Only content-related requests should be gzipped
|
||||
if len(data) > 512:
|
||||
gzipped = bytes(GzipPacked(data))
|
||||
return gzipped if len(gzipped) < len(data) else data
|
||||
else:
|
||||
|
|
|
@ -27,11 +27,6 @@ class MessageContainer(TLObject):
|
|||
],
|
||||
}
|
||||
|
||||
def __bytes__(self):
|
||||
return struct.pack(
|
||||
'<Ii', MessageContainer.CONSTRUCTOR_ID, len(self.messages)
|
||||
) + b''.join(bytes(m) for m in self.messages)
|
||||
|
||||
@classmethod
|
||||
def from_reader(cls, reader):
|
||||
# This assumes that .read_* calls are done in the order they appear
|
||||
|
@ -43,5 +38,5 @@ class MessageContainer(TLObject):
|
|||
before = reader.tell_position()
|
||||
obj = reader.tgread_object() # May over-read e.g. RpcResult
|
||||
reader.set_position(before + length)
|
||||
messages.append(TLMessage(msg_id, seq_no, obj, loop=None))
|
||||
messages.append(TLMessage(msg_id, seq_no, obj))
|
||||
return MessageContainer(messages)
|
||||
|
|
|
@ -1,10 +1,6 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import struct
|
||||
|
||||
from .gzippacked import GzipPacked
|
||||
from .. import TLObject
|
||||
from ..functions import InvokeAfterMsgRequest
|
||||
|
||||
__log__ = logging.getLogger(__name__)
|
||||
|
||||
|
@ -17,83 +13,24 @@ class TLMessage(TLObject):
|
|||
message msg_id:long seqno:int bytes:int body:bytes = Message;
|
||||
|
||||
Each message has its own unique identifier, and the body is simply
|
||||
the serialized request that should be executed on the server. Then
|
||||
Telegram will, at some point, respond with the result for this msg.
|
||||
the serialized request that should be executed on the server, or
|
||||
the response object from Telegram. Since the body is always a valid
|
||||
object, it makes sense to store the object and not the bytes to
|
||||
ease working with them.
|
||||
|
||||
Thus it makes sense that requests and their result are bound to a
|
||||
sent `TLMessage`, and this result can be represented as a `Future`
|
||||
that will eventually be set with either a result, error or cancelled.
|
||||
There is no need to add serializing logic here since that can be
|
||||
inlined and is unlikely to change. Thus these are only needed to
|
||||
encapsulate responses.
|
||||
"""
|
||||
def __init__(self, msg_id, seq_no, obj, *, loop, out=False, after_id=0):
|
||||
def __init__(self, msg_id, seq_no, obj):
|
||||
self.msg_id = msg_id
|
||||
self.seq_no = seq_no
|
||||
self.obj = obj
|
||||
self.container_msg_id = None
|
||||
|
||||
# If no loop is given then it is an incoming message.
|
||||
# Only outgoing messages need the future to await them.
|
||||
self.future = loop.create_future() if loop else None
|
||||
|
||||
# After which message ID this one should run. We do this so
|
||||
# InvokeAfterMsgRequest is transparent to the user and we can
|
||||
# easily invoke after while confirming the original request.
|
||||
# TODO Currently we don't update this if another message ID changes
|
||||
self.after_id = after_id
|
||||
|
||||
# There are two use-cases for the TLMessage, outgoing and incoming.
|
||||
# Outgoing messages are meant to be serialized and sent across the
|
||||
# network so it makes sense to pack them as early as possible and
|
||||
# avoid this computation if it needs to be resent, and also shows
|
||||
# serializing-errors as early as possible (foreground task).
|
||||
#
|
||||
# We assume obj won't change so caching the bytes is safe to do.
|
||||
# Caching bytes lets us get the size in a fast way, necessary for
|
||||
# knowing whether a container can be sent (<1MB) or not (too big).
|
||||
#
|
||||
# Incoming messages don't really need this body, but we save the
|
||||
# msg_id and seq_no inside the body for consistency and raise if
|
||||
# one tries to bytes()-ify the entire message (len == 12).
|
||||
if not out:
|
||||
self._body = struct.pack('<qi', msg_id, seq_no)
|
||||
else:
|
||||
try:
|
||||
if self.after_id is None:
|
||||
body = GzipPacked.gzip_if_smaller(self.obj)
|
||||
else:
|
||||
body = GzipPacked.gzip_if_smaller(
|
||||
InvokeAfterMsgRequest(self.after_id, self.obj))
|
||||
except Exception:
|
||||
# struct.pack doesn't give a lot of information about
|
||||
# why it may fail so log the exception AND the object
|
||||
__log__.exception('Failed to pack %s', self.obj)
|
||||
raise
|
||||
|
||||
self._body = struct.pack('<qii', msg_id, seq_no, len(body)) + body
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'_': 'TLMessage',
|
||||
'msg_id': self.msg_id,
|
||||
'seq_no': self.seq_no,
|
||||
'obj': self.obj,
|
||||
'container_msg_id': self.container_msg_id
|
||||
'obj': self.obj
|
||||
}
|
||||
|
||||
@property
|
||||
def msg_id(self):
|
||||
return struct.unpack('<q', self._body[:8])[0]
|
||||
|
||||
@msg_id.setter
|
||||
def msg_id(self, value):
|
||||
self._body = struct.pack('<q', value) + self._body[8:]
|
||||
|
||||
@property
|
||||
def seq_no(self):
|
||||
return struct.unpack('<i', self._body[8:12])[0]
|
||||
|
||||
def __bytes__(self):
|
||||
if len(self._body) == 12: # msg_id, seqno
|
||||
raise TypeError('Incoming messages should not be bytes()-ed')
|
||||
|
||||
return self._body
|
||||
|
||||
def size(self):
|
||||
return len(self._body)
|
||||
|
|
Loading…
Reference in New Issue
Block a user