Create a self-contained MTProtoState

This frees us from using entire Session objects in something
that's supposed to just send and receive items from the net.
This commit is contained in:
Lonami Exo 2018-06-09 11:34:01 +02:00
parent cc5753137c
commit adfe861e9f
5 changed files with 226 additions and 146 deletions

View File

@ -1,12 +1,7 @@
"""Various helpers not related to the Telegram API itself"""
import os
import struct
from hashlib import sha1, sha256
from .crypto import AES
from .errors import SecurityError, BrokenAuthKeyError
from .extensions import BinaryReader
# region Multiple utilities
@ -27,77 +22,6 @@ def ensure_parent_dir_exists(file_path):
# region Cryptographic related utils
def pack_message(session, message):
"""Packs a message following MtProto 2.0 guidelines"""
# See https://core.telegram.org/mtproto/description
data = struct.pack('<qq', session.salt, session.id) + bytes(message)
padding = os.urandom(-(len(data) + 12) % 16 + 12)
# Being substr(what, offset, length); x = 0 for client
# "msg_key_large = SHA256(substr(auth_key, 88+x, 32) + pt + padding)"
msg_key_large = sha256(
session.auth_key.key[88:88 + 32] + data + padding).digest()
# "msg_key = substr (msg_key_large, 8, 16)"
msg_key = msg_key_large[8:24]
aes_key, aes_iv = calc_key(session.auth_key.key, msg_key, True)
key_id = struct.pack('<Q', session.auth_key.key_id)
return key_id + msg_key + AES.encrypt_ige(data + padding, aes_key, aes_iv)
def unpack_message(session, body):
"""Unpacks a message following MtProto 2.0 guidelines"""
# See https://core.telegram.org/mtproto/description
if len(body) < 8:
if body == b'l\xfe\xff\xff':
raise BrokenAuthKeyError()
else:
raise BufferError("Can't decode packet ({})".format(body))
key_id = struct.unpack('<Q', body[:8])[0]
if key_id != session.auth_key.key_id:
raise SecurityError('Server replied with an invalid auth key')
msg_key = body[8:24]
aes_key, aes_iv = calc_key(session.auth_key.key, msg_key, False)
data = BinaryReader(AES.decrypt_ige(body[24:], aes_key, aes_iv))
data.read_long() # remote_salt
if data.read_long() != session.id:
raise SecurityError('Server replied with a wrong session ID')
remote_msg_id = data.read_long()
remote_sequence = data.read_int()
msg_len = data.read_int()
message = data.read(msg_len)
# https://core.telegram.org/mtproto/security_guidelines
# Sections "checking sha256 hash" and "message length"
if msg_key != sha256(
session.auth_key.key[96:96 + 32] + data.get_bytes()).digest()[8:24]:
raise SecurityError("Received msg_key doesn't match with expected one")
return message, remote_msg_id, remote_sequence
def calc_key(auth_key, msg_key, client):
"""
Calculate the key based on Telegram guidelines
for MtProto 2, specifying whether it's the client or not.
"""
# https://core.telegram.org/mtproto/description#defining-aes-key-and-initialization-vector
x = 0 if client else 8
sha256a = sha256(msg_key + auth_key[x: x + 36]).digest()
sha256b = sha256(auth_key[x + 40:x + 76] + msg_key).digest()
aes_key = sha256a[:8] + sha256b[8:24] + sha256a[24:32]
aes_iv = sha256b[:8] + sha256a[8:24] + sha256b[24:32]
return aes_key, aes_iv
def generate_key_data_from_nonce(server_nonce, new_nonce):
"""Generates the key data corresponding to the given nonce"""
server_nonce = server_nonce.to_bytes(16, 'little', signed=True)

View File

@ -3,13 +3,13 @@ import logging
from . import MTProtoPlainSender, authenticator
from .connection import ConnectionTcpFull
from .. import helpers, utils
from .. import utils
from ..errors import (
BadMessageError, TypeNotFoundError, BrokenAuthKeyError, SecurityError,
rpc_message_to_error
)
from ..extensions import BinaryReader
from ..tl import TLMessage, MessageContainer, GzipPacked
from ..tl import MessageContainer, GzipPacked
from ..tl.functions.auth import LogOutRequest
from ..tl.types import (
MsgsAck, Pong, BadServerSalt, BadMsgNotification, FutureSalts,
@ -39,8 +39,8 @@ class MTProtoSender:
A new authorization key will be generated on connection if no other
key exists yet.
"""
def __init__(self, session, retries=5):
self.session = session
def __init__(self, state, retries=5):
self.state = state
self._connection = ConnectionTcpFull()
self._ip = None
self._port = None
@ -171,21 +171,17 @@ class MTProtoSender:
# a `Future` that you need to further ``await`` instead of the
# currently double ``await (await send())``?
if utils.is_list_like(request):
if not ordered:
# False-y values must be None to do after_id = ordered and ...
ordered = None
result = []
after_id = None
after = None
for r in request:
message = TLMessage(self.session, r, after_id=after_id)
message = self.state.create_message(r, after=after)
self._pending_messages[message.msg_id] = message
after_id = ordered and message.msg_id
await self._send_queue.put(message)
result.append(message.future)
after = ordered and message
return result
else:
message = TLMessage(self.session, request)
message = self.state.create_message(request)
self._pending_messages[message.msg_id] = message
await self._send_queue.put(message)
return message.future
@ -215,13 +211,13 @@ class MTProtoSender:
raise _last_error
__log__.debug('Connection success!')
if self.session.auth_key is None:
if self.state.auth_key is None:
_last_error = SecurityError()
plain = MTProtoPlainSender(self._connection)
for retry in range(1, self._retries + 1):
try:
__log__.debug('New auth_key attempt {}...'.format(retry))
self.session.auth_key, self.session.time_offset =\
self.state.auth_key, self.state.time_offset =\
await authenticator.do_authentication(plain)
except (SecurityError, AssertionError) as e:
_last_error = e
@ -268,13 +264,14 @@ class MTProtoSender:
"""
while self._user_connected and not self._reconnecting:
if self._pending_ack:
await self._send_queue.put(TLMessage(
self.session, MsgsAck(list(self._pending_ack))))
await self._send_queue.put(self.state.create_message(
MsgsAck(list(self._pending_ack))
))
self._pending_ack.clear()
messages = await self._send_queue.get()
if isinstance(messages, list):
message = TLMessage(self.session, MessageContainer(messages))
message = self.state.create_message(MessageContainer(messages))
self._pending_messages[message.msg_id] = message
self._pending_containers.append(message)
else:
@ -283,7 +280,7 @@ class MTProtoSender:
__log__.debug('Packing {} outgoing message(s)...'
.format(len(messages)))
body = helpers.pack_message(self.session, message)
body = self.state.pack_message(message)
while not any(m.future.cancelled() for m in messages):
try:
@ -333,8 +330,7 @@ class MTProtoSender:
# TODO Check salt, session_id and sequence_number
__log__.debug('Decoding packet of {} bytes...'.format(len(body)))
try:
message, remote_msg_id, remote_seq =\
helpers.unpack_message(self.session, body)
message = self.state.unpack_message(body)
except (BrokenAuthKeyError, BufferError) as e:
# The authorization key may be broken if a message was
# sent malformed, or if the authkey truly is corrupted.
@ -346,7 +342,7 @@ class MTProtoSender:
# TODO Is it possible to detect malformed messages vs
# an actually broken authkey?
__log__.warning('Broken authorization key?: {}'.format(e))
self.session.auth_key = None
self.state.auth_key = None
asyncio.ensure_future(self._reconnect())
break
except SecurityError as e:
@ -357,28 +353,27 @@ class MTProtoSender:
continue
else:
try:
with BinaryReader(message) as reader:
await self._process_message(
remote_msg_id, remote_seq, reader)
with BinaryReader(message.body) as reader:
await self._process_message(message, reader)
except TypeNotFoundError as e:
__log__.warning('Could not decode received message: {}, '
'raw bytes: {!r}'.format(e, message))
# Response Handlers
async def _process_message(self, msg_id, seq, reader):
async def _process_message(self, message, reader):
"""
Adds the given message to the list of messages that must be
acknowledged and dispatches control to different ``_handle_*``
method based on its type.
"""
self._pending_ack.add(msg_id)
self._pending_ack.add(message.msg_id)
code = reader.read_int(signed=False)
reader.seek(-4)
handler = self._handlers.get(code, self._handle_update)
await handler(msg_id, seq, reader)
await handler(message, reader)
async def _handle_rpc_result(self, msg_id, seq, reader):
async def _handle_rpc_result(self, message, reader):
"""
Handles the result for Remote Procedure Calls:
@ -395,19 +390,14 @@ class MTProtoSender:
__log__.debug('Handling RPC result for message {}'.format(message_id))
message = self._pending_messages.pop(message_id, None)
if inner_code == 0x2144ca19: # RPC Error
# TODO Report errors if possible/enabled
reader.seek(4)
if self.session.report_errors and message:
error = rpc_message_to_error(
reader.read_int(), reader.tgread_string(),
report_method=type(message.request).CONSTRUCTOR_ID
)
else:
error = rpc_message_to_error(
reader.read_int(), reader.tgread_string()
)
error = rpc_message_to_error(reader.read_int(),
reader.tgread_string())
await self._send_queue.put(
TLMessage(self.session, MsgsAck([msg_id])))
await self._send_queue.put(self.state.create_message(
MsgsAck([message.msg_id])
))
if not message.future.cancelled():
message.future.set_exception(error)
@ -419,7 +409,7 @@ class MTProtoSender:
else:
result = message.request.read_result(reader)
self.session.process_entities(result)
# TODO Process entities
if not message.future.cancelled():
message.future.set_result(result)
return
@ -428,19 +418,18 @@ class MTProtoSender:
__log__.info('Received response without parent request: {}'
.format(reader.tgread_object()))
async def _handle_container(self, msg_id, seq, reader):
async def _handle_container(self, message, reader):
"""
Processes the inner messages of a container with many of them:
msg_container#73f1f8dc messages:vector<%Message> = MessageContainer;
"""
__log__.debug('Handling container')
for inner_msg_id, _, inner_len in MessageContainer.iter_read(reader):
next_position = reader.tell_position() + inner_len
await self._process_message(inner_msg_id, seq, reader)
reader.set_position(next_position) # Ensure reading correctly
for inner_message in MessageContainer.iter_read(reader):
with BinaryReader(inner_message.body) as inner_reader:
await self._process_message(inner_message, inner_reader)
async def _handle_gzip_packed(self, msg_id, seq, reader):
async def _handle_gzip_packed(self, message, reader):
"""
Unpacks the data from a gzipped object and processes it:
@ -448,16 +437,16 @@ class MTProtoSender:
"""
__log__.debug('Handling gzipped data')
with BinaryReader(GzipPacked.read(reader)) as compressed_reader:
await self._process_message(msg_id, seq, compressed_reader)
await self._process_message(message, compressed_reader)
async def _handle_update(self, msg_id, seq, reader):
async def _handle_update(self, message, reader):
obj = reader.tgread_object()
__log__.debug('Handling update {}'.format(obj.__class__.__name__))
# TODO Further handling of the update
self.session.process_entities(obj)
# TODO Process entities
async def _handle_pong(self, msg_id, seq, reader):
async def _handle_pong(self, message, reader):
"""
Handles pong results, which don't come inside a ``rpc_result``
but are still sent through a request:
@ -470,7 +459,7 @@ class MTProtoSender:
if message:
message.future.set_result(pong)
async def _handle_bad_server_salt(self, msg_id, seq, reader):
async def _handle_bad_server_salt(self, message, reader):
"""
Corrects the currently used server salt to use the right value
before enqueuing the rejected message to be re-sent:
@ -480,11 +469,10 @@ class MTProtoSender:
"""
__log__.debug('Handling bad salt')
bad_salt = reader.tgread_object()
self.session.salt = bad_salt.new_server_salt
self.session.save()
self.state.salt = bad_salt.new_server_salt
await self._send_queue.put(self._pending_messages[bad_salt.bad_msg_id])
async def _handle_bad_notification(self, msg_id, seq, reader):
async def _handle_bad_notification(self, message, reader):
"""
Adjusts the current state to be correct based on the
received bad message notification whenever possible:
@ -497,14 +485,14 @@ class MTProtoSender:
if bad_msg.error_code in (16, 17):
# Sent msg_id too low or too high (respectively).
# Use the current msg_id to determine the right time offset.
self.session.update_time_offset(correct_msg_id=msg_id)
self.state.update_time_offset(correct_msg_id=message.msg_id)
elif bad_msg.error_code == 32:
# msg_seqno too low, so just pump it up by some "large" amount
# TODO A better fix would be to start with a new fresh session ID
self.session.sequence += 64
self.state._sequence += 64
elif bad_msg.error_code == 33:
# msg_seqno too high never seems to happen but just in case
self.session.sequence -= 16
self.state._sequence -= 16
else:
msg = self._pending_messages.pop(bad_msg.bad_msg_id, None)
if msg:
@ -514,7 +502,7 @@ class MTProtoSender:
# Messages are to be re-sent once we've corrected the issue
await self._send_queue.put(self._pending_messages[bad_msg.bad_msg_id])
async def _handle_detailed_info(self, msg_id, seq, reader):
async def _handle_detailed_info(self, message, reader):
"""
Updates the current status with the received detailed information:
@ -525,7 +513,7 @@ class MTProtoSender:
__log__.debug('Handling detailed info')
self._pending_ack.add(reader.tgread_object().answer_msg_id)
async def _handle_new_detailed_info(self, msg_id, seq, reader):
async def _handle_new_detailed_info(self, message, reader):
"""
Updates the current status with the received detailed information:
@ -536,7 +524,7 @@ class MTProtoSender:
__log__.debug('Handling new detailed info')
self._pending_ack.add(reader.tgread_object().answer_msg_id)
async def _handle_new_session_created(self, msg_id, seq, reader):
async def _handle_new_session_created(self, message, reader):
"""
Updates the current status with the received session information:
@ -545,7 +533,7 @@ class MTProtoSender:
"""
# TODO https://goo.gl/LMyN7A
__log__.debug('Handling new session created')
self.session.salt = reader.tgread_object().server_salt
self.state.salt = reader.tgread_object().server_salt
def _clean_containers(self, msg_ids):
"""
@ -564,7 +552,7 @@ class MTProtoSender:
del self._pending_messages[message.msg_id]
break
async def _handle_ack(self, msg_id, seq, reader):
async def _handle_ack(self, message, reader):
"""
Handles a server acknowledge about our messages. Normally
these can be ignored except in the case of ``auth.logOut``:
@ -590,7 +578,7 @@ class MTProtoSender:
del self._pending_messages[msg_id]
msg.future.set_result(True)
async def _handle_future_salts(self, msg_id, seq, reader):
async def _handle_future_salts(self, message, reader):
"""
Handles future salt results, which don't come inside a
``rpc_result`` but are still sent through a request:
@ -602,7 +590,7 @@ class MTProtoSender:
# correct one whenever the salt in use expires.
__log__.debug('Handling future salts')
salts = reader.tgread_object()
msg = self._pending_messages.pop(msg_id, None)
msg = self._pending_messages.pop(message.msg_id, None)
if msg:
msg.future.set_result(salts)

View File

@ -0,0 +1,158 @@
import os
import struct
import time
from hashlib import sha256
from ..crypto import AES
from ..errors import SecurityError, BrokenAuthKeyError
from ..extensions import BinaryReader
from ..tl import TLMessage
class MTProtoState:
"""
`telethon.network.mtprotosender.MTProtoSender` needs to hold a state
in order to be able to encrypt and decrypt incoming/outgoing messages,
as well as generating the message IDs. Instances of this class hold
together all the required information.
It doesn't make sense to use `telethon.sessions.abstract.Session` for
the sender because the sender should *not* be concerned about storing
this information to disk, as one may create as many senders as they
desire to any other data center, or some CDN. Using the same session
for all these is not a good idea as each need their own authkey, and
the concept of "copying" sessions with the unnecessary entities or
updates state for these connections doesn't make sense.
"""
def __init__(self, auth_key):
# Session IDs can be random on every connection
self.id = struct.unpack('q', os.urandom(8))[0]
self.auth_key = auth_key
self.time_offset = 0
self.salt = 0
self._sequence = 0
self._last_msg_id = 0
def create_message(self, request, after=None):
"""
Creates a new `telethon.tl.tl_message.TLMessage` from
the given `telethon.tl.tlobject.TLObject` instance.
"""
return TLMessage(
msg_id=self._get_new_msg_id(),
seq_no=self._get_seq_no(request.content_related),
request=request,
after_id=after.msg_id if after else None
)
@staticmethod
def _calc_key(auth_key, msg_key, client):
"""
Calculate the key based on Telegram guidelines for MTProto 2,
specifying whether it's the client or not. See
https://core.telegram.org/mtproto/description#defining-aes-key-and-initialization-vector
"""
x = 0 if client else 8
sha256a = sha256(msg_key + auth_key[x: x + 36]).digest()
sha256b = sha256(auth_key[x + 40:x + 76] + msg_key).digest()
aes_key = sha256a[:8] + sha256b[8:24] + sha256a[24:32]
aes_iv = sha256b[:8] + sha256a[8:24] + sha256b[24:32]
return aes_key, aes_iv
def pack_message(self, message):
"""
Packs the given `telethon.tl.tl_message.TLMessage` using the
current authorization key following MTProto 2.0 guidelines.
See https://core.telegram.org/mtproto/description.
"""
data = struct.pack('<qq', self.salt, self.id) + bytes(message)
padding = os.urandom(-(len(data) + 12) % 16 + 12)
# Being substr(what, offset, length); x = 0 for client
# "msg_key_large = SHA256(substr(auth_key, 88+x, 32) + pt + padding)"
msg_key_large = sha256(
self.auth_key.key[88:88 + 32] + data + padding).digest()
# "msg_key = substr (msg_key_large, 8, 16)"
msg_key = msg_key_large[8:24]
aes_key, aes_iv = self._calc_key(self.auth_key.key, msg_key, True)
key_id = struct.pack('<Q', self.auth_key.key_id)
return (key_id + msg_key +
AES.encrypt_ige(data + padding, aes_key, aes_iv))
def unpack_message(self, body):
"""
Inverse of `pack_message` for incoming server messages.
"""
if len(body) < 8:
if body == b'l\xfe\xff\xff':
raise BrokenAuthKeyError()
else:
raise BufferError("Can't decode packet ({})".format(body))
key_id = struct.unpack('<Q', body[:8])[0]
if key_id != self.auth_key.key_id:
raise SecurityError('Server replied with an invalid auth key')
msg_key = body[8:24]
aes_key, aes_iv = self._calc_key(self.auth_key.key, msg_key, False)
data = BinaryReader(AES.decrypt_ige(body[24:], aes_key, aes_iv))
data.read_long() # remote_salt
if data.read_long() != self.id:
raise SecurityError('Server replied with a wrong session ID')
remote_msg_id = data.read_long()
remote_sequence = data.read_int()
msg_len = data.read_int()
message = data.read(msg_len)
# https://core.telegram.org/mtproto/security_guidelines
# Sections "checking sha256 hash" and "message length"
our_key = sha256(self.auth_key.key[96:96 + 32] + data.get_bytes())
if msg_key != our_key.digest()[8:24]:
raise SecurityError(
"Received msg_key doesn't match with expected one")
return TLMessage(remote_msg_id, remote_sequence, body=message)
def _get_new_msg_id(self):
"""
Generates a new unique message ID based on the current
time (in ms) since epoch, applying a known time offset.
"""
now = time.time() + self.time_offset
nanoseconds = int((now - int(now)) * 1e+9)
new_msg_id = (int(now) << 32) | (nanoseconds << 2)
if self._last_msg_id >= new_msg_id:
new_msg_id = self._last_msg_id + 4
self._last_msg_id = new_msg_id
return new_msg_id
def update_time_offset(self, correct_msg_id):
"""
Updates the time offset to the correct
one given a known valid message ID.
"""
now = int(time.time())
correct = correct_msg_id >> 32
self.time_offset = correct - now
self._last_msg_id = 0
def _get_seq_no(self, content_related):
"""
Generates the next sequence number depending on whether
it should be for a content-related query or not.
"""
if content_related:
result = self._sequence * 2 + 1
self._sequence += 1
return result
else:
return self._sequence * 2

View File

@ -1,6 +1,7 @@
import struct
from . import TLObject
from .tl_message import TLMessage
class MessageContainer(TLObject):
@ -33,7 +34,8 @@ class MessageContainer(TLObject):
inner_msg_id = reader.read_long()
inner_sequence = reader.read_int()
inner_length = reader.read_int()
yield inner_msg_id, inner_sequence, inner_length
yield TLMessage(inner_msg_id, inner_sequence,
body=reader.read(inner_length))
def __str__(self):
return TLObject.pretty_format(self)

View File

@ -20,15 +20,23 @@ class TLMessage(TLObject):
sent `TLMessage`, and this result can be represented as a `Future`
that will eventually be set with either a result, error or cancelled.
"""
def __init__(self, session, request, after_id=None):
def __init__(self, msg_id, seq_no, body=None, request=None, after_id=0):
super().__init__()
del self.content_related
self.msg_id = session.get_new_msg_id()
self.seq_no = session.generate_sequence(request.content_related)
self.request = request
self.msg_id = msg_id
self.seq_no = seq_no
self.container_msg_id = None
self.future = asyncio.Future()
# TODO Perhaps it's possible to merge body and request?
# We need things like rpc_result and gzip_packed to
# be readable by the ``BinaryReader`` for such purpose.
# Used for incoming, not-decoded messages
self.body = body
# Used for outgoing, not-encoded messages
self.request = request
# After which message ID this one should run. We do this so
# InvokeAfterMsgRequest is transparent to the user and we can
# easily invoke after while confirming the original request.