Create a new layer to lift encryption off the MTProtoSender

This commit is contained in:
Lonami Exo 2018-09-29 10:58:45 +02:00
parent 5daad2aaab
commit 9402b4a26d
8 changed files with 166 additions and 120 deletions

View File

@ -224,10 +224,9 @@ class TelegramBaseClient(abc.ABC):
) )
) )
state = MTProtoState(self.session.auth_key)
self._connection = connection self._connection = connection
self._sender = MTProtoSender( self._sender = MTProtoSender(
state, self._loop, self.session.auth_key, self._loop,
retries=self._connection_retries, retries=self._connection_retries,
auto_reconnect=self._auto_reconnect, auto_reconnect=self._auto_reconnect,
update_callback=self._handle_update, update_callback=self._handle_update,
@ -413,12 +412,11 @@ class TelegramBaseClient(abc.ABC):
# Thanks badoualy/kotlogram on /telegram/api/DefaultTelegramClient.kt # Thanks badoualy/kotlogram on /telegram/api/DefaultTelegramClient.kt
# for clearly showing how to export the authorization # for clearly showing how to export the authorization
dc = await self._get_dc(dc_id) dc = await self._get_dc(dc_id)
state = MTProtoState(None)
# Can't reuse self._sender._connection as it has its own seqno. # Can't reuse self._sender._connection as it has its own seqno.
# #
# If one were to do that, Telegram would reset the connection # If one were to do that, Telegram would reset the connection
# with no further clues. # with no further clues.
sender = MTProtoSender(state, self._loop) sender = MTProtoSender(None, self._loop)
await sender.connect(self._connection( await sender.connect(self._connection(
dc.ip_address, dc.port, loop=self._loop)) dc.ip_address, dc.port, loop=self._loop))
__log__.info('Exporting authorization for data center %s', dc) __log__.info('Exporting authorization for data center %s', dc)

View File

@ -108,3 +108,9 @@ class Connection(abc.ABC):
the way it should be read from `self._reader`. the way it should be read from `self._reader`.
""" """
raise NotImplementedError raise NotImplementedError
def __str__(self):
return '{}:{}/{}'.format(
self._ip, self._port,
self.__class__.__name__.replace('Connection', '')
)

View File

@ -0,0 +1,104 @@
import io
import struct
from .mtprotostate import MTProtoState
from ..tl.core.messagecontainer import MessageContainer
class MTProtoLayer:
"""
This class is the message encryption layer between the methods defined
in the schema and the response objects. It also holds the necessary state
necessary for this encryption to happen.
The `connection` parameter is through which these messages will be sent
and received.
The `auth_key` must be a valid authorization key which will be used to
encrypt these messages. This class is not responsible for generating them.
"""
def __init__(self, connection, auth_key):
self._connection = connection
self._state = MTProtoState(auth_key)
def connect(self):
"""
Wrapper for ``self._connection.connect()``.
"""
return self._connection.connect()
def disconnect(self):
"""
Wrapper for ``self._connection.disconnect()``.
"""
self._connection.disconnect()
async def send(self, data_list):
"""
A list of serialized RPC queries as bytes must be given to be sent.
Nested lists imply an order is required for the messages in them.
Message containers will be used if there is more than one item.
Returns ``(container_id, msg_ids)``.
"""
data, container_id, msg_ids = self._pack_data_list(data_list)
await self._connection.send(self._state.encrypt_message_data(data))
return container_id, msg_ids
async def recv(self):
"""
Reads a single message from the network, decrypts it and returns it.
"""
body = await self._connection.recv()
return self._state.decrypt_message_data(body)
def _pack_data_list(self, data_list):
"""
A list of serialized RPC queries as bytes must be given to be packed.
Nested lists imply an order is required for the messages in them.
Returns ``(data, container_id, msg_ids)``.
"""
# TODO write_data_as_message raises on invalid messages, handle it
# TODO This method could be an iterator yielding messages while small
# respecting the ``MessageContainer.MAXIMUM_SIZE`` limit.
#
# Note that the simplest case is writing a single query data into
# a message, and returning the message data and ID. For efficiency
# purposes this method supports more than one message and automatically
# uses containers if deemed necessary.
#
# Technically the message and message container classes could be used
# to store and serialize the data. However, to keep the context local
# and relevant to the only place where such feature is actually used,
# this is not done.
msg_ids = []
buffer = io.BytesIO()
for data in data_list:
if not isinstance(data, list):
msg_ids.append(self._state.write_data_as_message(buffer, data))
else:
last_id = None
for d in data:
last_id = self._state.write_data_as_message(
buffer, d, after_id=last_id)
msg_ids.append(last_id)
if len(msg_ids) == 1:
container_id = None
else:
# Inlined code to pack several messages into a container
#
# TODO This part and encrypting data prepend a few bytes but
# force a potentially large payload to be appended, which
# may be expensive. Can we do better?
data = struct.pack(
'<Ii', MessageContainer.CONSTRUCTOR_ID, len(msg_ids)
) + buffer.getvalue()
buffer = io.BytesIO()
container_id = self._state.write_data_as_message(buffer, data)
return buffer.getvalue(), container_id, msg_ids
def __str__(self):
return str(self._connection)

View File

@ -1,7 +1,7 @@
import asyncio import asyncio
import logging import logging
from . import MTProtoPlainSender, authenticator from .mtprotolayer import MTProtoLayer
from .. import utils from .. import utils
from ..errors import ( from ..errors import (
BadMessageError, TypeNotFoundError, BrokenAuthKeyError, SecurityError, BadMessageError, TypeNotFoundError, BrokenAuthKeyError, SecurityError,
@ -40,11 +40,11 @@ class MTProtoSender:
A new authorization key will be generated on connection if no other A new authorization key will be generated on connection if no other
key exists yet. key exists yet.
""" """
def __init__(self, state, loop, *, def __init__(self, auth_key, loop, *,
retries=5, auto_reconnect=True, update_callback=None, retries=5, auto_reconnect=True, update_callback=None,
auth_key_callback=None, auto_reconnect_callback=None): auth_key_callback=None, auto_reconnect_callback=None):
self.state = state self._auth_key = auth_key
self._connection = None self._connection = None # MTProtoLayer, a.k.a. encrypted connection
self._loop = loop self._loop = loop
self._retries = retries self._retries = retries
self._auto_reconnect = auto_reconnect self._auto_reconnect = auto_reconnect
@ -118,7 +118,7 @@ class MTProtoSender:
__log__.info('User is already connected!') __log__.info('User is already connected!')
return return
self._connection = connection self._connection = MTProtoLayer(connection, self._auth_key)
self._user_connected = True self._user_connected = True
await self._connect() await self._connect()
@ -137,7 +137,7 @@ class MTProtoSender:
await self._disconnect() await self._disconnect()
async def _disconnect(self, error=None): async def _disconnect(self, error=None):
__log__.info('Disconnecting from %s...', self._connection._ip) __log__.info('Disconnecting from %s...', self._connection)
self._user_connected = False self._user_connected = False
try: try:
__log__.debug('Closing current connection...') __log__.debug('Closing current connection...')
@ -163,7 +163,7 @@ class MTProtoSender:
__log__.debug('Cancelling the receive loop...') __log__.debug('Cancelling the receive loop...')
self._recv_loop_handle.cancel() self._recv_loop_handle.cancel()
__log__.info('Disconnection from %s complete!', self._connection._ip) __log__.info('Disconnection from %s complete!', self._connection)
if self._disconnected and not self._disconnected.done(): if self._disconnected and not self._disconnected.done():
if error: if error:
self._disconnected.set_exception(error) self._disconnected.set_exception(error)
@ -235,8 +235,7 @@ class MTProtoSender:
authorization key if necessary, and starting the send and authorization key if necessary, and starting the send and
receive loops. receive loops.
""" """
__log__.info('Connecting to %s:%d...', __log__.info('Connecting to %s...', self._connection)
self._connection._ip, self._connection._port)
for retry in range(1, self._retries + 1): for retry in range(1, self._retries + 1):
try: try:
__log__.debug('Connection attempt {}...'.format(retry)) __log__.debug('Connection attempt {}...'.format(retry))
@ -251,6 +250,8 @@ class MTProtoSender:
.format(self._retries)) .format(self._retries))
__log__.debug('Connection success!') __log__.debug('Connection success!')
# TODO Handle this, maybe an empty MTProtoState that does no encryption
"""
if self.state.auth_key is None: if self.state.auth_key is None:
plain = MTProtoPlainSender(self._connection) plain = MTProtoPlainSender(self._connection)
for retry in range(1, self._retries + 1): for retry in range(1, self._retries + 1):
@ -271,6 +272,7 @@ class MTProtoSender:
.format(self._retries)) .format(self._retries))
await self._disconnect(error=e) await self._disconnect(error=e)
raise e raise e
"""
__log__.debug('Starting send loop') __log__.debug('Starting send loop')
self._send_loop_handle = self._loop.create_task(self._send_loop()) self._send_loop_handle = self._loop.create_task(self._send_loop())
@ -281,7 +283,7 @@ class MTProtoSender:
# First connection or manual reconnection after a failure # First connection or manual reconnection after a failure
if self._disconnected is None or self._disconnected.done(): if self._disconnected is None or self._disconnected.done():
self._disconnected = self._loop.create_future() self._disconnected = self._loop.create_future()
__log__.info('Connection to %s complete!', self._connection._ip) __log__.info('Connection to %s complete!', self._connection)
async def _reconnect(self): async def _reconnect(self):
""" """

View File

@ -8,7 +8,8 @@ from ..crypto import AES
from ..errors import SecurityError, BrokenAuthKeyError from ..errors import SecurityError, BrokenAuthKeyError
from ..extensions import BinaryReader from ..extensions import BinaryReader
from ..tl.core import TLMessage from ..tl.core import TLMessage
from ..tl.tlobject import TLRequest from ..tl.functions import InvokeAfterMsgRequest
from ..tl.core.gzippacked import GzipPacked
__log__ = logging.getLogger(__name__) __log__ = logging.getLogger(__name__)
@ -37,20 +38,6 @@ class MTProtoState:
self._sequence = 0 self._sequence = 0
self._last_msg_id = 0 self._last_msg_id = 0
def create_message(self, obj, *, loop, after=None):
"""
Creates a new `telethon.tl.tl_message.TLMessage` from
the given `telethon.tl.tlobject.TLObject` instance.
"""
return TLMessage(
msg_id=self._get_new_msg_id(),
seq_no=self._get_seq_no(isinstance(obj, TLRequest)),
obj=obj,
after_id=after.msg_id if after else None,
out=True, # Pre-convert the request into bytes
loop=loop
)
def update_message_id(self, message): def update_message_id(self, message):
""" """
Updates the message ID to a new one, Updates the message ID to a new one,
@ -74,14 +61,30 @@ class MTProtoState:
return aes_key, aes_iv return aes_key, aes_iv
def pack_message(self, message): def write_data_as_message(self, buffer, data, after_id=None):
""" """
Packs the given `telethon.tl.tl_message.TLMessage` using the Writes a message containing the given data into buffer.
current authorization key following MTProto 2.0 guidelines.
See https://core.telegram.org/mtproto/description. Returns the message id.
""" """
data = struct.pack('<qq', self.salt, self.id) + bytes(message) msg_id = self._get_new_msg_id()
seq_no = self._get_seq_no(True) # TODO ack/ping are not content-related
if after_id is None:
body = GzipPacked.gzip_if_smaller(data)
else:
body = GzipPacked.gzip_if_smaller(
bytes(InvokeAfterMsgRequest(after_id, data)))
buffer.write(struct.pack('<qii', msg_id, seq_no, len(body)))
buffer.write(body)
return msg_id
def encrypt_message_data(self, data):
"""
Encrypts the given message data using the current authorization key
following MTProto 2.0 guidelines core.telegram.org/mtproto/description.
"""
data = struct.pack('<qq', self.salt, self.id) + data
padding = os.urandom(-(len(data) + 12) % 16 + 12) padding = os.urandom(-(len(data) + 12) % 16 + 12)
# Being substr(what, offset, length); x = 0 for client # Being substr(what, offset, length); x = 0 for client
@ -97,11 +100,12 @@ class MTProtoState:
return (key_id + msg_key + return (key_id + msg_key +
AES.encrypt_ige(data + padding, aes_key, aes_iv)) AES.encrypt_ige(data + padding, aes_key, aes_iv))
def unpack_message(self, body): def decrypt_message_data(self, body):
""" """
Inverse of `pack_message` for incoming server messages. Inverse of `encrypt_message_data` for incoming server messages.
""" """
if len(body) < 8: if len(body) < 8:
# TODO If len == 4, raise HTTPErrorCode(-little endian int)
if body == b'l\xfe\xff\xff': if body == b'l\xfe\xff\xff':
raise BrokenAuthKeyError() raise BrokenAuthKeyError()
else: else:
@ -136,7 +140,7 @@ class MTProtoState:
# reader isn't used for anything else after this, it's unnecessary. # reader isn't used for anything else after this, it's unnecessary.
obj = reader.tgread_object() obj = reader.tgread_object()
return TLMessage(remote_msg_id, remote_sequence, obj, loop=None) return TLMessage(remote_msg_id, remote_sequence, obj)
def _get_new_msg_id(self): def _get_new_msg_id(self):
""" """

View File

@ -11,15 +11,15 @@ class GzipPacked(TLObject):
self.data = data self.data = data
@staticmethod @staticmethod
def gzip_if_smaller(request): def gzip_if_smaller(data):
"""Calls bytes(request), and based on a certain threshold, """Calls bytes(request), and based on a certain threshold,
optionally gzips the resulting data. If the gzipped data is optionally gzips the resulting data. If the gzipped data is
smaller than the original byte array, this is returned instead. smaller than the original byte array, this is returned instead.
Note that this only applies to content related requests. Note that this only applies to content related requests.
""" """
data = bytes(request) # TODO Only content-related requests should be gzipped
if isinstance(request, TLRequest) and len(data) > 512: if len(data) > 512:
gzipped = bytes(GzipPacked(data)) gzipped = bytes(GzipPacked(data))
return gzipped if len(gzipped) < len(data) else data return gzipped if len(gzipped) < len(data) else data
else: else:

View File

@ -27,11 +27,6 @@ class MessageContainer(TLObject):
], ],
} }
def __bytes__(self):
return struct.pack(
'<Ii', MessageContainer.CONSTRUCTOR_ID, len(self.messages)
) + b''.join(bytes(m) for m in self.messages)
@classmethod @classmethod
def from_reader(cls, reader): def from_reader(cls, reader):
# This assumes that .read_* calls are done in the order they appear # This assumes that .read_* calls are done in the order they appear
@ -43,5 +38,5 @@ class MessageContainer(TLObject):
before = reader.tell_position() before = reader.tell_position()
obj = reader.tgread_object() # May over-read e.g. RpcResult obj = reader.tgread_object() # May over-read e.g. RpcResult
reader.set_position(before + length) reader.set_position(before + length)
messages.append(TLMessage(msg_id, seq_no, obj, loop=None)) messages.append(TLMessage(msg_id, seq_no, obj))
return MessageContainer(messages) return MessageContainer(messages)

View File

@ -1,10 +1,6 @@
import asyncio
import logging import logging
import struct
from .gzippacked import GzipPacked
from .. import TLObject from .. import TLObject
from ..functions import InvokeAfterMsgRequest
__log__ = logging.getLogger(__name__) __log__ = logging.getLogger(__name__)
@ -17,83 +13,24 @@ class TLMessage(TLObject):
message msg_id:long seqno:int bytes:int body:bytes = Message; message msg_id:long seqno:int bytes:int body:bytes = Message;
Each message has its own unique identifier, and the body is simply Each message has its own unique identifier, and the body is simply
the serialized request that should be executed on the server. Then the serialized request that should be executed on the server, or
Telegram will, at some point, respond with the result for this msg. the response object from Telegram. Since the body is always a valid
object, it makes sense to store the object and not the bytes to
ease working with them.
Thus it makes sense that requests and their result are bound to a There is no need to add serializing logic here since that can be
sent `TLMessage`, and this result can be represented as a `Future` inlined and is unlikely to change. Thus these are only needed to
that will eventually be set with either a result, error or cancelled. encapsulate responses.
""" """
def __init__(self, msg_id, seq_no, obj, *, loop, out=False, after_id=0): def __init__(self, msg_id, seq_no, obj):
self.msg_id = msg_id
self.seq_no = seq_no
self.obj = obj self.obj = obj
self.container_msg_id = None
# If no loop is given then it is an incoming message.
# Only outgoing messages need the future to await them.
self.future = loop.create_future() if loop else None
# After which message ID this one should run. We do this so
# InvokeAfterMsgRequest is transparent to the user and we can
# easily invoke after while confirming the original request.
# TODO Currently we don't update this if another message ID changes
self.after_id = after_id
# There are two use-cases for the TLMessage, outgoing and incoming.
# Outgoing messages are meant to be serialized and sent across the
# network so it makes sense to pack them as early as possible and
# avoid this computation if it needs to be resent, and also shows
# serializing-errors as early as possible (foreground task).
#
# We assume obj won't change so caching the bytes is safe to do.
# Caching bytes lets us get the size in a fast way, necessary for
# knowing whether a container can be sent (<1MB) or not (too big).
#
# Incoming messages don't really need this body, but we save the
# msg_id and seq_no inside the body for consistency and raise if
# one tries to bytes()-ify the entire message (len == 12).
if not out:
self._body = struct.pack('<qi', msg_id, seq_no)
else:
try:
if self.after_id is None:
body = GzipPacked.gzip_if_smaller(self.obj)
else:
body = GzipPacked.gzip_if_smaller(
InvokeAfterMsgRequest(self.after_id, self.obj))
except Exception:
# struct.pack doesn't give a lot of information about
# why it may fail so log the exception AND the object
__log__.exception('Failed to pack %s', self.obj)
raise
self._body = struct.pack('<qii', msg_id, seq_no, len(body)) + body
def to_dict(self): def to_dict(self):
return { return {
'_': 'TLMessage', '_': 'TLMessage',
'msg_id': self.msg_id, 'msg_id': self.msg_id,
'seq_no': self.seq_no, 'seq_no': self.seq_no,
'obj': self.obj, 'obj': self.obj
'container_msg_id': self.container_msg_id
} }
@property
def msg_id(self):
return struct.unpack('<q', self._body[:8])[0]
@msg_id.setter
def msg_id(self, value):
self._body = struct.pack('<q', value) + self._body[8:]
@property
def seq_no(self):
return struct.unpack('<i', self._body[8:12])[0]
def __bytes__(self):
if len(self._body) == 12: # msg_id, seqno
raise TypeError('Incoming messages should not be bytes()-ed')
return self._body
def size(self):
return len(self._body)